Merge pull request #110 from logspace-ai/memory_fixes

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-04-05 11:52:38 -03:00 committed by GitHub
commit 550a1706cf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 91 additions and 50 deletions

View file

@ -41,13 +41,13 @@ build:
dev:
make install_frontend
ifeq ($(build),1)
ifeq ($(build),1)
@echo 'Running docker compose up with build'
docker compose up $(if $(debug),-f docker-compose.debug.yml) --build
else
else
@echo 'Running docker compose up without build'
docker compose up $(if $(debug),-f docker-compose.debug.yml)
endif
endif
publish:
make build

View file

@ -46,6 +46,7 @@ toolkits:
memories:
- ConversationBufferMemory
- ConversationSummaryMemory
- ConversationKGMemory
embeddings: []

View file

@ -6,6 +6,7 @@ from langflow.graph.nodes import (
ChainNode,
FileToolNode,
LLMNode,
MemoryNode,
PromptNode,
ToolkitNode,
ToolNode,
@ -20,6 +21,7 @@ from langflow.interface.tools.base import tool_creator
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.interface.tools.util import get_tools_dict
from langflow.interface.wrappers.base import wrapper_creator
from langflow.interface.memories.base import memory_creator
from langflow.utils import payload
@ -65,6 +67,8 @@ class Graph:
def build(self) -> List[Node]:
# Get root node
root_node = payload.get_root_node(self)
if root_node is None:
raise ValueError("No root node found")
return root_node.build()
def get_node_neighbors(self, node: Node) -> Dict[Node, int]:
@ -130,6 +134,11 @@ class Graph:
or node_lc_type in llm_creator.to_list()
):
nodes.append(LLMNode(node))
elif (
node_type in memory_creator.to_list()
or node_lc_type in memory_creator.to_list()
):
nodes.append(MemoryNode(node))
else:
nodes.append(Node(node))
return nodes

View file

@ -144,3 +144,13 @@ class WrapperNode(Node):
self.params["headers"] = eval(self.params["headers"])
self._build()
return deepcopy(self._built_object)
class MemoryNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="memory")
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)

View file

@ -26,7 +26,14 @@ def load_file(file_name, file_content, accepted_types) -> Any:
return json.loads(decoded_string)
elif suffix in ["yaml", "yml"]:
# Return the yaml content
return yaml.safe_load(decoded_string)
loaded_yaml = yaml.load(decoded_string, Loader=yaml.FullLoader)
try:
from langchain.agents.agent_toolkits.openapi.spec import reduce_openapi_spec # type: ignore
return reduce_openapi_spec(loaded_yaml)
except ImportError:
return loaded_yaml
elif suffix == "csv":
# Load the csv content
csv_reader = csv.DictReader(io.StringIO(decoded_string))

View file

@ -29,6 +29,8 @@ def import_module(module_path: str) -> Any:
def import_by_type(_type: str, name: str) -> Any:
"""Import class by type and name"""
if _type is None:
raise ValueError(f"Type cannot be None. Check if {name} is in the config file.")
func_dict = {
"agents": import_agent,
"prompts": import_prompt,

View file

@ -72,41 +72,56 @@ def process_graph(data_graph: Dict[str, Any]):
return {"result": str(result), "thought": thought.strip()}
def fix_memory_inputs(langchain_object):
def get_memory_key(langchain_object):
"""
Fix memory inputs by replacing the memory key with the input key.
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.
It then checks if the key exists in a dictionary of known memory keys and returns the corresponding key,
or None if the current key is not recognized.
"""
# Possible memory keys
# "chat_history", "history"
# if memory_key is "chat_history" and input_keys has "history"
# we need to replace "chat_history" with "history"
mem_key_dict = {
"chat_history": "history",
"history": "chat_history",
}
memory_key = langchain_object.memory.memory_key
possible_new_mem_key = mem_key_dict.get(memory_key)
if possible_new_mem_key is not None:
# get input_key
input_key = [
key
for key in langchain_object.input_keys
if key not in [memory_key, possible_new_mem_key]
][0]
return mem_key_dict.get(memory_key)
# get output_key
output_key = [
key
for key in langchain_object.output_keys
if key not in [memory_key, possible_new_mem_key]
][0]
# set input_key and output_key in memory
langchain_object.memory.input_key = input_key
langchain_object.memory.output_key = output_key
for input_key in langchain_object.input_keys:
if input_key == possible_new_mem_key:
langchain_object.memory.memory_key = possible_new_mem_key
def update_memory_keys(langchain_object, possible_new_mem_key):
"""
Given a LangChain object and a possible new memory key, this function updates the input and output keys in the
object's memory attribute to exclude the current memory key and the possible new key. It then sets the memory key
to the possible new key.
"""
input_key = [
key
for key in langchain_object.input_keys
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
][0]
output_key = [
key
for key in langchain_object.output_keys
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
][0]
langchain_object.memory.input_key = input_key
langchain_object.memory.output_key = output_key
langchain_object.memory.memory_key = possible_new_mem_key
def fix_memory_inputs(langchain_object):
"""
Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
get_memory_key function and updates the memory keys using the update_memory_keys function.
"""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
if langchain_object.memory.memory_key in langchain_object.input_variables:
return
possible_new_mem_key = get_memory_key(langchain_object)
if possible_new_mem_key is not None:
update_memory_keys(langchain_object, possible_new_mem_key)
def get_result_and_thought_using_graph(langchain_object, message: str):
@ -114,27 +129,24 @@ def get_result_and_thought_using_graph(langchain_object, message: str):
try:
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
chat_input = None
memory_key = ""
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
memory_key = langchain_object.memory.memory_key
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
chat_input = None
memory_key = ""
if (
hasattr(langchain_object, "memory")
and langchain_object.memory is not None
):
memory_key = langchain_object.memory.memory_key
for key in langchain_object.input_keys:
if key not in [memory_key, "chat_history"]:
chat_input = {key: message}
if hasattr(langchain_object, "return_intermediate_steps"):
# https://github.com/hwchase17/langchain/issues/2068
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
try:
output = langchain_object(chat_input)
except ValueError as exc: