Merge pull request #110 from logspace-ai/memory_fixes
This commit is contained in:
commit
550a1706cf
7 changed files with 91 additions and 50 deletions
6
Makefile
6
Makefile
|
|
@ -41,13 +41,13 @@ build:
|
|||
|
||||
dev:
|
||||
make install_frontend
|
||||
ifeq ($(build),1)
|
||||
ifeq ($(build),1)
|
||||
@echo 'Running docker compose up with build'
|
||||
docker compose up $(if $(debug),-f docker-compose.debug.yml) --build
|
||||
else
|
||||
else
|
||||
@echo 'Running docker compose up without build'
|
||||
docker compose up $(if $(debug),-f docker-compose.debug.yml)
|
||||
endif
|
||||
endif
|
||||
|
||||
publish:
|
||||
make build
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ toolkits:
|
|||
memories:
|
||||
- ConversationBufferMemory
|
||||
- ConversationSummaryMemory
|
||||
- ConversationKGMemory
|
||||
|
||||
embeddings: []
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langflow.graph.nodes import (
|
|||
ChainNode,
|
||||
FileToolNode,
|
||||
LLMNode,
|
||||
MemoryNode,
|
||||
PromptNode,
|
||||
ToolkitNode,
|
||||
ToolNode,
|
||||
|
|
@ -20,6 +21,7 @@ from langflow.interface.tools.base import tool_creator
|
|||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.interface.tools.util import get_tools_dict
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.memories.base import memory_creator
|
||||
from langflow.utils import payload
|
||||
|
||||
|
||||
|
|
@ -65,6 +67,8 @@ class Graph:
|
|||
def build(self) -> List[Node]:
|
||||
# Get root node
|
||||
root_node = payload.get_root_node(self)
|
||||
if root_node is None:
|
||||
raise ValueError("No root node found")
|
||||
return root_node.build()
|
||||
|
||||
def get_node_neighbors(self, node: Node) -> Dict[Node, int]:
|
||||
|
|
@ -130,6 +134,11 @@ class Graph:
|
|||
or node_lc_type in llm_creator.to_list()
|
||||
):
|
||||
nodes.append(LLMNode(node))
|
||||
elif (
|
||||
node_type in memory_creator.to_list()
|
||||
or node_lc_type in memory_creator.to_list()
|
||||
):
|
||||
nodes.append(MemoryNode(node))
|
||||
else:
|
||||
nodes.append(Node(node))
|
||||
return nodes
|
||||
|
|
|
|||
|
|
@ -144,3 +144,13 @@ class WrapperNode(Node):
|
|||
self.params["headers"] = eval(self.params["headers"])
|
||||
self._build()
|
||||
return deepcopy(self._built_object)
|
||||
|
||||
|
||||
class MemoryNode(Node):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="memory")
|
||||
|
||||
def build(self, force: bool = False) -> Any:
|
||||
if not self._built or force:
|
||||
self._build()
|
||||
return deepcopy(self._built_object)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,14 @@ def load_file(file_name, file_content, accepted_types) -> Any:
|
|||
return json.loads(decoded_string)
|
||||
elif suffix in ["yaml", "yml"]:
|
||||
# Return the yaml content
|
||||
return yaml.safe_load(decoded_string)
|
||||
loaded_yaml = yaml.load(decoded_string, Loader=yaml.FullLoader)
|
||||
try:
|
||||
from langchain.agents.agent_toolkits.openapi.spec import reduce_openapi_spec # type: ignore
|
||||
|
||||
return reduce_openapi_spec(loaded_yaml)
|
||||
except ImportError:
|
||||
return loaded_yaml
|
||||
|
||||
elif suffix == "csv":
|
||||
# Load the csv content
|
||||
csv_reader = csv.DictReader(io.StringIO(decoded_string))
|
||||
|
|
|
|||
|
|
@ -29,6 +29,8 @@ def import_module(module_path: str) -> Any:
|
|||
|
||||
def import_by_type(_type: str, name: str) -> Any:
|
||||
"""Import class by type and name"""
|
||||
if _type is None:
|
||||
raise ValueError(f"Type cannot be None. Check if {name} is in the config file.")
|
||||
func_dict = {
|
||||
"agents": import_agent,
|
||||
"prompts": import_prompt,
|
||||
|
|
|
|||
|
|
@ -72,41 +72,56 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
def get_memory_key(langchain_object):
|
||||
"""
|
||||
Fix memory inputs by replacing the memory key with the input key.
|
||||
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.
|
||||
It then checks if the key exists in a dictionary of known memory keys and returns the corresponding key,
|
||||
or None if the current key is not recognized.
|
||||
"""
|
||||
# Possible memory keys
|
||||
# "chat_history", "history"
|
||||
# if memory_key is "chat_history" and input_keys has "history"
|
||||
# we need to replace "chat_history" with "history"
|
||||
mem_key_dict = {
|
||||
"chat_history": "history",
|
||||
"history": "chat_history",
|
||||
}
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
possible_new_mem_key = mem_key_dict.get(memory_key)
|
||||
if possible_new_mem_key is not None:
|
||||
# get input_key
|
||||
input_key = [
|
||||
key
|
||||
for key in langchain_object.input_keys
|
||||
if key not in [memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
return mem_key_dict.get(memory_key)
|
||||
|
||||
# get output_key
|
||||
output_key = [
|
||||
key
|
||||
for key in langchain_object.output_keys
|
||||
if key not in [memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
|
||||
# set input_key and output_key in memory
|
||||
langchain_object.memory.input_key = input_key
|
||||
langchain_object.memory.output_key = output_key
|
||||
for input_key in langchain_object.input_keys:
|
||||
if input_key == possible_new_mem_key:
|
||||
langchain_object.memory.memory_key = possible_new_mem_key
|
||||
def update_memory_keys(langchain_object, possible_new_mem_key):
|
||||
"""
|
||||
Given a LangChain object and a possible new memory key, this function updates the input and output keys in the
|
||||
object's memory attribute to exclude the current memory key and the possible new key. It then sets the memory key
|
||||
to the possible new key.
|
||||
"""
|
||||
input_key = [
|
||||
key
|
||||
for key in langchain_object.input_keys
|
||||
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
|
||||
output_key = [
|
||||
key
|
||||
for key in langchain_object.output_keys
|
||||
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
|
||||
langchain_object.memory.input_key = input_key
|
||||
langchain_object.memory.output_key = output_key
|
||||
langchain_object.memory.memory_key = possible_new_mem_key
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the
|
||||
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
|
||||
get_memory_key function and updates the memory keys using the update_memory_keys function.
|
||||
"""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
if langchain_object.memory.memory_key in langchain_object.input_variables:
|
||||
return
|
||||
|
||||
possible_new_mem_key = get_memory_key(langchain_object)
|
||||
if possible_new_mem_key is not None:
|
||||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
def get_result_and_thought_using_graph(langchain_object, message: str):
|
||||
|
|
@ -114,27 +129,24 @@ def get_result_and_thought_using_graph(langchain_object, message: str):
|
|||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if (
|
||||
hasattr(langchain_object, "memory")
|
||||
and langchain_object.memory is not None
|
||||
):
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
try:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue