Refactor MemoryComponent class and add ZepMessageReaderComponent (#1771)

* Add BaseMemoryComponent class to langflow.base.memory.memory.py (#1750)

* Add BaseMemoryComponent class to langflow.base.memory.memory.py

* Update MemoryComponent class in langflow.components.helpers.MemoryComponent.py to inherit from BaseMemoryComponent

*  (ZepMessageReader.py): Add ZepMessageReaderComponent to retrieve chat messages from Zep
📝 (ZepMessageWriter.py): Add ZepMessageWriterComponent to add messages to ZepChatMessageHistory

📝 (Langflow Memory Conversation.json): Refactor MemoryComponent class to inherit from BaseMemoryComponent for better code organization and reusability. Move get_messages method to the class level and validate kwargs for correct keys before processing.

* Update WeaviateSearch.py to include index_name parameter in build method

Update ZepMessageWriter.py to include metadata parameter in __init__ method

Update ZepMessageReader.py to include cast function for memory parameter

Update schema.py to include cast function for metadata parameter

Update process.py to include tweaks_dict variable and use it in apply_tweaks method

Update Weaviate.py to include index_name parameter in build method and raise ValueError if index_name is not provided

* Update process.py to include tweaks_dict variable and use it in apply_tweaks method

*  (ZepMessageReader.py): Update ZepMessageReaderComponent build method to handle optional url and api_key parameters and improve error handling for zep-python package import
📝 (ZepMessageWriter.py): Refactor ZepMessageWriterComponent to use 'text' instead of 'message' for consistency and update add_message method to reflect this change. Add 'input_value' configuration option for specifying the record to write to Zep. Update build_config method to reflect changes in input parameters. Update add_message method to use 'text' parameter instead of 'message'. Update build method to handle optional url and api_key parameters and improve error handling for zep-python package import.

* Update zep-python package to version 2.0.0rc5

* 📝 (memory.py): update parameter name from 'message' to 'text' for better clarity and consistency
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-04-23 23:06:54 -03:00 committed by GitHub
commit 19e46de04c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 345 additions and 47 deletions

17
poetry.lock generated
View file

@ -10169,6 +10169,21 @@ files = [
idna = ">=2.0"
multidict = ">=4.0"
[[package]]
name = "zep-python"
version = "2.0.0rc5"
description = "Long-Term Memory for AI Assistants. This is the Python client for the Zep service."
optional = false
python-versions = "<4,>=3.9.0"
files = [
{file = "zep_python-2.0.0rc5-py3-none-any.whl", hash = "sha256:8b1b5c22c9e1ef439c9ef3d785347abf89b1243c7149e32025dd065cc022af40"},
{file = "zep_python-2.0.0rc5.tar.gz", hash = "sha256:e6ced8089760374dead948d6b4b88fceb09a356bf9a7fe182b4ceb6e828f0bb1"},
]
[package.dependencies]
httpx = ">=0.24.0,<0.29.0"
pydantic = ">=2.0.0"
[[package]]
name = "zipp"
version = "3.18.1"
@ -10262,4 +10277,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.10,<3.12"
content-hash = "9dd152b30031767c522c77e2ad5fc4597a8d1590b13968af143bd382e056b2a1"
content-hash = "bec34397b534f882551511558c76785c7cd67e6a1eefc1d45f6a64d97175d886"

View file

@ -81,6 +81,7 @@ chromadb = "^0.4.24"
langchain-anthropic = "^0.1.6"
langchain-astradb = "^0.1.0"
langchain-openai = "^0.1.1"
zep-python = {version = "^2.0.0rc5", allow-prereleases = true}
[tool.poetry.group.dev.dependencies]
types-redis = "^4.6.0.5"

View file

@ -130,7 +130,7 @@ async def simplified_run_flow(
graph_data = flow.data
graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream)
graph = Graph.from_payload(graph_data, flow_id=flow_id, user_id=api_key_user.id)
graph = Graph.from_payload(graph_data, flow_id=flow_id, user_id=str(api_key_user.id))
inputs = [
InputValueRequest(components=[], input_value=input_request.input_value, type=input_request.input_type)
]

View file

@ -0,0 +1,51 @@
from typing import Optional
from langflow.field_typing import Text
from langflow.helpers.record import records_to_text
from langflow.interface.custom.custom_component import CustomComponent
from langflow.schema.schema import Record
class BaseMemoryComponent(CustomComponent):
display_name = "Chat Memory"
description = "Retrieves stored chat messages given a specific Session ID."
beta: bool = True
icon = "history"
def build_config(self):
return {
"sender": {
"options": ["Machine", "User", "Machine and User"],
"display_name": "Sender Type",
},
"sender_name": {"display_name": "Sender Name", "advanced": True},
"n_messages": {
"display_name": "Number of Messages",
"info": "Number of messages to retrieve.",
},
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"order": {
"options": ["Ascending", "Descending"],
"display_name": "Order",
"info": "Order of the messages.",
"advanced": True,
},
"record_template": {
"display_name": "Record Template",
"multiline": True,
"info": "Template to convert Record to Text. If left empty, it will be dynamically set to the Record's text key.",
"advanced": True,
},
}
def get_messages(self, **kwargs) -> list[Record]:
raise NotImplementedError
def add_message(
self, sender: str, sender_name: str, text: str, session_id: str, metadata: Optional[dict] = None, **kwargs
):
raise NotImplementedError

View file

@ -1,12 +1,13 @@
from typing import Optional
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.field_typing import Text
from langflow.helpers.record import records_to_text
from langflow.interface.custom.custom_component import CustomComponent
from langflow.memory import get_messages
from langflow.schema.schema import Record
class MemoryComponent(CustomComponent):
class MemoryComponent(BaseMemoryComponent):
display_name = "Chat Memory"
description = "Retrieves stored chat messages given a specific Session ID."
beta: bool = True
@ -42,6 +43,24 @@ class MemoryComponent(CustomComponent):
},
}
def get_messages(self, **kwargs) -> list[Record]:
# Validate kwargs by checking if it contains the correct keys
if "sender" not in kwargs:
kwargs["sender"] = None
if "sender_name" not in kwargs:
kwargs["sender_name"] = None
if "session_id" not in kwargs:
kwargs["session_id"] = None
if "n_messages" not in kwargs:
kwargs["n_messages"] = 5
if "order" not in kwargs:
kwargs["order"] = "Descending"
kwargs["order"] = "DESC" if kwargs["order"] == "Descending" else "ASC"
if kwargs["sender"] == "Machine and User":
kwargs["sender"] = None
return get_messages(**kwargs)
def build(
self,
sender: Optional[str] = "Machine and User",
@ -51,14 +70,11 @@ class MemoryComponent(CustomComponent):
order: Optional[str] = "Descending",
record_template: Optional[str] = "{sender_name}: {text}",
) -> Text:
order = "DESC" if order == "Descending" else "ASC"
if sender == "Machine and User":
sender = None
messages = get_messages(
messages = self.get_messages(
sender=sender,
sender_name=sender_name,
session_id=session_id,
limit=n_messages,
n_messages=n_messages,
order=order,
)
messages_str = records_to_text(template=record_template or "", records=messages)

View file

@ -0,0 +1,137 @@
from typing import Optional, cast
from langchain_community.chat_message_histories.zep import SearchScope, SearchType, ZepChatMessageHistory
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.field_typing import Text
from langflow.schema.schema import Record
class ZepMessageReaderComponent(BaseMemoryComponent):
display_name = "Zep Message Reader"
description = "Retrieves stored chat messages from Zep."
def build_config(self):
return {
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"url": {
"display_name": "Zep URL",
"info": "URL of the Zep instance.",
"input_types": ["Text"],
},
"api_key": {
"display_name": "Zep API Key",
"info": "API Key for the Zep instance.",
"password": True,
},
"query": {
"display_name": "Query",
"info": "Query to search for in the chat history.",
},
"metadata": {
"display_name": "Metadata",
"info": "Optional metadata to attach to the message.",
"advanced": True,
},
"search_scope": {
"options": ["Messages", "Summary"],
"display_name": "Search Scope",
"info": "Scope of the search.",
"advanced": True,
},
"search_type": {
"options": ["Similarity", "MMR"],
"display_name": "Search Type",
"info": "Type of search.",
"advanced": True,
},
"limit": {
"display_name": "Limit",
"info": "Limit of search results.",
"advanced": True,
},
}
def get_messages(self, **kwargs) -> list[Record]:
"""
Retrieves messages from the ZepChatMessageHistory memory.
If a query is provided, the search method is used to search for messages in the memory, otherwise all messages are returned.
Args:
memory (ZepChatMessageHistory): The ZepChatMessageHistory instance to retrieve messages from.
query (str, optional): The query string to search for messages. Defaults to None.
metadata (dict, optional): Additional metadata to filter the search results. Defaults to None.
search_scope (str, optional): The scope of the search. Can be 'messages' or 'summary'. Defaults to 'messages'.
search_type (str, optional): The type of search. Can be 'similarity' or 'exact'. Defaults to 'similarity'.
limit (int, optional): The maximum number of search results to return. Defaults to None.
Returns:
list[Record]: A list of Record objects representing the search results.
"""
memory: ZepChatMessageHistory = cast(ZepChatMessageHistory, kwargs.get("memory"))
if not memory:
raise ValueError("ZepChatMessageHistory instance is required.")
query = kwargs.get("query")
search_scope = kwargs.get("search_scope", SearchScope.messages).lower()
search_type = kwargs.get("search_type", SearchType.similarity).lower()
limit = kwargs.get("limit")
if query:
memory_search_results = memory.search(
query,
search_scope=search_scope,
search_type=search_type,
limit=limit,
)
# Get the messages from the search results if the search scope is messages
result_dicts = []
for result in memory_search_results:
result_dict = {}
if search_scope == SearchScope.messages:
result_dict["text"] = result.message
else:
result_dict["text"] = result.summary
result_dict["metadata"] = result.metadata
result_dict["score"] = result.score
result_dicts.append(result_dict)
results = [Record(data=result_dict) for result_dict in result_dicts]
else:
messages = memory.messages
results = [Record.from_lc_message(message) for message in messages]
return results
def build(
self,
session_id: Text,
url: Optional[Text] = None,
api_key: Optional[Text] = None,
query: Optional[Text] = None,
search_scope: SearchScope = SearchScope.messages,
search_type: SearchType = SearchType.similarity,
limit: Optional[int] = None,
) -> list[Record]:
try:
from zep_python import ZepClient
from zep_python.langchain import ZepChatMessageHistory
except ImportError:
raise ImportError(
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
)
if url == "":
url = None
zep_client = ZepClient(api_url=url, api_key=api_key)
memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client)
records = self.get_messages(
memory=memory,
query=query,
search_scope=search_scope,
search_type=search_type,
limit=limit,
)
self.status = records
return records

View file

@ -0,0 +1,96 @@
from typing import Optional, TYPE_CHECKING
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.field_typing import Text
from langflow.schema.schema import Record
if TYPE_CHECKING:
from zep_python.langchain import ZepChatMessageHistory
class ZepMessageWriterComponent(BaseMemoryComponent):
display_name = "Zep Message Writer"
description = "Writes a message to Zep."
def build_config(self):
return {
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"url": {
"display_name": "Zep URL",
"info": "URL of the Zep instance.",
"input_types": ["Text"],
},
"api_key": {
"display_name": "Zep API Key",
"info": "API Key for the Zep instance.",
"password": True,
},
"limit": {
"display_name": "Limit",
"info": "Limit of search results.",
"advanced": True,
},
"input_value": {
"display_name": "Input Record",
"info": "Record to write to Zep.",
},
}
def add_message(
self, sender: Text, sender_name: Text, text: Text, session_id: Text, metadata: dict | None = None, **kwargs
):
"""
Adds a message to the ZepChatMessageHistory memory.
Args:
sender (Text): The type of the message sender. Valid values are "Machine" or "User".
sender_name (Text): The name of the message sender.
text (Text): The content of the message.
session_id (Text): The session ID associated with the message.
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the ZepChatMessageHistory instance is not provided.
"""
memory: ZepChatMessageHistory | None = kwargs.pop("memory", None)
if memory is None:
raise ValueError("ZepChatMessageHistory instance is required.")
if metadata is None:
metadata = {}
metadata["sender_name"] = sender_name
metadata.update(kwargs)
if sender == "Machine":
memory.add_ai_message(text, metadata=metadata)
elif sender == "User":
memory.add_user_message(text, metadata=metadata)
else:
raise ValueError(f"Invalid sender type: {sender}")
def build(
self,
input_value: Record,
session_id: Text,
url: Optional[Text] = None,
api_key: Optional[Text] = None,
) -> Record:
try:
from zep_python import ZepClient
from zep_python.langchain import ZepChatMessageHistory
except ImportError:
raise ImportError(
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
)
if url == "":
url = None
zep_client = ZepClient(api_url=url, api_key=api_key)
memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client)
self.add_message(**input_value.data, memory=memory)
self.status = f"Added message to Zep memory for session {session_id}"
return input_value

View file

@ -61,10 +61,10 @@ class WeaviateSearchVectorStore(WeaviateVectorStoreComponent, LCVectorStoreCompo
input_value: Text,
search_type: str,
url: str,
index_name: str,
number_of_results: int = 4,
search_by_text: bool = False,
api_key: Optional[str] = None,
index_name: Optional[str] = None,
text_key: str = "text",
embedding: Optional[Embeddings] = None,
attributes: Optional[list] = None,

View file

@ -4,6 +4,7 @@ import weaviate # type: ignore
from langchain.embeddings.base import Embeddings
from langchain.schema import BaseRetriever
from langchain_community.vectorstores import VectorStore, Weaviate
from langchain_core.documents import Document
from langflow.interface.custom.custom_component import CustomComponent
from langflow.schema.schema import Record
@ -50,9 +51,9 @@ class WeaviateVectorStoreComponent(CustomComponent):
def build(
self,
url: str,
index_name: str,
search_by_text: bool = False,
api_key: Optional[str] = None,
index_name: Optional[str] = None,
text_key: str = "text",
embedding: Optional[Embeddings] = None,
inputs: Optional[Record] = None,
@ -78,11 +79,13 @@ class WeaviateVectorStoreComponent(CustomComponent):
return pascal_case_word
index_name = _to_pascal_case(index_name) if index_name else None
documents = []
if not index_name:
raise ValueError("Index name is required")
documents: list[Document] = []
for _input in inputs or []:
if isinstance(_input, Record):
documents.append(_input.to_lc_document())
else:
elif isinstance(_input, Document):
documents.append(_input)
if documents and embedding is not None:

View file

@ -377,7 +377,7 @@
"list": false,
"show": true,
"multiline": true,
"value": "from typing import Optional\n\nfrom langflow.field_typing import Text\nfrom langflow.helpers.record import records_to_text\nfrom langflow.interface.custom.custom_component import CustomComponent\nfrom langflow.memory import get_messages\n\n\nclass MemoryComponent(CustomComponent):\n display_name = \"Chat Memory\"\n description = \"Retrieves stored chat messages given a specific Session ID.\"\n beta: bool = True\n icon = \"history\"\n\n def build_config(self):\n return {\n \"sender\": {\n \"options\": [\"Machine\", \"User\", \"Machine and User\"],\n \"display_name\": \"Sender Type\",\n },\n \"sender_name\": {\"display_name\": \"Sender Name\", \"advanced\": True},\n \"n_messages\": {\n \"display_name\": \"Number of Messages\",\n \"info\": \"Number of messages to retrieve.\",\n },\n \"session_id\": {\n \"display_name\": \"Session ID\",\n \"info\": \"Session ID of the chat history.\",\n \"input_types\": [\"Text\"],\n },\n \"order\": {\n \"options\": [\"Ascending\", \"Descending\"],\n \"display_name\": \"Order\",\n \"info\": \"Order of the messages.\",\n \"advanced\": True,\n },\n \"record_template\": {\n \"display_name\": \"Record Template\",\n \"multiline\": True,\n \"info\": \"Template to convert Record to Text. If left empty, it will be dynamically set to the Record's text key.\",\n \"advanced\": True,\n },\n }\n\n def build(\n self,\n sender: Optional[str] = \"Machine and User\",\n sender_name: Optional[str] = None,\n session_id: Optional[str] = None,\n n_messages: int = 5,\n order: Optional[str] = \"Descending\",\n record_template: Optional[str] = \"{sender_name}: {text}\",\n ) -> Text:\n order = \"DESC\" if order == \"Descending\" else \"ASC\"\n if sender == \"Machine and User\":\n sender = None\n messages = get_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n messages_str = records_to_text(template=record_template or \"\", records=messages)\n self.status = messages_str\n return messages_str\n",
"value": "from typing import Optional\n\nfrom langflow.base.memory.memory import BaseMemoryComponent\nfrom langflow.field_typing import Text\nfrom langflow.helpers.record import records_to_text\nfrom langflow.memory import get_messages\nfrom langflow.schema.schema import Record\n\n\nclass MemoryComponent(BaseMemoryComponent):\n display_name = \"Chat Memory\"\n description = \"Retrieves stored chat messages given a specific Session ID.\"\n beta: bool = True\n icon = \"history\"\n\n def build_config(self):\n return {\n \"sender\": {\n \"options\": [\"Machine\", \"User\", \"Machine and User\"],\n \"display_name\": \"Sender Type\",\n },\n \"sender_name\": {\"display_name\": \"Sender Name\", \"advanced\": True},\n \"n_messages\": {\n \"display_name\": \"Number of Messages\",\n \"info\": \"Number of messages to retrieve.\",\n },\n \"session_id\": {\n \"display_name\": \"Session ID\",\n \"info\": \"Session ID of the chat history.\",\n \"input_types\": [\"Text\"],\n },\n \"order\": {\n \"options\": [\"Ascending\", \"Descending\"],\n \"display_name\": \"Order\",\n \"info\": \"Order of the messages.\",\n \"advanced\": True,\n },\n \"record_template\": {\n \"display_name\": \"Record Template\",\n \"multiline\": True,\n \"info\": \"Template to convert Record to Text. If left empty, it will be dynamically set to the Record's text key.\",\n \"advanced\": True,\n },\n }\n\n def get_messages(self, **kwargs) -> list[Record]:\n # Validate kwargs by checking if it contains the correct keys\n if \"sender\" not in kwargs:\n kwargs[\"sender\"] = None\n if \"sender_name\" not in kwargs:\n kwargs[\"sender_name\"] = None\n if \"session_id\" not in kwargs:\n kwargs[\"session_id\"] = None\n if \"n_messages\" not in kwargs:\n kwargs[\"n_messages\"] = 5\n if \"order\" not in kwargs:\n kwargs[\"order\"] = \"Descending\"\n\n kwargs[\"order\"] = \"DESC\" if kwargs[\"order\"] == \"Descending\" else \"ASC\"\n if kwargs[\"sender\"] == \"Machine and User\":\n kwargs[\"sender\"] = None\n return get_messages(**kwargs)\n\n def build(\n self,\n sender: Optional[str] = \"Machine and User\",\n sender_name: Optional[str] = None,\n session_id: Optional[str] = None,\n n_messages: int = 5,\n order: Optional[str] = \"Descending\",\n record_template: Optional[str] = \"{sender_name}: {text}\",\n ) -> Text:\n messages = self.get_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n n_messages=n_messages,\n order=order,\n )\n messages_str = records_to_text(template=record_template or \"\", records=messages)\n self.status = messages_str\n return messages_str\n",
"fileTypes": [],
"file_path": "",
"password": false,

View file

@ -227,34 +227,11 @@ def initialize_qdrant(class_object: Type[Qdrant], params: dict):
return class_object.from_documents(**params)
def initialize_elasticsearch(class_object: Type[ElasticsearchStore], params: dict):
"""Initialize elastic and return the class object"""
if "index_name" not in params:
raise ValueError("Elasticsearch Index must be provided in the params")
if "es_url" not in params:
raise ValueError("Elasticsearch URL must be provided in the params")
if not docs_in_params(params):
existing_index_params = {
"embedding": params.pop("embedding"),
}
if "index_name" in params:
existing_index_params["index_name"] = params.pop("index_name")
if "es_url" in params:
existing_index_params["es_url"] = params.pop("es_url")
return class_object.from_existing_index(**existing_index_params)
# If there are docs in the params, create a new index
if "texts" in params:
params["documents"] = params.pop("texts")
return class_object.from_documents(**params)
vecstore_initializer: Dict[str, Callable[[Type[Any], dict], Any]] = {
"Pinecone": initialize_pinecone,
"Chroma": initialize_chroma,
"Qdrant": initialize_qdrant,
"Weaviate": initialize_weaviate,
"ElasticsearchStore": initialize_elasticsearch,
"FAISS": initialize_faiss,
"SupabaseVectorStore": initialize_supabase,
"MongoDBAtlasVectorSearch": initialize_mongodb,

View file

@ -269,16 +269,19 @@ def process_tweaks(
:return: The modified graph_data dictionary.
:raises ValueError: If the input is not in the expected format.
"""
tweaks_dict = {}
if not isinstance(tweaks, dict):
tweaks = tweaks.model_dump()
if "stream" not in tweaks:
tweaks["stream"] = stream
nodes = validate_input(graph_data, tweaks)
tweaks_dict = tweaks.model_dump()
else:
tweaks_dict = tweaks
if "stream" not in tweaks_dict:
tweaks_dict["stream"] = stream
nodes = validate_input(graph_data, tweaks_dict)
nodes_map = {node.get("id"): node for node in nodes}
nodes_display_name_map = {node.get("data", {}).get("node", {}).get("display_name"): node for node in nodes}
all_nodes_tweaks = {}
for key, value in tweaks.items():
for key, value in tweaks_dict.items():
if isinstance(value, dict):
if node := nodes_map.get(key):
apply_tweaks(node, value)

View file

@ -1,10 +1,9 @@
import copy
from typing import Literal, Optional
from typing import Literal, Optional, cast
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from pydantic import BaseModel, model_validator
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
class Record(BaseModel):
@ -67,8 +66,8 @@ class Record(BaseModel):
Returns:
Record: The converted Record.
"""
data = {"text": message.content}
data["metadata"] = message.to_json()
data: dict = {"text": message.content}
data["metadata"] = cast(dict, message.to_json())
return cls(data=data, text_key="text")
def __add__(self, other: "Record") -> "Record":