From eb3420523ea756ac53c53e56fb9cad2627c025f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Boschi?= Date: Fri, 5 Jul 2024 18:30:23 +0200 Subject: [PATCH] feat: migrate chains and memories to Component syntax (#2528) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: migrate chains and memories to Component syntax * use base class * add classes * [autofix.ci] apply automated fixes * fix tests * fix tests * ✅ (filterSidebar.spec.ts): increase waitForTimeout from 1000ms to 2000ms to ensure elements are fully loaded before interaction --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Gabriel Luiz Freitas Almeida Co-authored-by: cristhianzl Co-authored-by: Cristhian Zanforlin Lousa <72977554+Cristhianzl@users.noreply.github.com> --- .../base/langflow/base/chains/__init__.py | 0 .../base/langflow/base/chains/model.py | 17 ++ .../base/langflow/base/memory/model.py | 35 ++++ .../components/chains/ConversationChain.py | 52 +++--- .../langflow/components/chains/LLMChain.py | 34 ---- .../components/chains/LLMCheckerChain.py | 39 ++--- .../components/chains/LLMMathChain.py | 60 +++---- .../langflow/components/chains/RetrievalQA.py | 111 ++++++------- .../chains/RetrievalQAWithSourcesChain.py | 64 -------- .../components/chains/SQLGenerator.py | 63 +++----- .../langflow/components/chains/__init__.py | 4 - .../components/memories/AstraDBChatMemory.py | 60 +++++++ .../memories/AstraDBMessageReader.py | 97 ----------- .../memories/AstraDBMessageWriter.py | 127 --------------- .../memories/CassandraChatMemory.py | 93 +++++++++++ .../memories/CassandraMessageReader.py | 87 ---------- .../memories/CassandraMessageWriter.py | 123 -------------- .../components/memories/ZepChatMemory.py | 43 +++++ .../components/memories/ZepMessageReader.py | 151 ------------------ .../components/memories/ZepMessageWriter.py | 109 ------------- .../base/langflow/field_typing/constants.py | 2 + .../end-to-end/filterEdge-shard-0.spec.ts | 65 ++++---- .../end-to-end/filterEdge-shard-1.spec.ts | 63 ++++---- .../tests/end-to-end/filterSidebar.spec.ts | 128 ++++++++------- 24 files changed, 531 insertions(+), 1096 deletions(-) create mode 100644 src/backend/base/langflow/base/chains/__init__.py create mode 100644 src/backend/base/langflow/base/chains/model.py create mode 100644 src/backend/base/langflow/base/memory/model.py delete mode 100644 src/backend/base/langflow/components/chains/LLMChain.py delete mode 100644 src/backend/base/langflow/components/chains/RetrievalQAWithSourcesChain.py create mode 100644 src/backend/base/langflow/components/memories/AstraDBChatMemory.py delete mode 100644 src/backend/base/langflow/components/memories/AstraDBMessageReader.py delete mode 100644 src/backend/base/langflow/components/memories/AstraDBMessageWriter.py create mode 100644 src/backend/base/langflow/components/memories/CassandraChatMemory.py delete mode 100644 src/backend/base/langflow/components/memories/CassandraMessageReader.py delete mode 100644 src/backend/base/langflow/components/memories/CassandraMessageWriter.py create mode 100644 src/backend/base/langflow/components/memories/ZepChatMemory.py delete mode 100644 src/backend/base/langflow/components/memories/ZepMessageReader.py delete mode 100644 src/backend/base/langflow/components/memories/ZepMessageWriter.py diff --git a/src/backend/base/langflow/base/chains/__init__.py b/src/backend/base/langflow/base/chains/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/base/langflow/base/chains/model.py b/src/backend/base/langflow/base/chains/model.py new file mode 100644 index 000000000..efbe4ea19 --- /dev/null +++ b/src/backend/base/langflow/base/chains/model.py @@ -0,0 +1,17 @@ +from langflow.custom import Component +from langflow.template import Output + + +class LCChainComponent(Component): + trace_type = "chain" + + outputs = [Output(display_name="Text", name="text", method="invoke_chain")] + + def _validate_outputs(self): + required_output_methods = ["invoke_chain"] + output_names = [output.name for output in self.outputs] + for method_name in required_output_methods: + if method_name not in output_names: + raise ValueError(f"Output with name '{method_name}' must be defined.") + elif not hasattr(self, method_name): + raise ValueError(f"Method '{method_name}' must be defined.") diff --git a/src/backend/base/langflow/base/memory/model.py b/src/backend/base/langflow/base/memory/model.py new file mode 100644 index 000000000..22f9a02a3 --- /dev/null +++ b/src/backend/base/langflow/base/memory/model.py @@ -0,0 +1,35 @@ +from abc import abstractmethod + +from langflow.custom import Component +from langflow.field_typing import BaseChatMessageHistory, BaseChatMemory +from langflow.template import Output +from langchain.memory import ConversationBufferMemory + + +class LCChatMemoryComponent(Component): + trace_type = "chat_memory" + outputs = [ + Output( + display_name="Memory", + name="base_memory", + method="build_base_memory", + ) + ] + + def _validate_outputs(self): + required_output_methods = ["build_base_memory"] + output_names = [output.name for output in self.outputs] + for method_name in required_output_methods: + if method_name not in output_names: + raise ValueError(f"Output with name '{method_name}' must be defined.") + elif not hasattr(self, method_name): + raise ValueError(f"Method '{method_name}' must be defined.") + + def build_base_memory(self) -> BaseChatMemory: + return ConversationBufferMemory(chat_memory=self.build_message_history()) + + @abstractmethod + def build_message_history(self) -> BaseChatMessageHistory: + """ + Builds the chat message history memory. + """ diff --git a/src/backend/base/langflow/components/chains/ConversationChain.py b/src/backend/base/langflow/components/chains/ConversationChain.py index f4a6e1b39..727a9d8cf 100644 --- a/src/backend/base/langflow/components/chains/ConversationChain.py +++ b/src/backend/base/langflow/components/chains/ConversationChain.py @@ -1,41 +1,34 @@ -from typing import Optional - from langchain.chains import ConversationChain -from langflow.custom import CustomComponent -from langflow.field_typing import BaseMemory, LanguageModel, Text +from langflow.base.chains.model import LCChainComponent +from langflow.field_typing import Message +from langflow.inputs import MultilineInput, HandleInput -class ConversationChainComponent(CustomComponent): +class ConversationChainComponent(LCChainComponent): display_name = "ConversationChain" description = "Chain to have a conversation and load context from memory." name = "ConversationChain" - def build_config(self): - return { - "prompt": {"display_name": "Prompt"}, - "llm": {"display_name": "LLM"}, - "memory": { - "display_name": "Memory", - "info": "Memory to load context from. If none is provided, a ConversationBufferMemory will be used.", - }, - "input_value": { - "display_name": "Input Value", - "info": "The input value to pass to the chain.", - }, - } + inputs = [ + MultilineInput( + name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True + ), + HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + HandleInput( + name="memory", + display_name="Memory", + input_types=["BaseChatMemory"], + ), + ] - def build( - self, - input_value: Text, - llm: LanguageModel, - memory: Optional[BaseMemory] = None, - ) -> Text: - if memory is None: - chain = ConversationChain(llm=llm) + def invoke_chain(self) -> Message: + if not self.memory: + chain = ConversationChain(llm=self.llm) else: - chain = ConversationChain(llm=llm, memory=memory) - result = chain.invoke({"input": input_value}) + chain = ConversationChain(llm=self.llm, memory=self.memory) + + result = chain.invoke({"input": self.input_value}) if isinstance(result, dict): result = result.get(chain.output_key, "") # type: ignore @@ -43,5 +36,6 @@ class ConversationChainComponent(CustomComponent): result = result else: result = result.get("response") + result = str(result) self.status = result - return str(result) + return Message(text=result) diff --git a/src/backend/base/langflow/components/chains/LLMChain.py b/src/backend/base/langflow/components/chains/LLMChain.py deleted file mode 100644 index d0f0c0975..000000000 --- a/src/backend/base/langflow/components/chains/LLMChain.py +++ /dev/null @@ -1,34 +0,0 @@ -from typing import Optional - -from langchain.chains.llm import LLMChain -from langchain_core.prompts import PromptTemplate - -from langflow.custom import CustomComponent -from langflow.field_typing import BaseMemory, LanguageModel, Text - - -class LLMChainComponent(CustomComponent): - display_name = "LLMChain" - description = "Chain to run queries against LLMs" - name = "LLMChain" - - def build_config(self): - return { - "prompt": {"display_name": "Prompt"}, - "llm": {"display_name": "LLM"}, - "memory": {"display_name": "Memory"}, - } - - def build( - self, - template: Text, - llm: LanguageModel, - memory: Optional[BaseMemory] = None, - ) -> Text: - prompt = PromptTemplate.from_template(template) - runnable = LLMChain(prompt=prompt, llm=llm, memory=memory) - result_dict = runnable.invoke({}) - output_key = runnable.output_key - result = result_dict[output_key] - self.status = result - return result diff --git a/src/backend/base/langflow/components/chains/LLMCheckerChain.py b/src/backend/base/langflow/components/chains/LLMCheckerChain.py index 080a8d7dd..127674701 100644 --- a/src/backend/base/langflow/components/chains/LLMCheckerChain.py +++ b/src/backend/base/langflow/components/chains/LLMCheckerChain.py @@ -1,32 +1,27 @@ from langchain.chains import LLMCheckerChain -from langflow.custom import CustomComponent -from langflow.field_typing import LanguageModel, Text +from langflow.base.chains.model import LCChainComponent +from langflow.field_typing import Message +from langflow.inputs import MultilineInput, HandleInput -class LLMCheckerChainComponent(CustomComponent): +class LLMCheckerChainComponent(LCChainComponent): display_name = "LLMCheckerChain" - description = "" + description = "Chain for question-answering with self-verification." documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_checker" name = "LLMCheckerChain" - def build_config(self): - return { - "llm": {"display_name": "LLM"}, - "input_value": { - "display_name": "Input Value", - "info": "The input value to pass to the chain.", - }, - } + inputs = [ + MultilineInput( + name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True + ), + HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + ] - def build( - self, - input_value: Text, - llm: LanguageModel, - ) -> Text: - chain = LLMCheckerChain.from_llm(llm=llm) - response = chain.invoke({chain.input_key: input_value}) + def invoke_chain(self) -> Message: + chain = LLMCheckerChain.from_llm(llm=self.llm) + response = chain.invoke({chain.input_key: self.input_value}) result = response.get(chain.output_key, "") - result_str = str(result) - self.status = result_str - return result_str + result = str(result) + self.status = result + return Message(text=result) diff --git a/src/backend/base/langflow/components/chains/LLMMathChain.py b/src/backend/base/langflow/components/chains/LLMMathChain.py index b23ad6350..30fa21d66 100644 --- a/src/backend/base/langflow/components/chains/LLMMathChain.py +++ b/src/backend/base/langflow/components/chains/LLMMathChain.py @@ -1,48 +1,30 @@ -from typing import Optional +from langchain.chains import LLMMathChain -from langchain.chains import LLMChain, LLMMathChain - -from langflow.custom import CustomComponent -from langflow.field_typing import BaseMemory, LanguageModel, Text +from langflow.base.chains.model import LCChainComponent +from langflow.field_typing import Message +from langflow.inputs import MultilineInput, HandleInput +from langflow.template import Output -class LLMMathChainComponent(CustomComponent): +class LLMMathChainComponent(LCChainComponent): display_name = "LLMMathChain" description = "Chain that interprets a prompt and executes python code to do math." documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_math" name = "LLMMathChain" - def build_config(self): - return { - "llm": {"display_name": "LLM"}, - "llm_chain": {"display_name": "LLM Chain"}, - "memory": {"display_name": "Memory"}, - "input_key": {"display_name": "Input Key"}, - "output_key": {"display_name": "Output Key"}, - "input_value": { - "display_name": "Input Value", - "info": "The input value to pass to the chain.", - }, - } + inputs = [ + MultilineInput( + name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True + ), + HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + ] - def build( - self, - input_value: Text, - llm: LanguageModel, - llm_chain: LLMChain, - input_key: str = "question", - output_key: str = "answer", - memory: Optional[BaseMemory] = None, - ) -> Text: - chain = LLMMathChain( - llm=llm, - llm_chain=llm_chain, - input_key=input_key, - output_key=output_key, - memory=memory, - ) - response = chain.invoke({input_key: input_value}) - result = response.get(output_key) - result_str = str(result) - self.status = result_str - return result_str + outputs = [Output(display_name="Text", name="text", method="invoke_chain")] + + def invoke_chain(self) -> Message: + chain = LLMMathChain.from_llm(llm=self.llm) + response = chain.invoke({chain.input_key: self.input_value}) + result = response.get(chain.output_key, "") + result = str(result) + self.status = result + return Message(text=result) diff --git a/src/backend/base/langflow/components/chains/RetrievalQA.py b/src/backend/base/langflow/components/chains/RetrievalQA.py index a15b35853..070d67e13 100644 --- a/src/backend/base/langflow/components/chains/RetrievalQA.py +++ b/src/backend/base/langflow/components/chains/RetrievalQA.py @@ -1,69 +1,64 @@ -from typing import Optional +from langchain.chains import RetrievalQA -from langchain.chains.retrieval_qa.base import RetrievalQA -from langchain_core.documents import Document - -from langflow.custom import CustomComponent -from langflow.field_typing import BaseMemory, BaseRetriever, LanguageModel, Text -from langflow.schema import Data +from langflow.base.chains.model import LCChainComponent +from langflow.field_typing import Message +from langflow.inputs import HandleInput, MultilineInput, BoolInput, DropdownInput -class RetrievalQAComponent(CustomComponent): +class RetrievalQAComponent(LCChainComponent): display_name = "Retrieval QA" - description = "Chain for question-answering against an index." + description = "Chain for question-answering querying sources from a retriever." name = "RetrievalQA" - def build_config(self): - return { - "llm": {"display_name": "LLM"}, - "chain_type": {"display_name": "Chain Type", "options": ["Stuff", "Map Reduce", "Refine", "Map Rerank"]}, - "retriever": {"display_name": "Retriever"}, - "memory": {"display_name": "Memory", "required": False}, - "input_key": {"display_name": "Input Key", "advanced": True}, - "output_key": {"display_name": "Output Key", "advanced": True}, - "return_source_documents": {"display_name": "Return Source Documents"}, - "input_value": { - "display_name": "Input", - "input_types": ["Data", "Document"], - }, - } + inputs = [ + MultilineInput( + name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True + ), + DropdownInput( + name="chain_type", + display_name="Chain Type", + info="Chain type to use.", + options=["Stuff", "Map Reduce", "Refine", "Map Rerank"], + value="Stuff", + advanced=True, + ), + HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"], required=True), + HandleInput( + name="memory", + display_name="Memory", + input_types=["BaseChatMemory"], + ), + BoolInput( + name="return_source_documents", + display_name="Return Source Documents", + value=False, + ), + ] + + def invoke_chain(self) -> Message: + chain_type = self.chain_type.lower().replace(" ", "_") + if self.memory: + self.memory.input_key = "query" + self.memory.output_key = "result" - def build( - self, - llm: LanguageModel, - chain_type: str, - retriever: BaseRetriever, - input_value: str = "", - memory: Optional[BaseMemory] = None, - input_key: str = "query", - output_key: str = "result", - return_source_documents: bool = True, - ) -> Text: - chain_type = chain_type.lower().replace(" ", "_") runnable = RetrievalQA.from_chain_type( - llm=llm, + llm=self.llm, chain_type=chain_type, - retriever=retriever, - memory=memory, - input_key=input_key, - output_key=output_key, - return_source_documents=return_source_documents, + retriever=self.retriever, + memory=self.memory, + # always include to help debugging + # + return_source_documents=True, ) - if isinstance(input_value, Document): - input_value = input_value.page_content - if isinstance(input_value, Data): - input_value = input_value.get_text() - self.status = runnable - result = runnable.invoke({input_key: input_value}) - result = result.content if hasattr(result, "content") else result - # Result is a dict with keys "query", "result" and "source_documents" - # for now we just return the result - data = self.to_data(result.get("source_documents")) - references_str = "" - if return_source_documents: - references_str = self.create_references_from_data(data) - result_str = result.get("result", "") - final_result = "\n".join([str(result_str), references_str]) - self.status = final_result - return final_result # OK + result = runnable.invoke({"query": self.input_value}) + + source_docs = self.to_data(result.get("source_documents", [])) + result_str = str(result.get("result", "")) + if self.return_source_documents and len(source_docs): + references_str = self.create_references_from_data(source_docs) + result_str = "\n".join([result_str, references_str]) + # put the entire result to debug history, query and content + self.status = {**result, "source_documents": source_docs, "output": result_str} + return result_str diff --git a/src/backend/base/langflow/components/chains/RetrievalQAWithSourcesChain.py b/src/backend/base/langflow/components/chains/RetrievalQAWithSourcesChain.py deleted file mode 100644 index 41efb1cc9..000000000 --- a/src/backend/base/langflow/components/chains/RetrievalQAWithSourcesChain.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional - -from langchain.chains import RetrievalQAWithSourcesChain -from langchain_core.documents import Document - -from langflow.custom import CustomComponent -from langflow.field_typing import BaseMemory, BaseRetriever, LanguageModel, Text - - -class RetrievalQAWithSourcesChainComponent(CustomComponent): - display_name = "RetrievalQAWithSourcesChain" - description = "Question-answering with sources over an index." - name = "RetrievalQAWithSourcesChain" - - def build_config(self): - return { - "llm": {"display_name": "LLM"}, - "chain_type": { - "display_name": "Chain Type", - "options": ["Stuff", "Map Reduce", "Refine", "Map Rerank"], - "info": "The type of chain to use to combined Documents.", - }, - "memory": {"display_name": "Memory"}, - "return_source_documents": {"display_name": "Return Source Documents"}, - "retriever": {"display_name": "Retriever"}, - "input_value": { - "display_name": "Input Value", - "info": "The input value to pass to the chain.", - }, - } - - def build( - self, - input_value: Text, - retriever: BaseRetriever, - llm: LanguageModel, - chain_type: str, - memory: Optional[BaseMemory] = None, - return_source_documents: Optional[bool] = True, - ) -> Text: - chain_type = chain_type.lower().replace(" ", "_") - runnable = RetrievalQAWithSourcesChain.from_chain_type( - llm=llm, - chain_type=chain_type, - memory=memory, - return_source_documents=return_source_documents, - retriever=retriever, - ) - if isinstance(input_value, Document): - input_value = input_value.page_content - self.status = runnable - input_key = runnable.input_keys[0] - result = runnable.invoke({input_key: input_value}) - result = result.content if hasattr(result, "content") else result - # Result is a dict with keys "query", "result" and "source_documents" - # for now we just return the result - data = self.to_data(result.get("source_documents")) - references_str = "" - if return_source_documents: - references_str = self.create_references_from_data(data) - result_str = str(result.get("answer", "")) - final_result = "\n".join([result_str, references_str]) - self.status = final_result - return final_result diff --git a/src/backend/base/langflow/components/chains/SQLGenerator.py b/src/backend/base/langflow/components/chains/SQLGenerator.py index b81b59470..d5f917417 100644 --- a/src/backend/base/langflow/components/chains/SQLGenerator.py +++ b/src/backend/base/langflow/components/chains/SQLGenerator.py @@ -1,62 +1,49 @@ -from typing import Optional - from langchain.chains import create_sql_query_chain -from langchain_community.utilities.sql_database import SQLDatabase from langchain_core.prompts import PromptTemplate from langchain_core.runnables import Runnable - -from langflow.custom import CustomComponent -from langflow.field_typing import LanguageModel, Text +from langflow.base.chains.model import LCChainComponent +from langflow.field_typing import Message +from langflow.inputs import MultilineInput, HandleInput, IntInput +from langflow.template import Output -class SQLGeneratorComponent(CustomComponent): +class SQLGeneratorComponent(LCChainComponent): display_name = "Natural Language to SQL" description = "Generate SQL from natural language." name = "SQLGenerator" - def build_config(self): - return { - "db": {"display_name": "Database"}, - "llm": {"display_name": "LLM"}, - "prompt": { - "display_name": "Prompt", - "info": "The prompt must contain `{question}`.", - }, - "top_k": { - "display_name": "Top K", - "info": "The number of results per select statement to return. If 0, no limit.", - }, - "input_value": { - "display_name": "Input Value", - "info": "The input value to pass to the chain.", - }, - } + inputs = [ + MultilineInput( + name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True + ), + HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True), + HandleInput(name="db", display_name="SQLDatabase", input_types=["SQLDatabase"], required=True), + IntInput( + name="top_k", display_name="Top K", info="The number of results per select statement to return.", value=5 + ), + MultilineInput(name="prompt", display_name="Prompt", info="The prompt must contain `{question}`."), + ] - def build( - self, - input_value: Text, - db: SQLDatabase, - llm: LanguageModel, - top_k: int = 5, - prompt: Optional[Text] = None, - ) -> Text: - if prompt: - prompt_template = PromptTemplate.from_template(template=prompt) + outputs = [Output(display_name="Text", name="text", method="invoke_chain")] + + def invoke_chain(self) -> Message: + if self.prompt: + prompt_template = PromptTemplate.from_template(template=self.prompt) else: prompt_template = None - if top_k < 1: + if self.top_k < 1: raise ValueError("Top K must be greater than 0.") if not prompt_template: - sql_query_chain = create_sql_query_chain(llm=llm, db=db, k=top_k) + sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, k=self.top_k) else: # Check if {question} is in the prompt if "{question}" not in prompt_template.template or "question" not in prompt_template.input_variables: raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.") - sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt_template, k=top_k) + sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, prompt=prompt_template, k=self.top_k) query_writer: Runnable = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()} - response = query_writer.invoke({"question": input_value}) + response = query_writer.invoke({"question": self.input_value}) query = response.get("query") self.status = query return query diff --git a/src/backend/base/langflow/components/chains/__init__.py b/src/backend/base/langflow/components/chains/__init__.py index 365a80eb6..39faca5e9 100644 --- a/src/backend/base/langflow/components/chains/__init__.py +++ b/src/backend/base/langflow/components/chains/__init__.py @@ -1,17 +1,13 @@ from .ConversationChain import ConversationChainComponent -from .LLMChain import LLMChainComponent from .LLMCheckerChain import LLMCheckerChainComponent from .LLMMathChain import LLMMathChainComponent from .RetrievalQA import RetrievalQAComponent -from .RetrievalQAWithSourcesChain import RetrievalQAWithSourcesChainComponent from .SQLGenerator import SQLGeneratorComponent __all__ = [ "ConversationChainComponent", - "LLMChainComponent", "LLMCheckerChainComponent", "LLMMathChainComponent", "RetrievalQAComponent", - "RetrievalQAWithSourcesChainComponent", "SQLGeneratorComponent", ] diff --git a/src/backend/base/langflow/components/memories/AstraDBChatMemory.py b/src/backend/base/langflow/components/memories/AstraDBChatMemory.py new file mode 100644 index 000000000..15561badb --- /dev/null +++ b/src/backend/base/langflow/components/memories/AstraDBChatMemory.py @@ -0,0 +1,60 @@ +from langflow.base.memory.model import LCChatMemoryComponent +from langflow.inputs import MessageTextInput, StrInput, SecretStrInput +from langflow.field_typing import BaseChatMessageHistory + + +class AstraDBChatMemory(LCChatMemoryComponent): + display_name = "Astra DB Chat Memory" + description = "Retrieves and store chat messages from Astra DB." + name = "AstraDBChatMemory" + icon: str = "AstraDB" + + inputs = [ + StrInput( + name="collection_name", + display_name="Collection Name", + info="The name of the collection within Astra DB where the vectors will be stored.", + required=True, + ), + SecretStrInput( + name="token", + display_name="Astra DB Application Token", + info="Authentication token for accessing Astra DB.", + value="ASTRA_DB_APPLICATION_TOKEN", + required=True, + ), + SecretStrInput( + name="api_endpoint", + display_name="API Endpoint", + info="API endpoint URL for the Astra DB service.", + value="ASTRA_DB_API_ENDPOINT", + required=True, + ), + StrInput( + name="namespace", + display_name="Namespace", + info="Optional namespace within Astra DB to use for the collection.", + advanced=True, + ), + MessageTextInput( + name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True + ), + ] + + def build_message_history(self) -> BaseChatMessageHistory: + try: + from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory + except ImportError: + raise ImportError( + "Could not import langchain Astra DB integration package. " + "Please install it with `pip install langchain-astradb`." + ) + + memory = AstraDBChatMessageHistory( + session_id=self.session_id, + collection_name=self.collection_name, + token=self.token, + api_endpoint=self.api_endpoint, + namespace=self.namespace or None, + ) + return memory diff --git a/src/backend/base/langflow/components/memories/AstraDBMessageReader.py b/src/backend/base/langflow/components/memories/AstraDBMessageReader.py deleted file mode 100644 index 52cdccdca..000000000 --- a/src/backend/base/langflow/components/memories/AstraDBMessageReader.py +++ /dev/null @@ -1,97 +0,0 @@ -from typing import Optional, cast - -from langflow.base.memory.memory import BaseMemoryComponent -from langflow.schema import Data - - -class AstraDBMessageReaderComponent(BaseMemoryComponent): - display_name = "Astra DB Message Reader" - description = "Retrieves stored chat messages from Astra DB." - name = "AstraDBMessageReader" - - def build_config(self): - return { - "session_id": { - "display_name": "Session ID", - "info": "Session ID of the chat history.", - "input_types": ["Text"], - }, - "collection_name": { - "display_name": "Collection Name", - "info": "Collection name for Astra DB.", - "input_types": ["Text"], - }, - "token": { - "display_name": "Astra DB Application Token", - "info": "Token for the Astra DB instance.", - "password": True, - }, - "api_endpoint": { - "display_name": "Astra DB API Endpoint", - "info": "API Endpoint for the Astra DB instance.", - "password": True, - }, - "namespace": { - "display_name": "Namespace", - "info": "Namespace for the Astra DB instance.", - "input_types": ["Text"], - "advanced": True, - }, - } - - def get_messages(self, **kwargs) -> list[Data]: - """ - Retrieves messages from the AstraDBChatMessageHistory memory. - - Args: - memory (AstraDBChatMessageHistory): The AstraDBChatMessageHistory instance to retrieve messages from. - - Returns: - list[Data]: A list of Data objects representing the search results. - """ - try: - from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory - except ImportError: - raise ImportError( - "Could not import langchain Astra DB integration package. " - "Please install it with `pip install langchain-astradb`." - ) - - memory: AstraDBChatMessageHistory = cast(AstraDBChatMessageHistory, kwargs.get("memory")) - if not memory: - raise ValueError("AstraDBChatMessageHistory instance is required.") - - # Get messages from the memory - messages = memory.messages - results = [Data.from_lc_message(message) for message in messages] - - return list(results) - - def build( - self, - session_id: str, - collection_name: str, - token: str, - api_endpoint: str, - namespace: Optional[str] = None, - ) -> list[Data]: - try: - from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory - except ImportError: - raise ImportError( - "Could not import langchain Astra DB integration package. " - "Please install it with `pip install langchain-astradb`." - ) - - memory = AstraDBChatMessageHistory( - session_id=session_id, - collection_name=collection_name, - token=token, - api_endpoint=api_endpoint, - namespace=namespace, - ) - - data = self.get_messages(memory=memory) - self.status = data - - return data diff --git a/src/backend/base/langflow/components/memories/AstraDBMessageWriter.py b/src/backend/base/langflow/components/memories/AstraDBMessageWriter.py deleted file mode 100644 index b36b32882..000000000 --- a/src/backend/base/langflow/components/memories/AstraDBMessageWriter.py +++ /dev/null @@ -1,127 +0,0 @@ -from typing import Optional - -from langchain_core.messages import BaseMessage - -from langflow.base.memory.memory import BaseMemoryComponent -from langflow.schema import Data - - -class AstraDBMessageWriterComponent(BaseMemoryComponent): - display_name = "Astra DB Message Writer" - description = "Writes a message to Astra DB." - name = "AstraDBMessageWriter" - - def build_config(self): - return { - "input_value": { - "display_name": "Input Data", - "info": "Data to write to Astra DB.", - }, - "session_id": { - "display_name": "Session ID", - "info": "Session ID of the chat history.", - "input_types": ["Text"], - }, - "collection_name": { - "display_name": "Collection Name", - "info": "Collection name for Astra DB.", - "input_types": ["Text"], - }, - "token": { - "display_name": "Astra DB Application Token", - "info": "Token for the Astra DB instance.", - "password": True, - }, - "api_endpoint": { - "display_name": "Astra DB API Endpoint", - "info": "API Endpoint for the Astra DB instance.", - "password": True, - }, - "namespace": { - "display_name": "Namespace", - "info": "Namespace for the Astra DB instance.", - "input_types": ["Text"], - "advanced": True, - }, - } - - def add_message( - self, - sender: str, - sender_name: str, - text: str, - session_id: str, - metadata: Optional[dict] = None, - **kwargs, - ): - """ - Adds a message to the AstraDBChatMessageHistory memory. - - Args: - sender (str): The type of the message sender. Typically "ai" or "human". - sender_name (str): The name of the message sender. - text (str): The content of the message. - session_id (str): The session ID associated with the message. - metadata (dict | None, optional): Additional metadata for the message. Defaults to None. - **kwargs: Additional keyword arguments, including: - memory (AstraDBChatMessageHistory | None): The memory instance to add the message to. - - - Raises: - ValueError: If the AstraDBChatMessageHistory instance is not provided. - - """ - try: - from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory - except ImportError: - raise ImportError( - "Could not import langchain Astra DB integration package. " - "Please install it with `pip install langchain-astradb`." - ) - - memory: AstraDBChatMessageHistory | None = kwargs.pop("memory", None) - if memory is None: - raise ValueError("AstraDBChatMessageHistory instance is required.") - - text_list = [ - BaseMessage( - content=text, - sender=sender, - sender_name=sender_name, - metadata=metadata, - session_id=session_id, - type=sender, - ) - ] - - memory.add_messages(text_list) - - def build( - self, - input_value: Data, - session_id: str, - collection_name: str, - token: str, - api_endpoint: str, - namespace: Optional[str] = None, - ) -> Data: - try: - from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory - except ImportError: - raise ImportError( - "Could not import langchain Astra DB integration package. " - "Please install it with `pip install langchain-astradb`." - ) - - memory = AstraDBChatMessageHistory( - session_id=session_id, - collection_name=collection_name, - token=token, - api_endpoint=api_endpoint, - namespace=namespace, - ) - - self.add_message(**input_value.data, memory=memory) - self.status = f"Added message to Astra DB memory for session {session_id}" - - return input_value diff --git a/src/backend/base/langflow/components/memories/CassandraChatMemory.py b/src/backend/base/langflow/components/memories/CassandraChatMemory.py new file mode 100644 index 000000000..4891122ab --- /dev/null +++ b/src/backend/base/langflow/components/memories/CassandraChatMemory.py @@ -0,0 +1,93 @@ +from langflow.base.memory.model import LCChatMemoryComponent +from langflow.inputs import MessageTextInput, SecretStrInput, DictInput +from langflow.field_typing import BaseChatMessageHistory + + +class CassandraChatMemory(LCChatMemoryComponent): + display_name = "Cassandra Chat Memory" + description = "Retrieves and store chat messages from Apache Cassandra." + name = "CassandraChatMemory" + icon = "Cassandra" + + inputs = [ + MessageTextInput( + name="database_ref", + display_name="Contact Points / Astra Database ID", + info="Contact points for the database (or AstraDB database ID)", + required=True, + ), + MessageTextInput( + name="username", display_name="Username", info="Username for the database (leave empty for AstraDB)." + ), + SecretStrInput( + name="token", + display_name="Password / AstraDB Token", + info="User password for the database (or AstraDB token).", + required=True, + ), + MessageTextInput( + name="keyspace", + display_name="Keyspace", + info="Table Keyspace (or AstraDB namespace).", + required=True, + ), + MessageTextInput( + name="table_name", + display_name="Table Name", + info="The name of the table (or AstraDB collection) where vectors will be stored.", + required=True, + ), + MessageTextInput( + name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True + ), + DictInput( + name="cluster_kwargs", + display_name="Cluster arguments", + info="Optional dictionary of additional keyword arguments for the Cassandra cluster.", + advanced=True, + is_list=True, + ), + ] + + def build_message_history(self) -> BaseChatMessageHistory: + from langchain_community.chat_message_histories import CassandraChatMessageHistory + + try: + import cassio + except ImportError: + raise ImportError( + "Could not import cassio integration package. " "Please install it with `pip install cassio`." + ) + + from uuid import UUID + + database_ref = self.database_ref + + try: + UUID(self.database_ref) + is_astra = True + except ValueError: + is_astra = False + if "," in self.database_ref: + # use a copy because we can't change the type of the parameter + database_ref = self.database_ref.split(",") + + if is_astra: + cassio.init( + database_id=database_ref, + token=self.token, + cluster_kwargs=self.cluster_kwargs, + ) + else: + cassio.init( + contact_points=database_ref, + username=self.username, + password=self.token, + cluster_kwargs=self.cluster_kwargs, + ) + + return CassandraChatMessageHistory( + session_id=self.session_id, + table_name=self.table_name, + keyspace=self.keyspace, + ) diff --git a/src/backend/base/langflow/components/memories/CassandraMessageReader.py b/src/backend/base/langflow/components/memories/CassandraMessageReader.py deleted file mode 100644 index 889be0a9a..000000000 --- a/src/backend/base/langflow/components/memories/CassandraMessageReader.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Optional, cast - -from langchain_community.chat_message_histories import CassandraChatMessageHistory - -from langflow.base.memory.memory import BaseMemoryComponent -from langflow.schema.data import Data - - -class CassandraMessageReaderComponent(BaseMemoryComponent): - display_name = "Cassandra Message Reader" - description = "Retrieves stored chat messages from a Cassandra table on Astra DB." - name = "CassandraMessageReader" - - def build_config(self): - return { - "session_id": { - "display_name": "Session ID", - "info": "Session ID of the chat history.", - "input_types": ["Text"], - }, - "database_id": { - "display_name": "Database ID", - "info": "The Astra database ID.", - }, - "table_name": { - "display_name": "Table Name", - "info": "The name of the table where messages are stored.", - }, - "token": { - "display_name": "Token", - "info": "Authentication token for accessing Cassandra on Astra DB.", - "password": True, - }, - "keyspace": { - "display_name": "Keyspace", - "info": "Optional key space within Astra DB. The keyspace should already be created.", - "input_types": ["Text"], - "advanced": True, - }, - } - - def get_messages(self, **kwargs) -> list[Data]: - """ - Retrieves messages from the CassandraChatMessageHistory memory. - - Args: - memory (CassandraChatMessageHistory): The CassandraChatMessageHistory instance to retrieve messages from. - - Returns: - list[Data]: A list of Data objects representing the search results. - """ - memory: CassandraChatMessageHistory = cast(CassandraChatMessageHistory, kwargs.get("memory")) - if not memory: - raise ValueError("CassandraChatMessageHistory instance is required.") - - # Get messages from the memory - messages = memory.messages - results = [Data.from_lc_message(message) for message in messages] - - return list(results) - - def build( - self, - session_id: str, - table_name: str, - token: str, - database_id: str, - keyspace: Optional[str] = None, - ) -> list[Data]: - try: - import cassio - except ImportError: - raise ImportError( - "Could not import cassio integration package. " "Please install it with `pip install cassio`." - ) - - cassio.init(token=token, database_id=database_id) - memory = CassandraChatMessageHistory( - session_id=session_id, - table_name=table_name, - keyspace=keyspace, - ) - - data = self.get_messages(memory=memory) - self.status = data - - return data diff --git a/src/backend/base/langflow/components/memories/CassandraMessageWriter.py b/src/backend/base/langflow/components/memories/CassandraMessageWriter.py deleted file mode 100644 index 33ca0fed1..000000000 --- a/src/backend/base/langflow/components/memories/CassandraMessageWriter.py +++ /dev/null @@ -1,123 +0,0 @@ -from typing import Optional - -from langchain_community.chat_message_histories import CassandraChatMessageHistory -from langchain_core.messages import BaseMessage - -from langflow.base.memory.memory import BaseMemoryComponent -from langflow.schema.data import Data - - -class CassandraMessageWriterComponent(BaseMemoryComponent): - display_name = "Cassandra Message Writer" - description = "Writes a message to a Cassandra table on Astra DB." - name = "CassandraMessageWriter" - - def build_config(self): - return { - "input_value": { - "display_name": "Input Data", - "info": "Data to write to Cassandra.", - }, - "session_id": { - "display_name": "Session ID", - "info": "Session ID of the chat history.", - "input_types": ["Text"], - }, - "database_id": { - "display_name": "Database ID", - "info": "The Astra database ID.", - }, - "table_name": { - "display_name": "Table Name", - "info": "The name of the table where messages will be stored.", - }, - "token": { - "display_name": "Token", - "info": "Authentication token for accessing Cassandra on Astra DB.", - "password": True, - }, - "keyspace": { - "display_name": "Keyspace", - "info": "Optional key space within Astra DB. The keyspace should already be created.", - "input_types": ["Text"], - "advanced": True, - }, - "ttl_seconds": { - "display_name": "TTL Seconds", - "info": "Optional time-to-live for the messages.", - "input_types": ["Number"], - "advanced": True, - }, - } - - def add_message( - self, - sender: str, - sender_name: str, - text: str, - session_id: str, - metadata: Optional[dict] = None, - **kwargs, - ): - """ - Adds a message to the CassandraChatMessageHistory memory. - - Args: - sender (str): The type of the message sender. Typically "ai" or "human". - sender_name (str): The name of the message sender. - text (str): The content of the message. - session_id (str): The session ID associated with the message. - metadata (dict | None, optional): Additional metadata for the message. Defaults to None. - **kwargs: Additional keyword arguments, including: - memory (CassandraChatMessageHistory | None): The memory instance to add the message to. - - - Raises: - ValueError: If the CassandraChatMessageHistory instance is not provided. - - """ - memory: CassandraChatMessageHistory | None = kwargs.pop("memory", None) - if memory is None: - raise ValueError("CassandraChatMessageHistory instance is required.") - - text_list = [ - BaseMessage( - content=text, - sender=sender, - sender_name=sender_name, - metadata=metadata, - session_id=session_id, - ) - ] - - memory.add_messages(text_list) - - def build( - self, - input_value: Data, - session_id: str, - table_name: str, - token: str, - database_id: str, - keyspace: Optional[str] = None, - ttl_seconds: Optional[int] = None, - ) -> Data: - try: - import cassio - except ImportError: - raise ImportError( - "Could not import cassio integration package. " "Please install it with `pip install cassio`." - ) - - cassio.init(token=token, database_id=database_id) - memory = CassandraChatMessageHistory( - session_id=session_id, - table_name=table_name, - keyspace=keyspace, - ttl_seconds=ttl_seconds, - ) - - self.add_message(**input_value.data, memory=memory) - self.status = f"Added message to Cassandra memory for session {session_id}" - - return input_value diff --git a/src/backend/base/langflow/components/memories/ZepChatMemory.py b/src/backend/base/langflow/components/memories/ZepChatMemory.py new file mode 100644 index 000000000..36d740a52 --- /dev/null +++ b/src/backend/base/langflow/components/memories/ZepChatMemory.py @@ -0,0 +1,43 @@ +from langflow.base.memory.model import LCChatMemoryComponent +from langflow.inputs import MessageTextInput, SecretStrInput, DropdownInput +from langflow.field_typing import BaseChatMessageHistory + + +class ZepChatMemory(LCChatMemoryComponent): + display_name = "Zep Chat Memory" + description = "Retrieves and store chat messages from Zep." + name = "ZepChatMemory" + + inputs = [ + MessageTextInput(name="url", display_name="Zep URL", info="URL of the Zep instance."), + SecretStrInput(name="api_key", display_name="API Key", info="API Key for the Zep instance."), + DropdownInput( + name="api_base_path", + display_name="API Base Path", + options=["api/v1", "api/v2"], + value="api/v1", + advanced=True, + ), + MessageTextInput( + name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True + ), + ] + + def build_message_history(self) -> BaseChatMessageHistory: + try: + # Monkeypatch API_BASE_PATH to + # avoid 404 + # This is a workaround for the local Zep instance + # cloud Zep works with v2 + import zep_python.zep_client + from zep_python import ZepClient + from zep_python.langchain import ZepChatMessageHistory + + zep_python.zep_client.API_BASE_PATH = self.api_base_path + except ImportError: + raise ImportError( + "Could not import zep-python package. " "Please install it with `pip install zep-python`." + ) + + zep_client = ZepClient(api_url=self.url, api_key=self.api_key) + return ZepChatMessageHistory(session_id=self.session_id, zep_client=zep_client) diff --git a/src/backend/base/langflow/components/memories/ZepMessageReader.py b/src/backend/base/langflow/components/memories/ZepMessageReader.py deleted file mode 100644 index 935706794..000000000 --- a/src/backend/base/langflow/components/memories/ZepMessageReader.py +++ /dev/null @@ -1,151 +0,0 @@ -from typing import Optional, cast - -from langchain_community.chat_message_histories.zep import SearchScope, SearchType, ZepChatMessageHistory - -from langflow.base.memory.memory import BaseMemoryComponent -from langflow.field_typing import Text -from langflow.schema import Data - - -class ZepMessageReaderComponent(BaseMemoryComponent): - display_name = "Zep Message Reader" - description = "Retrieves stored chat messages from Zep." - name = "ZepMessageReader" - - def build_config(self): - return { - "session_id": { - "display_name": "Session ID", - "info": "Session ID of the chat history.", - "input_types": ["Text"], - }, - "url": { - "display_name": "Zep URL", - "info": "URL of the Zep instance.", - "input_types": ["Text"], - }, - "api_key": { - "display_name": "Zep API Key", - "info": "API Key for the Zep instance.", - "password": True, - }, - "query": { - "display_name": "Query", - "info": "Query to search for in the chat history.", - }, - "metadata": { - "display_name": "Metadata", - "info": "Optional metadata to attach to the message.", - "advanced": True, - }, - "search_scope": { - "options": ["Messages", "Summary"], - "display_name": "Search Scope", - "info": "Scope of the search.", - "advanced": True, - }, - "search_type": { - "options": ["Similarity", "MMR"], - "display_name": "Search Type", - "info": "Type of search.", - "advanced": True, - }, - "limit": { - "display_name": "Limit", - "info": "Limit of search results.", - "advanced": True, - }, - "api_base_path": { - "display_name": "API Base Path", - "options": ["api/v1", "api/v2"], - }, - } - - def get_messages(self, **kwargs) -> list[Data]: - """ - Retrieves messages from the ZepChatMessageHistory memory. - - If a query is provided, the search method is used to search for messages in the memory, otherwise all messages are returned. - - Args: - memory (ZepChatMessageHistory): The ZepChatMessageHistory instance to retrieve messages from. - query (str, optional): The query string to search for messages. Defaults to None. - metadata (dict, optional): Additional metadata to filter the search results. Defaults to None. - search_scope (str, optional): The scope of the search. Can be 'messages' or 'summary'. Defaults to 'messages'. - search_type (str, optional): The type of search. Can be 'similarity' or 'exact'. Defaults to 'similarity'. - limit (int, optional): The maximum number of search results to return. Defaults to None. - - Returns: - list[Data]: A list of Data objects representing the search results. - """ - memory: ZepChatMessageHistory = cast(ZepChatMessageHistory, kwargs.get("memory")) - if not memory: - raise ValueError("ZepChatMessageHistory instance is required.") - query = kwargs.get("query") - search_scope = kwargs.get("search_scope", SearchScope.messages).lower() - search_type = kwargs.get("search_type", SearchType.similarity).lower() - limit = kwargs.get("limit") - - if query: - memory_search_results = memory.search( - query, - search_scope=search_scope, - search_type=search_type, - limit=limit, - ) - # Get the messages from the search results if the search scope is messages - result_dicts = [] - for result in memory_search_results: - result_dict = {} - if search_scope == SearchScope.messages: - result_dict["text"] = result.message - else: - result_dict["text"] = result.summary - result_dict["metadata"] = result.metadata - result_dict["score"] = result.score - result_dicts.append(result_dict) - results = [Data(data=result_dict) for result_dict in result_dicts] - else: - messages = memory.messages - results = [Data.from_lc_message(message) for message in messages] - return results - - def build( - self, - session_id: Text, - api_base_path: str = "api/v1", - url: Optional[Text] = None, - api_key: Optional[Text] = None, - query: Optional[Text] = None, - search_scope: str = SearchScope.messages, - search_type: str = SearchType.similarity, - limit: Optional[int] = None, - ) -> list[Data]: - try: - # Monkeypatch API_BASE_PATH to - # avoid 404 - # This is a workaround for the local Zep instance - # cloud Zep works with v2 - import zep_python.zep_client - from zep_python import ZepClient - from zep_python.langchain import ZepChatMessageHistory - - zep_python.zep_client.API_BASE_PATH = api_base_path - except ImportError: - raise ImportError( - "Could not import zep-python package. " "Please install it with `pip install zep-python`." - ) - if url == "": - url = None - - zep_client = ZepClient(api_url=url, api_key=api_key) - memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client) - data = self.get_messages( - memory=memory, - query=query, - search_scope=search_scope, - search_type=search_type, - limit=limit, - ) - self.status = data - return data diff --git a/src/backend/base/langflow/components/memories/ZepMessageWriter.py b/src/backend/base/langflow/components/memories/ZepMessageWriter.py deleted file mode 100644 index 82382e701..000000000 --- a/src/backend/base/langflow/components/memories/ZepMessageWriter.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import TYPE_CHECKING, Optional - -from langflow.base.memory.memory import BaseMemoryComponent -from langflow.field_typing import Text -from langflow.schema import Data - -if TYPE_CHECKING: - from zep_python.langchain import ZepChatMessageHistory - - -class ZepMessageWriterComponent(BaseMemoryComponent): - display_name = "Zep Message Writer" - description = "Writes a message to Zep." - name = "ZepMessageWriter" - - def build_config(self): - return { - "session_id": { - "display_name": "Session ID", - "info": "Session ID of the chat history.", - "input_types": ["Text"], - }, - "url": { - "display_name": "Zep URL", - "info": "URL of the Zep instance.", - "input_types": ["Text"], - }, - "api_key": { - "display_name": "Zep API Key", - "info": "API Key for the Zep instance.", - "password": True, - }, - "limit": { - "display_name": "Limit", - "info": "Limit of search results.", - "advanced": True, - }, - "input_value": { - "display_name": "Input Data", - "info": "Data to write to Zep.", - }, - "api_base_path": { - "display_name": "API Base Path", - "options": ["api/v1", "api/v2"], - }, - } - - def add_message( - self, sender: Text, sender_name: Text, text: Text, session_id: Text, metadata: dict | None = None, **kwargs - ): - """ - Adds a message to the ZepChatMessageHistory memory. - - Args: - sender (Text): The type of the message sender. Valid values are "Machine" or "User". - sender_name (Text): The name of the message sender. - text (Text): The content of the message. - session_id (Text): The session ID associated with the message. - metadata (dict | None, optional): Additional metadata for the message. Defaults to None. - **kwargs: Additional keyword arguments. - - Raises: - ValueError: If the ZepChatMessageHistory instance is not provided. - - """ - memory: ZepChatMessageHistory | None = kwargs.pop("memory", None) - if memory is None: - raise ValueError("ZepChatMessageHistory instance is required.") - if metadata is None: - metadata = {} - metadata["sender_name"] = sender_name - metadata.update(kwargs) - if sender == "Machine": - memory.add_ai_message(text, metadata=metadata) - elif sender == "User": - memory.add_user_message(text, metadata=metadata) - else: - raise ValueError(f"Invalid sender type: {sender}") - - def build( - self, - input_value: Data, - session_id: Text, - api_base_path: str = "api/v1", - url: Optional[Text] = None, - api_key: Optional[Text] = None, - ) -> Data: - try: - # Monkeypatch API_BASE_PATH to - # avoid 404 - # This is a workaround for the local Zep instance - # cloud Zep works with v2 - import zep_python.zep_client - from zep_python import ZepClient - from zep_python.langchain import ZepChatMessageHistory - - zep_python.zep_client.API_BASE_PATH = api_base_path - except ImportError: - raise ImportError( - "Could not import zep-python package. " "Please install it with `pip install zep-python`." - ) - if url == "": - url = None - - zep_client = ZepClient(api_url=url, api_key=api_key) - memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client) - self.add_message(**input_value.data, memory=memory) - self.status = f"Added message to Zep memory for session {session_id}" - return input_value diff --git a/src/backend/base/langflow/field_typing/constants.py b/src/backend/base/langflow/field_typing/constants.py index a857de59e..9ac0a1a3f 100644 --- a/src/backend/base/langflow/field_typing/constants.py +++ b/src/backend/base/langflow/field_typing/constants.py @@ -3,6 +3,7 @@ from typing import Callable, Dict, Text, TypeAlias, TypeVar, Union from langchain.agents.agent import AgentExecutor from langchain.chains.base import Chain from langchain.memory.chat_memory import BaseChatMemory +from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.document_loaders import BaseLoader from langchain_core.documents import Document from langchain_core.embeddings import Embeddings @@ -58,6 +59,7 @@ LANGCHAIN_BASE_TYPES = { "BaseMemory": BaseMemory, "BaseChatMemory": BaseChatMemory, "BaseChatModel": BaseChatModel, + "BaseChatMessageHistory": BaseChatMessageHistory, } # Langchain base types plus Python base types CUSTOM_COMPONENT_SUPPORTED_TYPES = { diff --git a/src/frontend/tests/end-to-end/filterEdge-shard-0.spec.ts b/src/frontend/tests/end-to-end/filterEdge-shard-0.spec.ts index 9aef2bccb..07322b752 100644 --- a/src/frontend/tests/end-to-end/filterEdge-shard-0.spec.ts +++ b/src/frontend/tests/end-to-end/filterEdge-shard-0.spec.ts @@ -1,6 +1,6 @@ import { expect, test } from "@playwright/test"; -test("LLMChain - Tooltip", async ({ page }) => { +test("RetrievalQA - Tooltip", async ({ page }) => { await page.goto("/"); await page.waitForTimeout(1000); @@ -26,11 +26,11 @@ test("LLMChain - Tooltip", async ({ page }) => { }); await page.getByTestId("extended-disclosure").click(); await page.getByPlaceholder("Search").click(); - await page.getByPlaceholder("Search").fill("llmchain"); + await page.getByPlaceholder("Search").fill("retrievalqa"); await page.waitForTimeout(1000); await page - .getByTestId("chainsLLMChain") + .getByTestId("chainsRetrieval QA") .dragTo(page.locator('//*[@id="react-flow-id"]')); await page.mouse.up(); await page.mouse.down(); @@ -39,12 +39,12 @@ test("LLMChain - Tooltip", async ({ page }) => { await page.getByTitle("zoom out").click(); await page.getByTitle("zoom out").click(); - const llmChainOutputElements = await page - .getByTestId("handle-llmchain-shownode-text-right") + const outputElements = await page + .getByTestId("handle-retrievalqa-shownode-text-right") .all(); let visibleElementHandle; - for (const element of llmChainOutputElements) { + for (const element of outputElements) { if (await element.isVisible()) { visibleElementHandle = element; break; @@ -52,9 +52,15 @@ test("LLMChain - Tooltip", async ({ page }) => { } await visibleElementHandle.hover().then(async () => { + await expect( + page.getByTestId("available-output-inputs").first(), + ).toBeVisible(); await expect( page.getByTestId("available-output-chains").first(), ).toBeVisible(); + await expect( + page.getByTestId("available-output-textsplitters").first(), + ).toBeVisible(); await expect( page.getByTestId("available-output-retrievers").first(), ).toBeVisible(); @@ -62,13 +68,23 @@ test("LLMChain - Tooltip", async ({ page }) => { page.getByTestId("available-output-prototypes").first(), ).toBeVisible(); await expect( - page.getByTestId("available-output-tools").first(), + page.getByTestId("available-output-embeddings").first(), + ).toBeVisible(); + await expect( + page.getByTestId("available-output-data").first(), + ).toBeVisible(); + await expect( + page.getByTestId("available-output-vectorstores").first(), ).toBeVisible(); await expect( page.getByTestId("available-output-memories").first(), ).toBeVisible(); await expect( - page.getByTestId("available-output-toolkits").first(), + page.getByTestId("available-output-models").first(), + ).toBeVisible(); + + await expect( + page.getByTestId("available-output-outputs").first(), ).toBeVisible(); await expect( page.getByTestId("available-output-agents").first(), @@ -76,9 +92,6 @@ test("LLMChain - Tooltip", async ({ page }) => { await expect( page.getByTestId("available-output-helpers").first(), ).toBeVisible(); - await expect( - page.getByTestId("available-output-langchain_utilities").first(), - ).toBeVisible(); await page.getByTestId("icon-X").click(); await page.waitForTimeout(500); @@ -89,11 +102,11 @@ test("LLMChain - Tooltip", async ({ page }) => { await page.getByTitle("zoom out").click(); await page.getByTitle("zoom out").click(); - const llmChainInputElements1 = await page - .getByTestId("handle-llmchain-shownode-llm-left") + const rqaChainInputElements1 = await page + .getByTestId("handle-retrievalqa-shownode-language model-left") .all(); - for (const element of llmChainInputElements1) { + for (const element of rqaChainInputElements1) { if (await element.isVisible()) { visibleElementHandle = element; break; @@ -115,11 +128,11 @@ test("LLMChain - Tooltip", async ({ page }) => { await page.getByTitle("zoom out").click(); await page.getByTitle("zoom out").click(); - const llmChainInputElements0 = await page - .getByTestId("handle-llmchain-shownode-template-left") + const rqaChainInputElements0 = await page + .getByTestId("handle-retrievalqa-shownode-retriever-left") .all(); - for (const element of llmChainInputElements0) { + for (const element of rqaChainInputElements0) { if (await element.isVisible()) { visibleElementHandle = element; break; @@ -130,16 +143,10 @@ test("LLMChain - Tooltip", async ({ page }) => { await page.waitForTimeout(2000); await expect( - page.getByTestId("available-input-chains").first(), + page.getByTestId("available-input-retrievers").first(), ).toBeVisible(); await expect( - page.getByTestId("available-input-prototypes").first(), - ).toBeVisible(); - await expect( - page.getByTestId("available-input-agents").first(), - ).toBeVisible(); - await expect( - page.getByTestId("available-input-helpers").first(), + page.getByTestId("available-input-vectorstores").first(), ).toBeVisible(); await page.waitForTimeout(500); @@ -150,11 +157,11 @@ test("LLMChain - Tooltip", async ({ page }) => { await page.getByTitle("zoom out").click(); await page.getByTitle("zoom out").click(); - const llmChainInputElements2 = await page - .getByTestId("handle-llmchain-shownode-memory-left") + const rqaChainInputElements2 = await page + .getByTestId("handle-retrievalqa-shownode-memory-left") .all(); - for (const element of llmChainInputElements2) { + for (const element of rqaChainInputElements2) { if (await element.isVisible()) { visibleElementHandle = element; break; @@ -163,7 +170,7 @@ test("LLMChain - Tooltip", async ({ page }) => { await visibleElementHandle.hover().then(async () => { await expect( - page.getByTestId("empty-tooltip-filter").first(), + page.getByTestId("available-input-memories").first(), ).toBeVisible(); }); }); diff --git a/src/frontend/tests/end-to-end/filterEdge-shard-1.spec.ts b/src/frontend/tests/end-to-end/filterEdge-shard-1.spec.ts index af72e5d78..8274a0797 100644 --- a/src/frontend/tests/end-to-end/filterEdge-shard-1.spec.ts +++ b/src/frontend/tests/end-to-end/filterEdge-shard-1.spec.ts @@ -1,6 +1,6 @@ import { expect, test } from "@playwright/test"; -test("LLMChain - Filter", async ({ page }) => { +test("RetrievalQA - Filter", async ({ page }) => { await page.goto("/"); await page.waitForTimeout(2000); @@ -31,11 +31,11 @@ test("LLMChain - Filter", async ({ page }) => { }); await page.getByTestId("extended-disclosure").click(); await page.getByPlaceholder("Search").click(); - await page.getByPlaceholder("Search").fill("llmchain"); + await page.getByPlaceholder("Search").fill("retrievalqa"); await page.waitForTimeout(1000); await page - .getByTestId("chainsLLMChain") + .getByTestId("chainsRetrieval QA") .dragTo(page.locator('//*[@id="react-flow-id"]')); await page.mouse.up(); await page.mouse.down(); @@ -47,11 +47,11 @@ test("LLMChain - Filter", async ({ page }) => { let visibleElementHandle; - const llmChainOutputElements = await page - .getByTestId("handle-llmchain-shownode-text-right") + const outputElements = await page + .getByTestId("handle-retrievalqa-shownode-text-right") .all(); - for (const element of llmChainOutputElements) { + for (const element of outputElements) { if (await element.isVisible()) { visibleElementHandle = element; break; @@ -62,38 +62,47 @@ test("LLMChain - Filter", async ({ page }) => { force: true, }); + await expect(page.getByTestId("disclosure-inputs")).toBeVisible(); + await expect(page.getByTestId("disclosure-outputs")).toBeVisible(); + await expect(page.getByTestId("disclosure-data")).toBeVisible(); + await expect(page.getByTestId("disclosure-models")).toBeVisible(); await expect(page.getByTestId("disclosure-helpers")).toBeVisible(); + await expect(page.getByTestId("disclosure-vector stores")).toBeVisible(); + await expect(page.getByTestId("disclosure-embeddings")).toBeVisible(); await expect(page.getByTestId("disclosure-agents")).toBeVisible(); await expect(page.getByTestId("disclosure-chains")).toBeVisible(); - await expect(page.getByTestId("disclosure-utilities")).toBeVisible(); await expect(page.getByTestId("disclosure-memories")).toBeVisible(); await expect(page.getByTestId("disclosure-prototypes")).toBeVisible(); await expect(page.getByTestId("disclosure-retrievers")).toBeVisible(); - await expect(page.getByTestId("disclosure-toolkits")).toBeVisible(); - await expect(page.getByTestId("disclosure-tools")).toBeVisible(); + await expect(page.getByTestId("disclosure-text splitters")).toBeVisible(); - await expect(page.getByTestId("helpersID Generator").first()).toBeVisible(); - - await expect(page.getByTestId("agentsCSVAgent").first()).toBeVisible(); - - await expect(page.getByTestId("chainsLLMChain").first()).toBeVisible(); + await expect(page.getByTestId("inputsChat Input").first()).toBeVisible(); + await expect(page.getByTestId("outputsChat Output").first()).toBeVisible(); + await expect(page.getByTestId("dataAPI Request").first()).toBeVisible(); + await expect(page.getByTestId("modelsAmazon Bedrock").first()).toBeVisible(); + await expect(page.getByTestId("helpersChat Memory").first()).toBeVisible(); + await expect(page.getByTestId("vectorstoresAstra DB").first()).toBeVisible(); await expect( - page.getByTestId("langchain_utilitiesSearchApi").first(), + page.getByTestId("embeddingsAmazon Bedrock Embeddings").first(), ).toBeVisible(); await expect( - page.getByTestId("memoriesAstra DB Message Reader").first(), + page.getByTestId("agentsTool Calling Agent").first(), ).toBeVisible(); await expect( - page.getByTestId("prototypesFlow as Tool").first(), + page.getByTestId("chainsConversationChain").first(), ).toBeVisible(); await expect( - page.getByTestId("retrieversAmazon Kendra Retriever").first(), + page.getByTestId("memoriesAstra DB Chat Memory").first(), ).toBeVisible(); - await expect( - page.getByTestId("toolkitsVectorStoreInfo").first(), + page.getByTestId("prototypesConditional Router").first(), + ).toBeVisible(); + await expect( + page.getByTestId("retrieversSelf Query Retriever").first(), + ).toBeVisible(); + await expect( + page.getByTestId("textsplittersCharacterTextSplitter").first(), ).toBeVisible(); - await expect(page.getByTestId("toolsSearchApi").first()).toBeVisible(); await page.getByPlaceholder("Search").click(); @@ -110,11 +119,11 @@ test("LLMChain - Filter", async ({ page }) => { await expect(page.getByTestId("model_specsChatOpenAI")).not.toBeVisible(); await expect(page.getByTestId("model_specsChatVertexAI")).not.toBeVisible(); - const llmChainInputElements1 = await page - .getByTestId("handle-llmchain-shownode-llm-left") + const chainInputElements1 = await page + .getByTestId("handle-retrievalqa-shownode-llm-left") .all(); - for (const element of llmChainInputElements1) { + for (const element of chainInputElements1) { if (await element.isVisible()) { visibleElementHandle = element; break; @@ -129,11 +138,11 @@ test("LLMChain - Filter", async ({ page }) => { await expect(page.getByTestId("disclosure-models")).toBeVisible(); - const llmChainInputElements0 = await page - .getByTestId("handle-llmchain-shownode-template-left") + const rqaChainInputElements0 = await page + .getByTestId("handle-retrievalqa-shownode-template-left") .all(); - for (const element of llmChainInputElements0) { + for (const element of rqaChainInputElements0) { if (await element.isVisible()) { visibleElementHandle = element; break; diff --git a/src/frontend/tests/end-to-end/filterSidebar.spec.ts b/src/frontend/tests/end-to-end/filterSidebar.spec.ts index c8c3b1df1..9b5273fbc 100644 --- a/src/frontend/tests/end-to-end/filterSidebar.spec.ts +++ b/src/frontend/tests/end-to-end/filterSidebar.spec.ts @@ -30,11 +30,11 @@ test("LLMChain - Filter", async ({ page }) => { await page.getByTestId("extended-disclosure").click(); await page.getByPlaceholder("Search").click(); - await page.getByPlaceholder("Search").fill("llmchain"); + await page.getByPlaceholder("Search").fill("api request"); - await page.waitForTimeout(1000); + await page.waitForTimeout(2000); await page - .getByTestId("chainsLLMChain") + .getByTestId("dataAPI Request") .dragTo(page.locator('//*[@id="react-flow-id"]')); await page.mouse.up(); await page.mouse.down(); @@ -44,65 +44,77 @@ test("LLMChain - Filter", async ({ page }) => { await page.getByTitle("zoom out").click(); await page.waitForTimeout(500); - await page - .locator( - '//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[7]/button/div[1]', - ) - .click(); - - await expect(page.getByTestId("helpersID Generator")).toBeVisible(); - await expect(page.getByTestId("disclosure-agents")).toBeVisible(); - - await expect(page.getByTestId("chainsLLMChain").first()).toBeVisible(); - await expect( - page.getByTestId("langchain_utilitiesSearchApi").first(), - ).toBeVisible(); - await expect( - page.getByTestId("memoriesAstra DB Message Reader").first(), - ).toBeVisible(); - await expect( - page.getByTestId("prototypesFlow as Tool").first(), - ).toBeVisible(); - await expect( - page.getByTestId("retrieversAmazon Kendra Retriever").first(), - ).toBeVisible(); - - await expect( - page.getByTestId("toolkitsVectorStoreInfo").first(), - ).toBeVisible(); - await expect(page.getByTestId("toolsSearchApi").first()).toBeVisible(); - - await page.getByPlaceholder("Search").click(); - - await expect(page.getByTestId("model_specsVertexAI")).not.toBeVisible(); - await expect(page.getByTestId("model_specsCTransformers")).not.toBeVisible(); - await expect(page.getByTestId("model_specsAmazon Bedrock")).not.toBeVisible(); - await expect(page.getByTestId("modelsAzure OpenAI")).not.toBeVisible(); - await expect( - page.getByTestId("model_specsAzureChatOpenAI"), - ).not.toBeVisible(); - await expect(page.getByTestId("model_specsChatAnthropic")).not.toBeVisible(); - await expect(page.getByTestId("model_specsChatLiteLLM")).not.toBeVisible(); - await expect(page.getByTestId("model_specsChatOllama")).not.toBeVisible(); - await expect(page.getByTestId("model_specsChatOpenAI")).not.toBeVisible(); - await expect(page.getByTestId("model_specsChatVertexAI")).not.toBeVisible(); - - await page - .locator( - '//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[4]/div/button/div[1]', - ) - .click(); + await page.getByTestId("handle-apirequest-shownode-urls-left").click(); + await expect(page.getByTestId("disclosure-inputs")).toBeVisible(); + await expect(page.getByTestId("disclosure-outputs")).toBeVisible(); + await expect(page.getByTestId("disclosure-prompts")).toBeVisible(); await expect(page.getByTestId("disclosure-models")).toBeVisible(); - - await page - .locator( - '//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[3]/div/button/div[1]', - ) - .click(); - await expect(page.getByTestId("disclosure-helpers")).toBeVisible(); await expect(page.getByTestId("disclosure-agents")).toBeVisible(); await expect(page.getByTestId("disclosure-chains")).toBeVisible(); await expect(page.getByTestId("disclosure-prototypes")).toBeVisible(); + + await expect(page.getByTestId("inputsChat Input")).toBeVisible(); + await expect(page.getByTestId("outputsChat Output")).toBeVisible(); + await expect(page.getByTestId("promptsPrompt")).toBeVisible(); + await expect(page.getByTestId("modelsAmazon Bedrock")).toBeVisible(); + await expect(page.getByTestId("helpersChat Memory")).toBeVisible(); + await expect(page.getByTestId("agentsTool Calling Agent")).toBeVisible(); + await expect(page.getByTestId("chainsConversationChain")).toBeVisible(); + await expect(page.getByTestId("prototypesConditional Router")).toBeVisible(); + + await page.getByPlaceholder("Search").click(); + + await expect(page.getByTestId("inputsChat Input")).not.toBeVisible(); + await expect(page.getByTestId("outputsChat Output")).not.toBeVisible(); + await expect(page.getByTestId("promptsPrompt")).not.toBeVisible(); + await expect(page.getByTestId("modelsAmazon Bedrock")).not.toBeVisible(); + await expect(page.getByTestId("helpersChat Memory")).not.toBeVisible(); + await expect(page.getByTestId("agentsTool Calling Agent")).not.toBeVisible(); + await expect(page.getByTestId("chainsConversationChain")).not.toBeVisible(); + await expect( + page.getByTestId("prototypesConditional Router"), + ).not.toBeVisible(); + + await page.getByTestId("handle-apirequest-shownode-headers-left").click(); + + await expect(page.getByTestId("disclosure-data")).toBeVisible(); + await expect(page.getByTestId("disclosure-helpers")).toBeVisible(); + await expect(page.getByTestId("disclosure-vector stores")).toBeVisible(); + await expect(page.getByTestId("disclosure-utilities")).toBeVisible(); + await expect(page.getByTestId("disclosure-prototypes")).toBeVisible(); + await expect(page.getByTestId("disclosure-retrievers")).toBeVisible(); + await expect(page.getByTestId("disclosure-text splitters")).toBeVisible(); + await expect(page.getByTestId("disclosure-tools")).toBeVisible(); + + await expect(page.getByTestId("dataAPI Request")).toBeVisible(); + await expect(page.getByTestId("helpersChat Memory")).toBeVisible(); + await expect(page.getByTestId("vectorstoresAstra DB")).toBeVisible(); + await expect(page.getByTestId("langchain_utilitiesSearchApi")).toBeVisible(); + await expect(page.getByTestId("prototypesSub Flow")).toBeVisible(); + await expect( + page.getByTestId("retrieversSelf Query Retriever"), + ).toBeVisible(); + await expect( + page.getByTestId("textsplittersCharacterTextSplitter"), + ).toBeVisible(); + await expect(page.getByTestId("toolsSearchApi")).toBeVisible(); + + await page.getByPlaceholder("Search").click(); + + await expect(page.getByTestId("dataAPI Request")).not.toBeVisible(); + await expect(page.getByTestId("helpersChat Memory")).not.toBeVisible(); + await expect(page.getByTestId("vectorstoresAstra DB")).not.toBeVisible(); + await expect( + page.getByTestId("langchain_utilitiesSearchApi"), + ).not.toBeVisible(); + await expect(page.getByTestId("prototypesSub Flow")).not.toBeVisible(); + await expect( + page.getByTestId("retrieversSelf Query Retriever"), + ).not.toBeVisible(); + await expect( + page.getByTestId("textsplittersCharacterTextSplitter"), + ).not.toBeVisible(); + await expect(page.getByTestId("toolsSearchApi")).not.toBeVisible(); });