Refactor MemoryComponent class and add ZepMessageReaderComponent (#1771)
* Add BaseMemoryComponent class to langflow.base.memory.memory.py (#1750) * Add BaseMemoryComponent class to langflow.base.memory.memory.py * Update MemoryComponent class in langflow.components.helpers.MemoryComponent.py to inherit from BaseMemoryComponent * ✨ (ZepMessageReader.py): Add ZepMessageReaderComponent to retrieve chat messages from Zep 📝 (ZepMessageWriter.py): Add ZepMessageWriterComponent to add messages to ZepChatMessageHistory 📝 (Langflow Memory Conversation.json): Refactor MemoryComponent class to inherit from BaseMemoryComponent for better code organization and reusability. Move get_messages method to the class level and validate kwargs for correct keys before processing. * Update WeaviateSearch.py to include index_name parameter in build method Update ZepMessageWriter.py to include metadata parameter in __init__ method Update ZepMessageReader.py to include cast function for memory parameter Update schema.py to include cast function for metadata parameter Update process.py to include tweaks_dict variable and use it in apply_tweaks method Update Weaviate.py to include index_name parameter in build method and raise ValueError if index_name is not provided * Update process.py to include tweaks_dict variable and use it in apply_tweaks method * ✨ (ZepMessageReader.py): Update ZepMessageReaderComponent build method to handle optional url and api_key parameters and improve error handling for zep-python package import 📝 (ZepMessageWriter.py): Refactor ZepMessageWriterComponent to use 'text' instead of 'message' for consistency and update add_message method to reflect this change. Add 'input_value' configuration option for specifying the record to write to Zep. Update build_config method to reflect changes in input parameters. Update add_message method to use 'text' parameter instead of 'message'. Update build method to handle optional url and api_key parameters and improve error handling for zep-python package import. * Update zep-python package to version 2.0.0rc5 * 📝 (memory.py): update parameter name from 'message' to 'text' for better clarity and consistency
This commit is contained in:
parent
ce322f1ba1
commit
19e46de04c
14 changed files with 345 additions and 47 deletions
17
poetry.lock
generated
17
poetry.lock
generated
|
|
@ -10169,6 +10169,21 @@ files = [
|
|||
idna = ">=2.0"
|
||||
multidict = ">=4.0"
|
||||
|
||||
[[package]]
|
||||
name = "zep-python"
|
||||
version = "2.0.0rc5"
|
||||
description = "Long-Term Memory for AI Assistants. This is the Python client for the Zep service."
|
||||
optional = false
|
||||
python-versions = "<4,>=3.9.0"
|
||||
files = [
|
||||
{file = "zep_python-2.0.0rc5-py3-none-any.whl", hash = "sha256:8b1b5c22c9e1ef439c9ef3d785347abf89b1243c7149e32025dd065cc022af40"},
|
||||
{file = "zep_python-2.0.0rc5.tar.gz", hash = "sha256:e6ced8089760374dead948d6b4b88fceb09a356bf9a7fe182b4ceb6e828f0bb1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.24.0,<0.29.0"
|
||||
pydantic = ">=2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.18.1"
|
||||
|
|
@ -10262,4 +10277,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.10,<3.12"
|
||||
content-hash = "9dd152b30031767c522c77e2ad5fc4597a8d1590b13968af143bd382e056b2a1"
|
||||
content-hash = "bec34397b534f882551511558c76785c7cd67e6a1eefc1d45f6a64d97175d886"
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ chromadb = "^0.4.24"
|
|||
langchain-anthropic = "^0.1.6"
|
||||
langchain-astradb = "^0.1.0"
|
||||
langchain-openai = "^0.1.1"
|
||||
zep-python = {version = "^2.0.0rc5", allow-prereleases = true}
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
types-redis = "^4.6.0.5"
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ async def simplified_run_flow(
|
|||
graph_data = flow.data
|
||||
|
||||
graph_data = process_tweaks(graph_data, input_request.tweaks or {}, stream=stream)
|
||||
graph = Graph.from_payload(graph_data, flow_id=flow_id, user_id=api_key_user.id)
|
||||
graph = Graph.from_payload(graph_data, flow_id=flow_id, user_id=str(api_key_user.id))
|
||||
inputs = [
|
||||
InputValueRequest(components=[], input_value=input_request.input_value, type=input_request.input_type)
|
||||
]
|
||||
|
|
|
|||
0
src/backend/base/langflow/base/memory/__init__.py
Normal file
0
src/backend/base/langflow/base/memory/__init__.py
Normal file
51
src/backend/base/langflow/base/memory/memory.py
Normal file
51
src/backend/base/langflow/base/memory/memory.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
from typing import Optional
|
||||
|
||||
from langflow.field_typing import Text
|
||||
from langflow.helpers.record import records_to_text
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.schema.schema import Record
|
||||
|
||||
|
||||
class BaseMemoryComponent(CustomComponent):
|
||||
display_name = "Chat Memory"
|
||||
description = "Retrieves stored chat messages given a specific Session ID."
|
||||
beta: bool = True
|
||||
icon = "history"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"sender": {
|
||||
"options": ["Machine", "User", "Machine and User"],
|
||||
"display_name": "Sender Type",
|
||||
},
|
||||
"sender_name": {"display_name": "Sender Name", "advanced": True},
|
||||
"n_messages": {
|
||||
"display_name": "Number of Messages",
|
||||
"info": "Number of messages to retrieve.",
|
||||
},
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"order": {
|
||||
"options": ["Ascending", "Descending"],
|
||||
"display_name": "Order",
|
||||
"info": "Order of the messages.",
|
||||
"advanced": True,
|
||||
},
|
||||
"record_template": {
|
||||
"display_name": "Record Template",
|
||||
"multiline": True,
|
||||
"info": "Template to convert Record to Text. If left empty, it will be dynamically set to the Record's text key.",
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def get_messages(self, **kwargs) -> list[Record]:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_message(
|
||||
self, sender: str, sender_name: str, text: str, session_id: str, metadata: Optional[dict] = None, **kwargs
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Optional
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.helpers.record import records_to_text
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.memory import get_messages
|
||||
from langflow.schema.schema import Record
|
||||
|
||||
|
||||
class MemoryComponent(CustomComponent):
|
||||
class MemoryComponent(BaseMemoryComponent):
|
||||
display_name = "Chat Memory"
|
||||
description = "Retrieves stored chat messages given a specific Session ID."
|
||||
beta: bool = True
|
||||
|
|
@ -42,6 +43,24 @@ class MemoryComponent(CustomComponent):
|
|||
},
|
||||
}
|
||||
|
||||
def get_messages(self, **kwargs) -> list[Record]:
|
||||
# Validate kwargs by checking if it contains the correct keys
|
||||
if "sender" not in kwargs:
|
||||
kwargs["sender"] = None
|
||||
if "sender_name" not in kwargs:
|
||||
kwargs["sender_name"] = None
|
||||
if "session_id" not in kwargs:
|
||||
kwargs["session_id"] = None
|
||||
if "n_messages" not in kwargs:
|
||||
kwargs["n_messages"] = 5
|
||||
if "order" not in kwargs:
|
||||
kwargs["order"] = "Descending"
|
||||
|
||||
kwargs["order"] = "DESC" if kwargs["order"] == "Descending" else "ASC"
|
||||
if kwargs["sender"] == "Machine and User":
|
||||
kwargs["sender"] = None
|
||||
return get_messages(**kwargs)
|
||||
|
||||
def build(
|
||||
self,
|
||||
sender: Optional[str] = "Machine and User",
|
||||
|
|
@ -51,14 +70,11 @@ class MemoryComponent(CustomComponent):
|
|||
order: Optional[str] = "Descending",
|
||||
record_template: Optional[str] = "{sender_name}: {text}",
|
||||
) -> Text:
|
||||
order = "DESC" if order == "Descending" else "ASC"
|
||||
if sender == "Machine and User":
|
||||
sender = None
|
||||
messages = get_messages(
|
||||
messages = self.get_messages(
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
session_id=session_id,
|
||||
limit=n_messages,
|
||||
n_messages=n_messages,
|
||||
order=order,
|
||||
)
|
||||
messages_str = records_to_text(template=record_template or "", records=messages)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,137 @@
|
|||
from typing import Optional, cast
|
||||
|
||||
from langchain_community.chat_message_histories.zep import SearchScope, SearchType, ZepChatMessageHistory
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.schema.schema import Record
|
||||
|
||||
|
||||
class ZepMessageReaderComponent(BaseMemoryComponent):
|
||||
display_name = "Zep Message Reader"
|
||||
description = "Retrieves stored chat messages from Zep."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"url": {
|
||||
"display_name": "Zep URL",
|
||||
"info": "URL of the Zep instance.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"api_key": {
|
||||
"display_name": "Zep API Key",
|
||||
"info": "API Key for the Zep instance.",
|
||||
"password": True,
|
||||
},
|
||||
"query": {
|
||||
"display_name": "Query",
|
||||
"info": "Query to search for in the chat history.",
|
||||
},
|
||||
"metadata": {
|
||||
"display_name": "Metadata",
|
||||
"info": "Optional metadata to attach to the message.",
|
||||
"advanced": True,
|
||||
},
|
||||
"search_scope": {
|
||||
"options": ["Messages", "Summary"],
|
||||
"display_name": "Search Scope",
|
||||
"info": "Scope of the search.",
|
||||
"advanced": True,
|
||||
},
|
||||
"search_type": {
|
||||
"options": ["Similarity", "MMR"],
|
||||
"display_name": "Search Type",
|
||||
"info": "Type of search.",
|
||||
"advanced": True,
|
||||
},
|
||||
"limit": {
|
||||
"display_name": "Limit",
|
||||
"info": "Limit of search results.",
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def get_messages(self, **kwargs) -> list[Record]:
|
||||
"""
|
||||
Retrieves messages from the ZepChatMessageHistory memory.
|
||||
|
||||
If a query is provided, the search method is used to search for messages in the memory, otherwise all messages are returned.
|
||||
|
||||
Args:
|
||||
memory (ZepChatMessageHistory): The ZepChatMessageHistory instance to retrieve messages from.
|
||||
query (str, optional): The query string to search for messages. Defaults to None.
|
||||
metadata (dict, optional): Additional metadata to filter the search results. Defaults to None.
|
||||
search_scope (str, optional): The scope of the search. Can be 'messages' or 'summary'. Defaults to 'messages'.
|
||||
search_type (str, optional): The type of search. Can be 'similarity' or 'exact'. Defaults to 'similarity'.
|
||||
limit (int, optional): The maximum number of search results to return. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[Record]: A list of Record objects representing the search results.
|
||||
"""
|
||||
memory: ZepChatMessageHistory = cast(ZepChatMessageHistory, kwargs.get("memory"))
|
||||
if not memory:
|
||||
raise ValueError("ZepChatMessageHistory instance is required.")
|
||||
query = kwargs.get("query")
|
||||
search_scope = kwargs.get("search_scope", SearchScope.messages).lower()
|
||||
search_type = kwargs.get("search_type", SearchType.similarity).lower()
|
||||
limit = kwargs.get("limit")
|
||||
|
||||
if query:
|
||||
memory_search_results = memory.search(
|
||||
query,
|
||||
search_scope=search_scope,
|
||||
search_type=search_type,
|
||||
limit=limit,
|
||||
)
|
||||
# Get the messages from the search results if the search scope is messages
|
||||
result_dicts = []
|
||||
for result in memory_search_results:
|
||||
result_dict = {}
|
||||
if search_scope == SearchScope.messages:
|
||||
result_dict["text"] = result.message
|
||||
else:
|
||||
result_dict["text"] = result.summary
|
||||
result_dict["metadata"] = result.metadata
|
||||
result_dict["score"] = result.score
|
||||
result_dicts.append(result_dict)
|
||||
results = [Record(data=result_dict) for result_dict in result_dicts]
|
||||
else:
|
||||
messages = memory.messages
|
||||
results = [Record.from_lc_message(message) for message in messages]
|
||||
return results
|
||||
|
||||
def build(
|
||||
self,
|
||||
session_id: Text,
|
||||
url: Optional[Text] = None,
|
||||
api_key: Optional[Text] = None,
|
||||
query: Optional[Text] = None,
|
||||
search_scope: SearchScope = SearchScope.messages,
|
||||
search_type: SearchType = SearchType.similarity,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[Record]:
|
||||
try:
|
||||
from zep_python import ZepClient
|
||||
from zep_python.langchain import ZepChatMessageHistory
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
|
||||
)
|
||||
if url == "":
|
||||
url = None
|
||||
zep_client = ZepClient(api_url=url, api_key=api_key)
|
||||
memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client)
|
||||
records = self.get_messages(
|
||||
memory=memory,
|
||||
query=query,
|
||||
search_scope=search_scope,
|
||||
search_type=search_type,
|
||||
limit=limit,
|
||||
)
|
||||
self.status = records
|
||||
return records
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.schema.schema import Record
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python.langchain import ZepChatMessageHistory
|
||||
|
||||
|
||||
class ZepMessageWriterComponent(BaseMemoryComponent):
|
||||
display_name = "Zep Message Writer"
|
||||
description = "Writes a message to Zep."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"url": {
|
||||
"display_name": "Zep URL",
|
||||
"info": "URL of the Zep instance.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"api_key": {
|
||||
"display_name": "Zep API Key",
|
||||
"info": "API Key for the Zep instance.",
|
||||
"password": True,
|
||||
},
|
||||
"limit": {
|
||||
"display_name": "Limit",
|
||||
"info": "Limit of search results.",
|
||||
"advanced": True,
|
||||
},
|
||||
"input_value": {
|
||||
"display_name": "Input Record",
|
||||
"info": "Record to write to Zep.",
|
||||
},
|
||||
}
|
||||
|
||||
def add_message(
|
||||
self, sender: Text, sender_name: Text, text: Text, session_id: Text, metadata: dict | None = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Adds a message to the ZepChatMessageHistory memory.
|
||||
|
||||
Args:
|
||||
sender (Text): The type of the message sender. Valid values are "Machine" or "User".
|
||||
sender_name (Text): The name of the message sender.
|
||||
text (Text): The content of the message.
|
||||
session_id (Text): The session ID associated with the message.
|
||||
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the ZepChatMessageHistory instance is not provided.
|
||||
|
||||
"""
|
||||
memory: ZepChatMessageHistory | None = kwargs.pop("memory", None)
|
||||
if memory is None:
|
||||
raise ValueError("ZepChatMessageHistory instance is required.")
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["sender_name"] = sender_name
|
||||
metadata.update(kwargs)
|
||||
if sender == "Machine":
|
||||
memory.add_ai_message(text, metadata=metadata)
|
||||
elif sender == "User":
|
||||
memory.add_user_message(text, metadata=metadata)
|
||||
else:
|
||||
raise ValueError(f"Invalid sender type: {sender}")
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Record,
|
||||
session_id: Text,
|
||||
url: Optional[Text] = None,
|
||||
api_key: Optional[Text] = None,
|
||||
) -> Record:
|
||||
try:
|
||||
from zep_python import ZepClient
|
||||
from zep_python.langchain import ZepChatMessageHistory
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
|
||||
)
|
||||
if url == "":
|
||||
url = None
|
||||
zep_client = ZepClient(api_url=url, api_key=api_key)
|
||||
memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client)
|
||||
self.add_message(**input_value.data, memory=memory)
|
||||
self.status = f"Added message to Zep memory for session {session_id}"
|
||||
return input_value
|
||||
|
|
@ -61,10 +61,10 @@ class WeaviateSearchVectorStore(WeaviateVectorStoreComponent, LCVectorStoreCompo
|
|||
input_value: Text,
|
||||
search_type: str,
|
||||
url: str,
|
||||
index_name: str,
|
||||
number_of_results: int = 4,
|
||||
search_by_text: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
index_name: Optional[str] = None,
|
||||
text_key: str = "text",
|
||||
embedding: Optional[Embeddings] = None,
|
||||
attributes: Optional[list] = None,
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import weaviate # type: ignore
|
|||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import BaseRetriever
|
||||
from langchain_community.vectorstores import VectorStore, Weaviate
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.schema.schema import Record
|
||||
|
|
@ -50,9 +51,9 @@ class WeaviateVectorStoreComponent(CustomComponent):
|
|||
def build(
|
||||
self,
|
||||
url: str,
|
||||
index_name: str,
|
||||
search_by_text: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
index_name: Optional[str] = None,
|
||||
text_key: str = "text",
|
||||
embedding: Optional[Embeddings] = None,
|
||||
inputs: Optional[Record] = None,
|
||||
|
|
@ -78,11 +79,13 @@ class WeaviateVectorStoreComponent(CustomComponent):
|
|||
return pascal_case_word
|
||||
|
||||
index_name = _to_pascal_case(index_name) if index_name else None
|
||||
documents = []
|
||||
if not index_name:
|
||||
raise ValueError("Index name is required")
|
||||
documents: list[Document] = []
|
||||
for _input in inputs or []:
|
||||
if isinstance(_input, Record):
|
||||
documents.append(_input.to_lc_document())
|
||||
else:
|
||||
elif isinstance(_input, Document):
|
||||
documents.append(_input)
|
||||
|
||||
if documents and embedding is not None:
|
||||
|
|
|
|||
|
|
@ -377,7 +377,7 @@
|
|||
"list": false,
|
||||
"show": true,
|
||||
"multiline": true,
|
||||
"value": "from typing import Optional\n\nfrom langflow.field_typing import Text\nfrom langflow.helpers.record import records_to_text\nfrom langflow.interface.custom.custom_component import CustomComponent\nfrom langflow.memory import get_messages\n\n\nclass MemoryComponent(CustomComponent):\n display_name = \"Chat Memory\"\n description = \"Retrieves stored chat messages given a specific Session ID.\"\n beta: bool = True\n icon = \"history\"\n\n def build_config(self):\n return {\n \"sender\": {\n \"options\": [\"Machine\", \"User\", \"Machine and User\"],\n \"display_name\": \"Sender Type\",\n },\n \"sender_name\": {\"display_name\": \"Sender Name\", \"advanced\": True},\n \"n_messages\": {\n \"display_name\": \"Number of Messages\",\n \"info\": \"Number of messages to retrieve.\",\n },\n \"session_id\": {\n \"display_name\": \"Session ID\",\n \"info\": \"Session ID of the chat history.\",\n \"input_types\": [\"Text\"],\n },\n \"order\": {\n \"options\": [\"Ascending\", \"Descending\"],\n \"display_name\": \"Order\",\n \"info\": \"Order of the messages.\",\n \"advanced\": True,\n },\n \"record_template\": {\n \"display_name\": \"Record Template\",\n \"multiline\": True,\n \"info\": \"Template to convert Record to Text. If left empty, it will be dynamically set to the Record's text key.\",\n \"advanced\": True,\n },\n }\n\n def build(\n self,\n sender: Optional[str] = \"Machine and User\",\n sender_name: Optional[str] = None,\n session_id: Optional[str] = None,\n n_messages: int = 5,\n order: Optional[str] = \"Descending\",\n record_template: Optional[str] = \"{sender_name}: {text}\",\n ) -> Text:\n order = \"DESC\" if order == \"Descending\" else \"ASC\"\n if sender == \"Machine and User\":\n sender = None\n messages = get_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n limit=n_messages,\n order=order,\n )\n messages_str = records_to_text(template=record_template or \"\", records=messages)\n self.status = messages_str\n return messages_str\n",
|
||||
"value": "from typing import Optional\n\nfrom langflow.base.memory.memory import BaseMemoryComponent\nfrom langflow.field_typing import Text\nfrom langflow.helpers.record import records_to_text\nfrom langflow.memory import get_messages\nfrom langflow.schema.schema import Record\n\n\nclass MemoryComponent(BaseMemoryComponent):\n display_name = \"Chat Memory\"\n description = \"Retrieves stored chat messages given a specific Session ID.\"\n beta: bool = True\n icon = \"history\"\n\n def build_config(self):\n return {\n \"sender\": {\n \"options\": [\"Machine\", \"User\", \"Machine and User\"],\n \"display_name\": \"Sender Type\",\n },\n \"sender_name\": {\"display_name\": \"Sender Name\", \"advanced\": True},\n \"n_messages\": {\n \"display_name\": \"Number of Messages\",\n \"info\": \"Number of messages to retrieve.\",\n },\n \"session_id\": {\n \"display_name\": \"Session ID\",\n \"info\": \"Session ID of the chat history.\",\n \"input_types\": [\"Text\"],\n },\n \"order\": {\n \"options\": [\"Ascending\", \"Descending\"],\n \"display_name\": \"Order\",\n \"info\": \"Order of the messages.\",\n \"advanced\": True,\n },\n \"record_template\": {\n \"display_name\": \"Record Template\",\n \"multiline\": True,\n \"info\": \"Template to convert Record to Text. If left empty, it will be dynamically set to the Record's text key.\",\n \"advanced\": True,\n },\n }\n\n def get_messages(self, **kwargs) -> list[Record]:\n # Validate kwargs by checking if it contains the correct keys\n if \"sender\" not in kwargs:\n kwargs[\"sender\"] = None\n if \"sender_name\" not in kwargs:\n kwargs[\"sender_name\"] = None\n if \"session_id\" not in kwargs:\n kwargs[\"session_id\"] = None\n if \"n_messages\" not in kwargs:\n kwargs[\"n_messages\"] = 5\n if \"order\" not in kwargs:\n kwargs[\"order\"] = \"Descending\"\n\n kwargs[\"order\"] = \"DESC\" if kwargs[\"order\"] == \"Descending\" else \"ASC\"\n if kwargs[\"sender\"] == \"Machine and User\":\n kwargs[\"sender\"] = None\n return get_messages(**kwargs)\n\n def build(\n self,\n sender: Optional[str] = \"Machine and User\",\n sender_name: Optional[str] = None,\n session_id: Optional[str] = None,\n n_messages: int = 5,\n order: Optional[str] = \"Descending\",\n record_template: Optional[str] = \"{sender_name}: {text}\",\n ) -> Text:\n messages = self.get_messages(\n sender=sender,\n sender_name=sender_name,\n session_id=session_id,\n n_messages=n_messages,\n order=order,\n )\n messages_str = records_to_text(template=record_template or \"\", records=messages)\n self.status = messages_str\n return messages_str\n",
|
||||
"fileTypes": [],
|
||||
"file_path": "",
|
||||
"password": false,
|
||||
|
|
|
|||
|
|
@ -227,34 +227,11 @@ def initialize_qdrant(class_object: Type[Qdrant], params: dict):
|
|||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def initialize_elasticsearch(class_object: Type[ElasticsearchStore], params: dict):
|
||||
"""Initialize elastic and return the class object"""
|
||||
if "index_name" not in params:
|
||||
raise ValueError("Elasticsearch Index must be provided in the params")
|
||||
if "es_url" not in params:
|
||||
raise ValueError("Elasticsearch URL must be provided in the params")
|
||||
if not docs_in_params(params):
|
||||
existing_index_params = {
|
||||
"embedding": params.pop("embedding"),
|
||||
}
|
||||
if "index_name" in params:
|
||||
existing_index_params["index_name"] = params.pop("index_name")
|
||||
if "es_url" in params:
|
||||
existing_index_params["es_url"] = params.pop("es_url")
|
||||
|
||||
return class_object.from_existing_index(**existing_index_params)
|
||||
# If there are docs in the params, create a new index
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
vecstore_initializer: Dict[str, Callable[[Type[Any], dict], Any]] = {
|
||||
"Pinecone": initialize_pinecone,
|
||||
"Chroma": initialize_chroma,
|
||||
"Qdrant": initialize_qdrant,
|
||||
"Weaviate": initialize_weaviate,
|
||||
"ElasticsearchStore": initialize_elasticsearch,
|
||||
"FAISS": initialize_faiss,
|
||||
"SupabaseVectorStore": initialize_supabase,
|
||||
"MongoDBAtlasVectorSearch": initialize_mongodb,
|
||||
|
|
|
|||
|
|
@ -269,16 +269,19 @@ def process_tweaks(
|
|||
:return: The modified graph_data dictionary.
|
||||
:raises ValueError: If the input is not in the expected format.
|
||||
"""
|
||||
tweaks_dict = {}
|
||||
if not isinstance(tweaks, dict):
|
||||
tweaks = tweaks.model_dump()
|
||||
if "stream" not in tweaks:
|
||||
tweaks["stream"] = stream
|
||||
nodes = validate_input(graph_data, tweaks)
|
||||
tweaks_dict = tweaks.model_dump()
|
||||
else:
|
||||
tweaks_dict = tweaks
|
||||
if "stream" not in tweaks_dict:
|
||||
tweaks_dict["stream"] = stream
|
||||
nodes = validate_input(graph_data, tweaks_dict)
|
||||
nodes_map = {node.get("id"): node for node in nodes}
|
||||
nodes_display_name_map = {node.get("data", {}).get("node", {}).get("display_name"): node for node in nodes}
|
||||
|
||||
all_nodes_tweaks = {}
|
||||
for key, value in tweaks.items():
|
||||
for key, value in tweaks_dict.items():
|
||||
if isinstance(value, dict):
|
||||
if node := nodes_map.get(key):
|
||||
apply_tweaks(node, value)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
import copy
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, cast
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||
from pydantic import BaseModel, model_validator
|
||||
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
|
|
@ -67,8 +66,8 @@ class Record(BaseModel):
|
|||
Returns:
|
||||
Record: The converted Record.
|
||||
"""
|
||||
data = {"text": message.content}
|
||||
data["metadata"] = message.to_json()
|
||||
data: dict = {"text": message.content}
|
||||
data["metadata"] = cast(dict, message.to_json())
|
||||
return cls(data=data, text_key="text")
|
||||
|
||||
def __add__(self, other: "Record") -> "Record":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue