feat: migrate chains and memories to Component syntax (#2528)

* feat: migrate chains and memories to Component syntax

* use base class

* add classes

* [autofix.ci] apply automated fixes

* fix tests

* fix tests

*  (filterSidebar.spec.ts): increase waitForTimeout from 1000ms to 2000ms to ensure elements are fully loaded before interaction

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
Co-authored-by: Cristhian Zanforlin Lousa <72977554+Cristhianzl@users.noreply.github.com>
This commit is contained in:
Nicolò Boschi 2024-07-05 18:30:23 +02:00 committed by GitHub
commit eb3420523e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 531 additions and 1096 deletions

View file

@ -0,0 +1,17 @@
from langflow.custom import Component
from langflow.template import Output
class LCChainComponent(Component):
trace_type = "chain"
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
def _validate_outputs(self):
required_output_methods = ["invoke_chain"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
if method_name not in output_names:
raise ValueError(f"Output with name '{method_name}' must be defined.")
elif not hasattr(self, method_name):
raise ValueError(f"Method '{method_name}' must be defined.")

View file

@ -0,0 +1,35 @@
from abc import abstractmethod
from langflow.custom import Component
from langflow.field_typing import BaseChatMessageHistory, BaseChatMemory
from langflow.template import Output
from langchain.memory import ConversationBufferMemory
class LCChatMemoryComponent(Component):
trace_type = "chat_memory"
outputs = [
Output(
display_name="Memory",
name="base_memory",
method="build_base_memory",
)
]
def _validate_outputs(self):
required_output_methods = ["build_base_memory"]
output_names = [output.name for output in self.outputs]
for method_name in required_output_methods:
if method_name not in output_names:
raise ValueError(f"Output with name '{method_name}' must be defined.")
elif not hasattr(self, method_name):
raise ValueError(f"Method '{method_name}' must be defined.")
def build_base_memory(self) -> BaseChatMemory:
return ConversationBufferMemory(chat_memory=self.build_message_history())
@abstractmethod
def build_message_history(self) -> BaseChatMessageHistory:
"""
Builds the chat message history memory.
"""

View file

@ -1,41 +1,34 @@
from typing import Optional
from langchain.chains import ConversationChain
from langflow.custom import CustomComponent
from langflow.field_typing import BaseMemory, LanguageModel, Text
from langflow.base.chains.model import LCChainComponent
from langflow.field_typing import Message
from langflow.inputs import MultilineInput, HandleInput
class ConversationChainComponent(CustomComponent):
class ConversationChainComponent(LCChainComponent):
display_name = "ConversationChain"
description = "Chain to have a conversation and load context from memory."
name = "ConversationChain"
def build_config(self):
return {
"prompt": {"display_name": "Prompt"},
"llm": {"display_name": "LLM"},
"memory": {
"display_name": "Memory",
"info": "Memory to load context from. If none is provided, a ConversationBufferMemory will be used.",
},
"input_value": {
"display_name": "Input Value",
"info": "The input value to pass to the chain.",
},
}
inputs = [
MultilineInput(
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
HandleInput(
name="memory",
display_name="Memory",
input_types=["BaseChatMemory"],
),
]
def build(
self,
input_value: Text,
llm: LanguageModel,
memory: Optional[BaseMemory] = None,
) -> Text:
if memory is None:
chain = ConversationChain(llm=llm)
def invoke_chain(self) -> Message:
if not self.memory:
chain = ConversationChain(llm=self.llm)
else:
chain = ConversationChain(llm=llm, memory=memory)
result = chain.invoke({"input": input_value})
chain = ConversationChain(llm=self.llm, memory=self.memory)
result = chain.invoke({"input": self.input_value})
if isinstance(result, dict):
result = result.get(chain.output_key, "") # type: ignore
@ -43,5 +36,6 @@ class ConversationChainComponent(CustomComponent):
result = result
else:
result = result.get("response")
result = str(result)
self.status = result
return str(result)
return Message(text=result)

View file

@ -1,34 +0,0 @@
from typing import Optional
from langchain.chains.llm import LLMChain
from langchain_core.prompts import PromptTemplate
from langflow.custom import CustomComponent
from langflow.field_typing import BaseMemory, LanguageModel, Text
class LLMChainComponent(CustomComponent):
display_name = "LLMChain"
description = "Chain to run queries against LLMs"
name = "LLMChain"
def build_config(self):
return {
"prompt": {"display_name": "Prompt"},
"llm": {"display_name": "LLM"},
"memory": {"display_name": "Memory"},
}
def build(
self,
template: Text,
llm: LanguageModel,
memory: Optional[BaseMemory] = None,
) -> Text:
prompt = PromptTemplate.from_template(template)
runnable = LLMChain(prompt=prompt, llm=llm, memory=memory)
result_dict = runnable.invoke({})
output_key = runnable.output_key
result = result_dict[output_key]
self.status = result
return result

View file

@ -1,32 +1,27 @@
from langchain.chains import LLMCheckerChain
from langflow.custom import CustomComponent
from langflow.field_typing import LanguageModel, Text
from langflow.base.chains.model import LCChainComponent
from langflow.field_typing import Message
from langflow.inputs import MultilineInput, HandleInput
class LLMCheckerChainComponent(CustomComponent):
class LLMCheckerChainComponent(LCChainComponent):
display_name = "LLMCheckerChain"
description = ""
description = "Chain for question-answering with self-verification."
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
name = "LLMCheckerChain"
def build_config(self):
return {
"llm": {"display_name": "LLM"},
"input_value": {
"display_name": "Input Value",
"info": "The input value to pass to the chain.",
},
}
inputs = [
MultilineInput(
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
]
def build(
self,
input_value: Text,
llm: LanguageModel,
) -> Text:
chain = LLMCheckerChain.from_llm(llm=llm)
response = chain.invoke({chain.input_key: input_value})
def invoke_chain(self) -> Message:
chain = LLMCheckerChain.from_llm(llm=self.llm)
response = chain.invoke({chain.input_key: self.input_value})
result = response.get(chain.output_key, "")
result_str = str(result)
self.status = result_str
return result_str
result = str(result)
self.status = result
return Message(text=result)

View file

@ -1,48 +1,30 @@
from typing import Optional
from langchain.chains import LLMMathChain
from langchain.chains import LLMChain, LLMMathChain
from langflow.custom import CustomComponent
from langflow.field_typing import BaseMemory, LanguageModel, Text
from langflow.base.chains.model import LCChainComponent
from langflow.field_typing import Message
from langflow.inputs import MultilineInput, HandleInput
from langflow.template import Output
class LLMMathChainComponent(CustomComponent):
class LLMMathChainComponent(LCChainComponent):
display_name = "LLMMathChain"
description = "Chain that interprets a prompt and executes python code to do math."
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_math"
name = "LLMMathChain"
def build_config(self):
return {
"llm": {"display_name": "LLM"},
"llm_chain": {"display_name": "LLM Chain"},
"memory": {"display_name": "Memory"},
"input_key": {"display_name": "Input Key"},
"output_key": {"display_name": "Output Key"},
"input_value": {
"display_name": "Input Value",
"info": "The input value to pass to the chain.",
},
}
inputs = [
MultilineInput(
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
]
def build(
self,
input_value: Text,
llm: LanguageModel,
llm_chain: LLMChain,
input_key: str = "question",
output_key: str = "answer",
memory: Optional[BaseMemory] = None,
) -> Text:
chain = LLMMathChain(
llm=llm,
llm_chain=llm_chain,
input_key=input_key,
output_key=output_key,
memory=memory,
)
response = chain.invoke({input_key: input_value})
result = response.get(output_key)
result_str = str(result)
self.status = result_str
return result_str
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
def invoke_chain(self) -> Message:
chain = LLMMathChain.from_llm(llm=self.llm)
response = chain.invoke({chain.input_key: self.input_value})
result = response.get(chain.output_key, "")
result = str(result)
self.status = result
return Message(text=result)

View file

@ -1,69 +1,64 @@
from typing import Optional
from langchain.chains import RetrievalQA
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_core.documents import Document
from langflow.custom import CustomComponent
from langflow.field_typing import BaseMemory, BaseRetriever, LanguageModel, Text
from langflow.schema import Data
from langflow.base.chains.model import LCChainComponent
from langflow.field_typing import Message
from langflow.inputs import HandleInput, MultilineInput, BoolInput, DropdownInput
class RetrievalQAComponent(CustomComponent):
class RetrievalQAComponent(LCChainComponent):
display_name = "Retrieval QA"
description = "Chain for question-answering against an index."
description = "Chain for question-answering querying sources from a retriever."
name = "RetrievalQA"
def build_config(self):
return {
"llm": {"display_name": "LLM"},
"chain_type": {"display_name": "Chain Type", "options": ["Stuff", "Map Reduce", "Refine", "Map Rerank"]},
"retriever": {"display_name": "Retriever"},
"memory": {"display_name": "Memory", "required": False},
"input_key": {"display_name": "Input Key", "advanced": True},
"output_key": {"display_name": "Output Key", "advanced": True},
"return_source_documents": {"display_name": "Return Source Documents"},
"input_value": {
"display_name": "Input",
"input_types": ["Data", "Document"],
},
}
inputs = [
MultilineInput(
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
DropdownInput(
name="chain_type",
display_name="Chain Type",
info="Chain type to use.",
options=["Stuff", "Map Reduce", "Refine", "Map Rerank"],
value="Stuff",
advanced=True,
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"], required=True),
HandleInput(
name="memory",
display_name="Memory",
input_types=["BaseChatMemory"],
),
BoolInput(
name="return_source_documents",
display_name="Return Source Documents",
value=False,
),
]
def invoke_chain(self) -> Message:
chain_type = self.chain_type.lower().replace(" ", "_")
if self.memory:
self.memory.input_key = "query"
self.memory.output_key = "result"
def build(
self,
llm: LanguageModel,
chain_type: str,
retriever: BaseRetriever,
input_value: str = "",
memory: Optional[BaseMemory] = None,
input_key: str = "query",
output_key: str = "result",
return_source_documents: bool = True,
) -> Text:
chain_type = chain_type.lower().replace(" ", "_")
runnable = RetrievalQA.from_chain_type(
llm=llm,
llm=self.llm,
chain_type=chain_type,
retriever=retriever,
memory=memory,
input_key=input_key,
output_key=output_key,
return_source_documents=return_source_documents,
retriever=self.retriever,
memory=self.memory,
# always include to help debugging
#
return_source_documents=True,
)
if isinstance(input_value, Document):
input_value = input_value.page_content
if isinstance(input_value, Data):
input_value = input_value.get_text()
self.status = runnable
result = runnable.invoke({input_key: input_value})
result = result.content if hasattr(result, "content") else result
# Result is a dict with keys "query", "result" and "source_documents"
# for now we just return the result
data = self.to_data(result.get("source_documents"))
references_str = ""
if return_source_documents:
references_str = self.create_references_from_data(data)
result_str = result.get("result", "")
final_result = "\n".join([str(result_str), references_str])
self.status = final_result
return final_result # OK
result = runnable.invoke({"query": self.input_value})
source_docs = self.to_data(result.get("source_documents", []))
result_str = str(result.get("result", ""))
if self.return_source_documents and len(source_docs):
references_str = self.create_references_from_data(source_docs)
result_str = "\n".join([result_str, references_str])
# put the entire result to debug history, query and content
self.status = {**result, "source_documents": source_docs, "output": result_str}
return result_str

View file

@ -1,64 +0,0 @@
from typing import Optional
from langchain.chains import RetrievalQAWithSourcesChain
from langchain_core.documents import Document
from langflow.custom import CustomComponent
from langflow.field_typing import BaseMemory, BaseRetriever, LanguageModel, Text
class RetrievalQAWithSourcesChainComponent(CustomComponent):
display_name = "RetrievalQAWithSourcesChain"
description = "Question-answering with sources over an index."
name = "RetrievalQAWithSourcesChain"
def build_config(self):
return {
"llm": {"display_name": "LLM"},
"chain_type": {
"display_name": "Chain Type",
"options": ["Stuff", "Map Reduce", "Refine", "Map Rerank"],
"info": "The type of chain to use to combined Documents.",
},
"memory": {"display_name": "Memory"},
"return_source_documents": {"display_name": "Return Source Documents"},
"retriever": {"display_name": "Retriever"},
"input_value": {
"display_name": "Input Value",
"info": "The input value to pass to the chain.",
},
}
def build(
self,
input_value: Text,
retriever: BaseRetriever,
llm: LanguageModel,
chain_type: str,
memory: Optional[BaseMemory] = None,
return_source_documents: Optional[bool] = True,
) -> Text:
chain_type = chain_type.lower().replace(" ", "_")
runnable = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type=chain_type,
memory=memory,
return_source_documents=return_source_documents,
retriever=retriever,
)
if isinstance(input_value, Document):
input_value = input_value.page_content
self.status = runnable
input_key = runnable.input_keys[0]
result = runnable.invoke({input_key: input_value})
result = result.content if hasattr(result, "content") else result
# Result is a dict with keys "query", "result" and "source_documents"
# for now we just return the result
data = self.to_data(result.get("source_documents"))
references_str = ""
if return_source_documents:
references_str = self.create_references_from_data(data)
result_str = str(result.get("answer", ""))
final_result = "\n".join([result_str, references_str])
self.status = final_result
return final_result

View file

@ -1,62 +1,49 @@
from typing import Optional
from langchain.chains import create_sql_query_chain
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable
from langflow.custom import CustomComponent
from langflow.field_typing import LanguageModel, Text
from langflow.base.chains.model import LCChainComponent
from langflow.field_typing import Message
from langflow.inputs import MultilineInput, HandleInput, IntInput
from langflow.template import Output
class SQLGeneratorComponent(CustomComponent):
class SQLGeneratorComponent(LCChainComponent):
display_name = "Natural Language to SQL"
description = "Generate SQL from natural language."
name = "SQLGenerator"
def build_config(self):
return {
"db": {"display_name": "Database"},
"llm": {"display_name": "LLM"},
"prompt": {
"display_name": "Prompt",
"info": "The prompt must contain `{question}`.",
},
"top_k": {
"display_name": "Top K",
"info": "The number of results per select statement to return. If 0, no limit.",
},
"input_value": {
"display_name": "Input Value",
"info": "The input value to pass to the chain.",
},
}
inputs = [
MultilineInput(
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
),
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
HandleInput(name="db", display_name="SQLDatabase", input_types=["SQLDatabase"], required=True),
IntInput(
name="top_k", display_name="Top K", info="The number of results per select statement to return.", value=5
),
MultilineInput(name="prompt", display_name="Prompt", info="The prompt must contain `{question}`."),
]
def build(
self,
input_value: Text,
db: SQLDatabase,
llm: LanguageModel,
top_k: int = 5,
prompt: Optional[Text] = None,
) -> Text:
if prompt:
prompt_template = PromptTemplate.from_template(template=prompt)
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
def invoke_chain(self) -> Message:
if self.prompt:
prompt_template = PromptTemplate.from_template(template=self.prompt)
else:
prompt_template = None
if top_k < 1:
if self.top_k < 1:
raise ValueError("Top K must be greater than 0.")
if not prompt_template:
sql_query_chain = create_sql_query_chain(llm=llm, db=db, k=top_k)
sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, k=self.top_k)
else:
# Check if {question} is in the prompt
if "{question}" not in prompt_template.template or "question" not in prompt_template.input_variables:
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt_template, k=top_k)
sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, prompt=prompt_template, k=self.top_k)
query_writer: Runnable = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
response = query_writer.invoke({"question": input_value})
response = query_writer.invoke({"question": self.input_value})
query = response.get("query")
self.status = query
return query

View file

@ -1,17 +1,13 @@
from .ConversationChain import ConversationChainComponent
from .LLMChain import LLMChainComponent
from .LLMCheckerChain import LLMCheckerChainComponent
from .LLMMathChain import LLMMathChainComponent
from .RetrievalQA import RetrievalQAComponent
from .RetrievalQAWithSourcesChain import RetrievalQAWithSourcesChainComponent
from .SQLGenerator import SQLGeneratorComponent
__all__ = [
"ConversationChainComponent",
"LLMChainComponent",
"LLMCheckerChainComponent",
"LLMMathChainComponent",
"RetrievalQAComponent",
"RetrievalQAWithSourcesChainComponent",
"SQLGeneratorComponent",
]

View file

@ -0,0 +1,60 @@
from langflow.base.memory.model import LCChatMemoryComponent
from langflow.inputs import MessageTextInput, StrInput, SecretStrInput
from langflow.field_typing import BaseChatMessageHistory
class AstraDBChatMemory(LCChatMemoryComponent):
display_name = "Astra DB Chat Memory"
description = "Retrieves and store chat messages from Astra DB."
name = "AstraDBChatMemory"
icon: str = "AstraDB"
inputs = [
StrInput(
name="collection_name",
display_name="Collection Name",
info="The name of the collection within Astra DB where the vectors will be stored.",
required=True,
),
SecretStrInput(
name="token",
display_name="Astra DB Application Token",
info="Authentication token for accessing Astra DB.",
value="ASTRA_DB_APPLICATION_TOKEN",
required=True,
),
SecretStrInput(
name="api_endpoint",
display_name="API Endpoint",
info="API endpoint URL for the Astra DB service.",
value="ASTRA_DB_API_ENDPOINT",
required=True,
),
StrInput(
name="namespace",
display_name="Namespace",
info="Optional namespace within Astra DB to use for the collection.",
advanced=True,
),
MessageTextInput(
name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True
),
]
def build_message_history(self) -> BaseChatMessageHistory:
try:
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. "
"Please install it with `pip install langchain-astradb`."
)
memory = AstraDBChatMessageHistory(
session_id=self.session_id,
collection_name=self.collection_name,
token=self.token,
api_endpoint=self.api_endpoint,
namespace=self.namespace or None,
)
return memory

View file

@ -1,97 +0,0 @@
from typing import Optional, cast
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.schema import Data
class AstraDBMessageReaderComponent(BaseMemoryComponent):
display_name = "Astra DB Message Reader"
description = "Retrieves stored chat messages from Astra DB."
name = "AstraDBMessageReader"
def build_config(self):
return {
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"collection_name": {
"display_name": "Collection Name",
"info": "Collection name for Astra DB.",
"input_types": ["Text"],
},
"token": {
"display_name": "Astra DB Application Token",
"info": "Token for the Astra DB instance.",
"password": True,
},
"api_endpoint": {
"display_name": "Astra DB API Endpoint",
"info": "API Endpoint for the Astra DB instance.",
"password": True,
},
"namespace": {
"display_name": "Namespace",
"info": "Namespace for the Astra DB instance.",
"input_types": ["Text"],
"advanced": True,
},
}
def get_messages(self, **kwargs) -> list[Data]:
"""
Retrieves messages from the AstraDBChatMessageHistory memory.
Args:
memory (AstraDBChatMessageHistory): The AstraDBChatMessageHistory instance to retrieve messages from.
Returns:
list[Data]: A list of Data objects representing the search results.
"""
try:
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. "
"Please install it with `pip install langchain-astradb`."
)
memory: AstraDBChatMessageHistory = cast(AstraDBChatMessageHistory, kwargs.get("memory"))
if not memory:
raise ValueError("AstraDBChatMessageHistory instance is required.")
# Get messages from the memory
messages = memory.messages
results = [Data.from_lc_message(message) for message in messages]
return list(results)
def build(
self,
session_id: str,
collection_name: str,
token: str,
api_endpoint: str,
namespace: Optional[str] = None,
) -> list[Data]:
try:
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. "
"Please install it with `pip install langchain-astradb`."
)
memory = AstraDBChatMessageHistory(
session_id=session_id,
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)
data = self.get_messages(memory=memory)
self.status = data
return data

View file

@ -1,127 +0,0 @@
from typing import Optional
from langchain_core.messages import BaseMessage
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.schema import Data
class AstraDBMessageWriterComponent(BaseMemoryComponent):
display_name = "Astra DB Message Writer"
description = "Writes a message to Astra DB."
name = "AstraDBMessageWriter"
def build_config(self):
return {
"input_value": {
"display_name": "Input Data",
"info": "Data to write to Astra DB.",
},
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"collection_name": {
"display_name": "Collection Name",
"info": "Collection name for Astra DB.",
"input_types": ["Text"],
},
"token": {
"display_name": "Astra DB Application Token",
"info": "Token for the Astra DB instance.",
"password": True,
},
"api_endpoint": {
"display_name": "Astra DB API Endpoint",
"info": "API Endpoint for the Astra DB instance.",
"password": True,
},
"namespace": {
"display_name": "Namespace",
"info": "Namespace for the Astra DB instance.",
"input_types": ["Text"],
"advanced": True,
},
}
def add_message(
self,
sender: str,
sender_name: str,
text: str,
session_id: str,
metadata: Optional[dict] = None,
**kwargs,
):
"""
Adds a message to the AstraDBChatMessageHistory memory.
Args:
sender (str): The type of the message sender. Typically "ai" or "human".
sender_name (str): The name of the message sender.
text (str): The content of the message.
session_id (str): The session ID associated with the message.
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
**kwargs: Additional keyword arguments, including:
memory (AstraDBChatMessageHistory | None): The memory instance to add the message to.
Raises:
ValueError: If the AstraDBChatMessageHistory instance is not provided.
"""
try:
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. "
"Please install it with `pip install langchain-astradb`."
)
memory: AstraDBChatMessageHistory | None = kwargs.pop("memory", None)
if memory is None:
raise ValueError("AstraDBChatMessageHistory instance is required.")
text_list = [
BaseMessage(
content=text,
sender=sender,
sender_name=sender_name,
metadata=metadata,
session_id=session_id,
type=sender,
)
]
memory.add_messages(text_list)
def build(
self,
input_value: Data,
session_id: str,
collection_name: str,
token: str,
api_endpoint: str,
namespace: Optional[str] = None,
) -> Data:
try:
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
except ImportError:
raise ImportError(
"Could not import langchain Astra DB integration package. "
"Please install it with `pip install langchain-astradb`."
)
memory = AstraDBChatMessageHistory(
session_id=session_id,
collection_name=collection_name,
token=token,
api_endpoint=api_endpoint,
namespace=namespace,
)
self.add_message(**input_value.data, memory=memory)
self.status = f"Added message to Astra DB memory for session {session_id}"
return input_value

View file

@ -0,0 +1,93 @@
from langflow.base.memory.model import LCChatMemoryComponent
from langflow.inputs import MessageTextInput, SecretStrInput, DictInput
from langflow.field_typing import BaseChatMessageHistory
class CassandraChatMemory(LCChatMemoryComponent):
display_name = "Cassandra Chat Memory"
description = "Retrieves and store chat messages from Apache Cassandra."
name = "CassandraChatMemory"
icon = "Cassandra"
inputs = [
MessageTextInput(
name="database_ref",
display_name="Contact Points / Astra Database ID",
info="Contact points for the database (or AstraDB database ID)",
required=True,
),
MessageTextInput(
name="username", display_name="Username", info="Username for the database (leave empty for AstraDB)."
),
SecretStrInput(
name="token",
display_name="Password / AstraDB Token",
info="User password for the database (or AstraDB token).",
required=True,
),
MessageTextInput(
name="keyspace",
display_name="Keyspace",
info="Table Keyspace (or AstraDB namespace).",
required=True,
),
MessageTextInput(
name="table_name",
display_name="Table Name",
info="The name of the table (or AstraDB collection) where vectors will be stored.",
required=True,
),
MessageTextInput(
name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True
),
DictInput(
name="cluster_kwargs",
display_name="Cluster arguments",
info="Optional dictionary of additional keyword arguments for the Cassandra cluster.",
advanced=True,
is_list=True,
),
]
def build_message_history(self) -> BaseChatMessageHistory:
from langchain_community.chat_message_histories import CassandraChatMessageHistory
try:
import cassio
except ImportError:
raise ImportError(
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
)
from uuid import UUID
database_ref = self.database_ref
try:
UUID(self.database_ref)
is_astra = True
except ValueError:
is_astra = False
if "," in self.database_ref:
# use a copy because we can't change the type of the parameter
database_ref = self.database_ref.split(",")
if is_astra:
cassio.init(
database_id=database_ref,
token=self.token,
cluster_kwargs=self.cluster_kwargs,
)
else:
cassio.init(
contact_points=database_ref,
username=self.username,
password=self.token,
cluster_kwargs=self.cluster_kwargs,
)
return CassandraChatMessageHistory(
session_id=self.session_id,
table_name=self.table_name,
keyspace=self.keyspace,
)

View file

@ -1,87 +0,0 @@
from typing import Optional, cast
from langchain_community.chat_message_histories import CassandraChatMessageHistory
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.schema.data import Data
class CassandraMessageReaderComponent(BaseMemoryComponent):
display_name = "Cassandra Message Reader"
description = "Retrieves stored chat messages from a Cassandra table on Astra DB."
name = "CassandraMessageReader"
def build_config(self):
return {
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"database_id": {
"display_name": "Database ID",
"info": "The Astra database ID.",
},
"table_name": {
"display_name": "Table Name",
"info": "The name of the table where messages are stored.",
},
"token": {
"display_name": "Token",
"info": "Authentication token for accessing Cassandra on Astra DB.",
"password": True,
},
"keyspace": {
"display_name": "Keyspace",
"info": "Optional key space within Astra DB. The keyspace should already be created.",
"input_types": ["Text"],
"advanced": True,
},
}
def get_messages(self, **kwargs) -> list[Data]:
"""
Retrieves messages from the CassandraChatMessageHistory memory.
Args:
memory (CassandraChatMessageHistory): The CassandraChatMessageHistory instance to retrieve messages from.
Returns:
list[Data]: A list of Data objects representing the search results.
"""
memory: CassandraChatMessageHistory = cast(CassandraChatMessageHistory, kwargs.get("memory"))
if not memory:
raise ValueError("CassandraChatMessageHistory instance is required.")
# Get messages from the memory
messages = memory.messages
results = [Data.from_lc_message(message) for message in messages]
return list(results)
def build(
self,
session_id: str,
table_name: str,
token: str,
database_id: str,
keyspace: Optional[str] = None,
) -> list[Data]:
try:
import cassio
except ImportError:
raise ImportError(
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
)
cassio.init(token=token, database_id=database_id)
memory = CassandraChatMessageHistory(
session_id=session_id,
table_name=table_name,
keyspace=keyspace,
)
data = self.get_messages(memory=memory)
self.status = data
return data

View file

@ -1,123 +0,0 @@
from typing import Optional
from langchain_community.chat_message_histories import CassandraChatMessageHistory
from langchain_core.messages import BaseMessage
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.schema.data import Data
class CassandraMessageWriterComponent(BaseMemoryComponent):
display_name = "Cassandra Message Writer"
description = "Writes a message to a Cassandra table on Astra DB."
name = "CassandraMessageWriter"
def build_config(self):
return {
"input_value": {
"display_name": "Input Data",
"info": "Data to write to Cassandra.",
},
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"database_id": {
"display_name": "Database ID",
"info": "The Astra database ID.",
},
"table_name": {
"display_name": "Table Name",
"info": "The name of the table where messages will be stored.",
},
"token": {
"display_name": "Token",
"info": "Authentication token for accessing Cassandra on Astra DB.",
"password": True,
},
"keyspace": {
"display_name": "Keyspace",
"info": "Optional key space within Astra DB. The keyspace should already be created.",
"input_types": ["Text"],
"advanced": True,
},
"ttl_seconds": {
"display_name": "TTL Seconds",
"info": "Optional time-to-live for the messages.",
"input_types": ["Number"],
"advanced": True,
},
}
def add_message(
self,
sender: str,
sender_name: str,
text: str,
session_id: str,
metadata: Optional[dict] = None,
**kwargs,
):
"""
Adds a message to the CassandraChatMessageHistory memory.
Args:
sender (str): The type of the message sender. Typically "ai" or "human".
sender_name (str): The name of the message sender.
text (str): The content of the message.
session_id (str): The session ID associated with the message.
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
**kwargs: Additional keyword arguments, including:
memory (CassandraChatMessageHistory | None): The memory instance to add the message to.
Raises:
ValueError: If the CassandraChatMessageHistory instance is not provided.
"""
memory: CassandraChatMessageHistory | None = kwargs.pop("memory", None)
if memory is None:
raise ValueError("CassandraChatMessageHistory instance is required.")
text_list = [
BaseMessage(
content=text,
sender=sender,
sender_name=sender_name,
metadata=metadata,
session_id=session_id,
)
]
memory.add_messages(text_list)
def build(
self,
input_value: Data,
session_id: str,
table_name: str,
token: str,
database_id: str,
keyspace: Optional[str] = None,
ttl_seconds: Optional[int] = None,
) -> Data:
try:
import cassio
except ImportError:
raise ImportError(
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
)
cassio.init(token=token, database_id=database_id)
memory = CassandraChatMessageHistory(
session_id=session_id,
table_name=table_name,
keyspace=keyspace,
ttl_seconds=ttl_seconds,
)
self.add_message(**input_value.data, memory=memory)
self.status = f"Added message to Cassandra memory for session {session_id}"
return input_value

View file

@ -0,0 +1,43 @@
from langflow.base.memory.model import LCChatMemoryComponent
from langflow.inputs import MessageTextInput, SecretStrInput, DropdownInput
from langflow.field_typing import BaseChatMessageHistory
class ZepChatMemory(LCChatMemoryComponent):
display_name = "Zep Chat Memory"
description = "Retrieves and store chat messages from Zep."
name = "ZepChatMemory"
inputs = [
MessageTextInput(name="url", display_name="Zep URL", info="URL of the Zep instance."),
SecretStrInput(name="api_key", display_name="API Key", info="API Key for the Zep instance."),
DropdownInput(
name="api_base_path",
display_name="API Base Path",
options=["api/v1", "api/v2"],
value="api/v1",
advanced=True,
),
MessageTextInput(
name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True
),
]
def build_message_history(self) -> BaseChatMessageHistory:
try:
# Monkeypatch API_BASE_PATH to
# avoid 404
# This is a workaround for the local Zep instance
# cloud Zep works with v2
import zep_python.zep_client
from zep_python import ZepClient
from zep_python.langchain import ZepChatMessageHistory
zep_python.zep_client.API_BASE_PATH = self.api_base_path
except ImportError:
raise ImportError(
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
)
zep_client = ZepClient(api_url=self.url, api_key=self.api_key)
return ZepChatMessageHistory(session_id=self.session_id, zep_client=zep_client)

View file

@ -1,151 +0,0 @@
from typing import Optional, cast
from langchain_community.chat_message_histories.zep import SearchScope, SearchType, ZepChatMessageHistory
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.field_typing import Text
from langflow.schema import Data
class ZepMessageReaderComponent(BaseMemoryComponent):
display_name = "Zep Message Reader"
description = "Retrieves stored chat messages from Zep."
name = "ZepMessageReader"
def build_config(self):
return {
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"url": {
"display_name": "Zep URL",
"info": "URL of the Zep instance.",
"input_types": ["Text"],
},
"api_key": {
"display_name": "Zep API Key",
"info": "API Key for the Zep instance.",
"password": True,
},
"query": {
"display_name": "Query",
"info": "Query to search for in the chat history.",
},
"metadata": {
"display_name": "Metadata",
"info": "Optional metadata to attach to the message.",
"advanced": True,
},
"search_scope": {
"options": ["Messages", "Summary"],
"display_name": "Search Scope",
"info": "Scope of the search.",
"advanced": True,
},
"search_type": {
"options": ["Similarity", "MMR"],
"display_name": "Search Type",
"info": "Type of search.",
"advanced": True,
},
"limit": {
"display_name": "Limit",
"info": "Limit of search results.",
"advanced": True,
},
"api_base_path": {
"display_name": "API Base Path",
"options": ["api/v1", "api/v2"],
},
}
def get_messages(self, **kwargs) -> list[Data]:
"""
Retrieves messages from the ZepChatMessageHistory memory.
If a query is provided, the search method is used to search for messages in the memory, otherwise all messages are returned.
Args:
memory (ZepChatMessageHistory): The ZepChatMessageHistory instance to retrieve messages from.
query (str, optional): The query string to search for messages. Defaults to None.
metadata (dict, optional): Additional metadata to filter the search results. Defaults to None.
search_scope (str, optional): The scope of the search. Can be 'messages' or 'summary'. Defaults to 'messages'.
search_type (str, optional): The type of search. Can be 'similarity' or 'exact'. Defaults to 'similarity'.
limit (int, optional): The maximum number of search results to return. Defaults to None.
Returns:
list[Data]: A list of Data objects representing the search results.
"""
memory: ZepChatMessageHistory = cast(ZepChatMessageHistory, kwargs.get("memory"))
if not memory:
raise ValueError("ZepChatMessageHistory instance is required.")
query = kwargs.get("query")
search_scope = kwargs.get("search_scope", SearchScope.messages).lower()
search_type = kwargs.get("search_type", SearchType.similarity).lower()
limit = kwargs.get("limit")
if query:
memory_search_results = memory.search(
query,
search_scope=search_scope,
search_type=search_type,
limit=limit,
)
# Get the messages from the search results if the search scope is messages
result_dicts = []
for result in memory_search_results:
result_dict = {}
if search_scope == SearchScope.messages:
result_dict["text"] = result.message
else:
result_dict["text"] = result.summary
result_dict["metadata"] = result.metadata
result_dict["score"] = result.score
result_dicts.append(result_dict)
results = [Data(data=result_dict) for result_dict in result_dicts]
else:
messages = memory.messages
results = [Data.from_lc_message(message) for message in messages]
return results
def build(
self,
session_id: Text,
api_base_path: str = "api/v1",
url: Optional[Text] = None,
api_key: Optional[Text] = None,
query: Optional[Text] = None,
search_scope: str = SearchScope.messages,
search_type: str = SearchType.similarity,
limit: Optional[int] = None,
) -> list[Data]:
try:
# Monkeypatch API_BASE_PATH to
# avoid 404
# This is a workaround for the local Zep instance
# cloud Zep works with v2
import zep_python.zep_client
from zep_python import ZepClient
from zep_python.langchain import ZepChatMessageHistory
zep_python.zep_client.API_BASE_PATH = api_base_path
except ImportError:
raise ImportError(
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
)
if url == "":
url = None
zep_client = ZepClient(api_url=url, api_key=api_key)
memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client)
data = self.get_messages(
memory=memory,
query=query,
search_scope=search_scope,
search_type=search_type,
limit=limit,
)
self.status = data
return data

View file

@ -1,109 +0,0 @@
from typing import TYPE_CHECKING, Optional
from langflow.base.memory.memory import BaseMemoryComponent
from langflow.field_typing import Text
from langflow.schema import Data
if TYPE_CHECKING:
from zep_python.langchain import ZepChatMessageHistory
class ZepMessageWriterComponent(BaseMemoryComponent):
display_name = "Zep Message Writer"
description = "Writes a message to Zep."
name = "ZepMessageWriter"
def build_config(self):
return {
"session_id": {
"display_name": "Session ID",
"info": "Session ID of the chat history.",
"input_types": ["Text"],
},
"url": {
"display_name": "Zep URL",
"info": "URL of the Zep instance.",
"input_types": ["Text"],
},
"api_key": {
"display_name": "Zep API Key",
"info": "API Key for the Zep instance.",
"password": True,
},
"limit": {
"display_name": "Limit",
"info": "Limit of search results.",
"advanced": True,
},
"input_value": {
"display_name": "Input Data",
"info": "Data to write to Zep.",
},
"api_base_path": {
"display_name": "API Base Path",
"options": ["api/v1", "api/v2"],
},
}
def add_message(
self, sender: Text, sender_name: Text, text: Text, session_id: Text, metadata: dict | None = None, **kwargs
):
"""
Adds a message to the ZepChatMessageHistory memory.
Args:
sender (Text): The type of the message sender. Valid values are "Machine" or "User".
sender_name (Text): The name of the message sender.
text (Text): The content of the message.
session_id (Text): The session ID associated with the message.
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
**kwargs: Additional keyword arguments.
Raises:
ValueError: If the ZepChatMessageHistory instance is not provided.
"""
memory: ZepChatMessageHistory | None = kwargs.pop("memory", None)
if memory is None:
raise ValueError("ZepChatMessageHistory instance is required.")
if metadata is None:
metadata = {}
metadata["sender_name"] = sender_name
metadata.update(kwargs)
if sender == "Machine":
memory.add_ai_message(text, metadata=metadata)
elif sender == "User":
memory.add_user_message(text, metadata=metadata)
else:
raise ValueError(f"Invalid sender type: {sender}")
def build(
self,
input_value: Data,
session_id: Text,
api_base_path: str = "api/v1",
url: Optional[Text] = None,
api_key: Optional[Text] = None,
) -> Data:
try:
# Monkeypatch API_BASE_PATH to
# avoid 404
# This is a workaround for the local Zep instance
# cloud Zep works with v2
import zep_python.zep_client
from zep_python import ZepClient
from zep_python.langchain import ZepChatMessageHistory
zep_python.zep_client.API_BASE_PATH = api_base_path
except ImportError:
raise ImportError(
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
)
if url == "":
url = None
zep_client = ZepClient(api_url=url, api_key=api_key)
memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client)
self.add_message(**input_value.data, memory=memory)
self.status = f"Added message to Zep memory for session {session_id}"
return input_value

View file

@ -3,6 +3,7 @@ from typing import Callable, Dict, Text, TypeAlias, TypeVar, Union
from langchain.agents.agent import AgentExecutor
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
@ -58,6 +59,7 @@ LANGCHAIN_BASE_TYPES = {
"BaseMemory": BaseMemory,
"BaseChatMemory": BaseChatMemory,
"BaseChatModel": BaseChatModel,
"BaseChatMessageHistory": BaseChatMessageHistory,
}
# Langchain base types plus Python base types
CUSTOM_COMPONENT_SUPPORTED_TYPES = {

View file

@ -1,6 +1,6 @@
import { expect, test } from "@playwright/test";
test("LLMChain - Tooltip", async ({ page }) => {
test("RetrievalQA - Tooltip", async ({ page }) => {
await page.goto("/");
await page.waitForTimeout(1000);
@ -26,11 +26,11 @@ test("LLMChain - Tooltip", async ({ page }) => {
});
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("llmchain");
await page.getByPlaceholder("Search").fill("retrievalqa");
await page.waitForTimeout(1000);
await page
.getByTestId("chainsLLMChain")
.getByTestId("chainsRetrieval QA")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.mouse.up();
await page.mouse.down();
@ -39,12 +39,12 @@ test("LLMChain - Tooltip", async ({ page }) => {
await page.getByTitle("zoom out").click();
await page.getByTitle("zoom out").click();
const llmChainOutputElements = await page
.getByTestId("handle-llmchain-shownode-text-right")
const outputElements = await page
.getByTestId("handle-retrievalqa-shownode-text-right")
.all();
let visibleElementHandle;
for (const element of llmChainOutputElements) {
for (const element of outputElements) {
if (await element.isVisible()) {
visibleElementHandle = element;
break;
@ -52,9 +52,15 @@ test("LLMChain - Tooltip", async ({ page }) => {
}
await visibleElementHandle.hover().then(async () => {
await expect(
page.getByTestId("available-output-inputs").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-chains").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-textsplitters").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-retrievers").first(),
).toBeVisible();
@ -62,13 +68,23 @@ test("LLMChain - Tooltip", async ({ page }) => {
page.getByTestId("available-output-prototypes").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-tools").first(),
page.getByTestId("available-output-embeddings").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-data").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-vectorstores").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-memories").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-toolkits").first(),
page.getByTestId("available-output-models").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-outputs").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-agents").first(),
@ -76,9 +92,6 @@ test("LLMChain - Tooltip", async ({ page }) => {
await expect(
page.getByTestId("available-output-helpers").first(),
).toBeVisible();
await expect(
page.getByTestId("available-output-langchain_utilities").first(),
).toBeVisible();
await page.getByTestId("icon-X").click();
await page.waitForTimeout(500);
@ -89,11 +102,11 @@ test("LLMChain - Tooltip", async ({ page }) => {
await page.getByTitle("zoom out").click();
await page.getByTitle("zoom out").click();
const llmChainInputElements1 = await page
.getByTestId("handle-llmchain-shownode-llm-left")
const rqaChainInputElements1 = await page
.getByTestId("handle-retrievalqa-shownode-language model-left")
.all();
for (const element of llmChainInputElements1) {
for (const element of rqaChainInputElements1) {
if (await element.isVisible()) {
visibleElementHandle = element;
break;
@ -115,11 +128,11 @@ test("LLMChain - Tooltip", async ({ page }) => {
await page.getByTitle("zoom out").click();
await page.getByTitle("zoom out").click();
const llmChainInputElements0 = await page
.getByTestId("handle-llmchain-shownode-template-left")
const rqaChainInputElements0 = await page
.getByTestId("handle-retrievalqa-shownode-retriever-left")
.all();
for (const element of llmChainInputElements0) {
for (const element of rqaChainInputElements0) {
if (await element.isVisible()) {
visibleElementHandle = element;
break;
@ -130,16 +143,10 @@ test("LLMChain - Tooltip", async ({ page }) => {
await page.waitForTimeout(2000);
await expect(
page.getByTestId("available-input-chains").first(),
page.getByTestId("available-input-retrievers").first(),
).toBeVisible();
await expect(
page.getByTestId("available-input-prototypes").first(),
).toBeVisible();
await expect(
page.getByTestId("available-input-agents").first(),
).toBeVisible();
await expect(
page.getByTestId("available-input-helpers").first(),
page.getByTestId("available-input-vectorstores").first(),
).toBeVisible();
await page.waitForTimeout(500);
@ -150,11 +157,11 @@ test("LLMChain - Tooltip", async ({ page }) => {
await page.getByTitle("zoom out").click();
await page.getByTitle("zoom out").click();
const llmChainInputElements2 = await page
.getByTestId("handle-llmchain-shownode-memory-left")
const rqaChainInputElements2 = await page
.getByTestId("handle-retrievalqa-shownode-memory-left")
.all();
for (const element of llmChainInputElements2) {
for (const element of rqaChainInputElements2) {
if (await element.isVisible()) {
visibleElementHandle = element;
break;
@ -163,7 +170,7 @@ test("LLMChain - Tooltip", async ({ page }) => {
await visibleElementHandle.hover().then(async () => {
await expect(
page.getByTestId("empty-tooltip-filter").first(),
page.getByTestId("available-input-memories").first(),
).toBeVisible();
});
});

View file

@ -1,6 +1,6 @@
import { expect, test } from "@playwright/test";
test("LLMChain - Filter", async ({ page }) => {
test("RetrievalQA - Filter", async ({ page }) => {
await page.goto("/");
await page.waitForTimeout(2000);
@ -31,11 +31,11 @@ test("LLMChain - Filter", async ({ page }) => {
});
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("llmchain");
await page.getByPlaceholder("Search").fill("retrievalqa");
await page.waitForTimeout(1000);
await page
.getByTestId("chainsLLMChain")
.getByTestId("chainsRetrieval QA")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.mouse.up();
await page.mouse.down();
@ -47,11 +47,11 @@ test("LLMChain - Filter", async ({ page }) => {
let visibleElementHandle;
const llmChainOutputElements = await page
.getByTestId("handle-llmchain-shownode-text-right")
const outputElements = await page
.getByTestId("handle-retrievalqa-shownode-text-right")
.all();
for (const element of llmChainOutputElements) {
for (const element of outputElements) {
if (await element.isVisible()) {
visibleElementHandle = element;
break;
@ -62,38 +62,47 @@ test("LLMChain - Filter", async ({ page }) => {
force: true,
});
await expect(page.getByTestId("disclosure-inputs")).toBeVisible();
await expect(page.getByTestId("disclosure-outputs")).toBeVisible();
await expect(page.getByTestId("disclosure-data")).toBeVisible();
await expect(page.getByTestId("disclosure-models")).toBeVisible();
await expect(page.getByTestId("disclosure-helpers")).toBeVisible();
await expect(page.getByTestId("disclosure-vector stores")).toBeVisible();
await expect(page.getByTestId("disclosure-embeddings")).toBeVisible();
await expect(page.getByTestId("disclosure-agents")).toBeVisible();
await expect(page.getByTestId("disclosure-chains")).toBeVisible();
await expect(page.getByTestId("disclosure-utilities")).toBeVisible();
await expect(page.getByTestId("disclosure-memories")).toBeVisible();
await expect(page.getByTestId("disclosure-prototypes")).toBeVisible();
await expect(page.getByTestId("disclosure-retrievers")).toBeVisible();
await expect(page.getByTestId("disclosure-toolkits")).toBeVisible();
await expect(page.getByTestId("disclosure-tools")).toBeVisible();
await expect(page.getByTestId("disclosure-text splitters")).toBeVisible();
await expect(page.getByTestId("helpersID Generator").first()).toBeVisible();
await expect(page.getByTestId("agentsCSVAgent").first()).toBeVisible();
await expect(page.getByTestId("chainsLLMChain").first()).toBeVisible();
await expect(page.getByTestId("inputsChat Input").first()).toBeVisible();
await expect(page.getByTestId("outputsChat Output").first()).toBeVisible();
await expect(page.getByTestId("dataAPI Request").first()).toBeVisible();
await expect(page.getByTestId("modelsAmazon Bedrock").first()).toBeVisible();
await expect(page.getByTestId("helpersChat Memory").first()).toBeVisible();
await expect(page.getByTestId("vectorstoresAstra DB").first()).toBeVisible();
await expect(
page.getByTestId("langchain_utilitiesSearchApi").first(),
page.getByTestId("embeddingsAmazon Bedrock Embeddings").first(),
).toBeVisible();
await expect(
page.getByTestId("memoriesAstra DB Message Reader").first(),
page.getByTestId("agentsTool Calling Agent").first(),
).toBeVisible();
await expect(
page.getByTestId("prototypesFlow as Tool").first(),
page.getByTestId("chainsConversationChain").first(),
).toBeVisible();
await expect(
page.getByTestId("retrieversAmazon Kendra Retriever").first(),
page.getByTestId("memoriesAstra DB Chat Memory").first(),
).toBeVisible();
await expect(
page.getByTestId("toolkitsVectorStoreInfo").first(),
page.getByTestId("prototypesConditional Router").first(),
).toBeVisible();
await expect(
page.getByTestId("retrieversSelf Query Retriever").first(),
).toBeVisible();
await expect(
page.getByTestId("textsplittersCharacterTextSplitter").first(),
).toBeVisible();
await expect(page.getByTestId("toolsSearchApi").first()).toBeVisible();
await page.getByPlaceholder("Search").click();
@ -110,11 +119,11 @@ test("LLMChain - Filter", async ({ page }) => {
await expect(page.getByTestId("model_specsChatOpenAI")).not.toBeVisible();
await expect(page.getByTestId("model_specsChatVertexAI")).not.toBeVisible();
const llmChainInputElements1 = await page
.getByTestId("handle-llmchain-shownode-llm-left")
const chainInputElements1 = await page
.getByTestId("handle-retrievalqa-shownode-llm-left")
.all();
for (const element of llmChainInputElements1) {
for (const element of chainInputElements1) {
if (await element.isVisible()) {
visibleElementHandle = element;
break;
@ -129,11 +138,11 @@ test("LLMChain - Filter", async ({ page }) => {
await expect(page.getByTestId("disclosure-models")).toBeVisible();
const llmChainInputElements0 = await page
.getByTestId("handle-llmchain-shownode-template-left")
const rqaChainInputElements0 = await page
.getByTestId("handle-retrievalqa-shownode-template-left")
.all();
for (const element of llmChainInputElements0) {
for (const element of rqaChainInputElements0) {
if (await element.isVisible()) {
visibleElementHandle = element;
break;

View file

@ -30,11 +30,11 @@ test("LLMChain - Filter", async ({ page }) => {
await page.getByTestId("extended-disclosure").click();
await page.getByPlaceholder("Search").click();
await page.getByPlaceholder("Search").fill("llmchain");
await page.getByPlaceholder("Search").fill("api request");
await page.waitForTimeout(1000);
await page.waitForTimeout(2000);
await page
.getByTestId("chainsLLMChain")
.getByTestId("dataAPI Request")
.dragTo(page.locator('//*[@id="react-flow-id"]'));
await page.mouse.up();
await page.mouse.down();
@ -44,65 +44,77 @@ test("LLMChain - Filter", async ({ page }) => {
await page.getByTitle("zoom out").click();
await page.waitForTimeout(500);
await page
.locator(
'//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[7]/button/div[1]',
)
.click();
await expect(page.getByTestId("helpersID Generator")).toBeVisible();
await expect(page.getByTestId("disclosure-agents")).toBeVisible();
await expect(page.getByTestId("chainsLLMChain").first()).toBeVisible();
await expect(
page.getByTestId("langchain_utilitiesSearchApi").first(),
).toBeVisible();
await expect(
page.getByTestId("memoriesAstra DB Message Reader").first(),
).toBeVisible();
await expect(
page.getByTestId("prototypesFlow as Tool").first(),
).toBeVisible();
await expect(
page.getByTestId("retrieversAmazon Kendra Retriever").first(),
).toBeVisible();
await expect(
page.getByTestId("toolkitsVectorStoreInfo").first(),
).toBeVisible();
await expect(page.getByTestId("toolsSearchApi").first()).toBeVisible();
await page.getByPlaceholder("Search").click();
await expect(page.getByTestId("model_specsVertexAI")).not.toBeVisible();
await expect(page.getByTestId("model_specsCTransformers")).not.toBeVisible();
await expect(page.getByTestId("model_specsAmazon Bedrock")).not.toBeVisible();
await expect(page.getByTestId("modelsAzure OpenAI")).not.toBeVisible();
await expect(
page.getByTestId("model_specsAzureChatOpenAI"),
).not.toBeVisible();
await expect(page.getByTestId("model_specsChatAnthropic")).not.toBeVisible();
await expect(page.getByTestId("model_specsChatLiteLLM")).not.toBeVisible();
await expect(page.getByTestId("model_specsChatOllama")).not.toBeVisible();
await expect(page.getByTestId("model_specsChatOpenAI")).not.toBeVisible();
await expect(page.getByTestId("model_specsChatVertexAI")).not.toBeVisible();
await page
.locator(
'//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[4]/div/button/div[1]',
)
.click();
await page.getByTestId("handle-apirequest-shownode-urls-left").click();
await expect(page.getByTestId("disclosure-inputs")).toBeVisible();
await expect(page.getByTestId("disclosure-outputs")).toBeVisible();
await expect(page.getByTestId("disclosure-prompts")).toBeVisible();
await expect(page.getByTestId("disclosure-models")).toBeVisible();
await page
.locator(
'//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[3]/div/button/div[1]',
)
.click();
await expect(page.getByTestId("disclosure-helpers")).toBeVisible();
await expect(page.getByTestId("disclosure-agents")).toBeVisible();
await expect(page.getByTestId("disclosure-chains")).toBeVisible();
await expect(page.getByTestId("disclosure-prototypes")).toBeVisible();
await expect(page.getByTestId("inputsChat Input")).toBeVisible();
await expect(page.getByTestId("outputsChat Output")).toBeVisible();
await expect(page.getByTestId("promptsPrompt")).toBeVisible();
await expect(page.getByTestId("modelsAmazon Bedrock")).toBeVisible();
await expect(page.getByTestId("helpersChat Memory")).toBeVisible();
await expect(page.getByTestId("agentsTool Calling Agent")).toBeVisible();
await expect(page.getByTestId("chainsConversationChain")).toBeVisible();
await expect(page.getByTestId("prototypesConditional Router")).toBeVisible();
await page.getByPlaceholder("Search").click();
await expect(page.getByTestId("inputsChat Input")).not.toBeVisible();
await expect(page.getByTestId("outputsChat Output")).not.toBeVisible();
await expect(page.getByTestId("promptsPrompt")).not.toBeVisible();
await expect(page.getByTestId("modelsAmazon Bedrock")).not.toBeVisible();
await expect(page.getByTestId("helpersChat Memory")).not.toBeVisible();
await expect(page.getByTestId("agentsTool Calling Agent")).not.toBeVisible();
await expect(page.getByTestId("chainsConversationChain")).not.toBeVisible();
await expect(
page.getByTestId("prototypesConditional Router"),
).not.toBeVisible();
await page.getByTestId("handle-apirequest-shownode-headers-left").click();
await expect(page.getByTestId("disclosure-data")).toBeVisible();
await expect(page.getByTestId("disclosure-helpers")).toBeVisible();
await expect(page.getByTestId("disclosure-vector stores")).toBeVisible();
await expect(page.getByTestId("disclosure-utilities")).toBeVisible();
await expect(page.getByTestId("disclosure-prototypes")).toBeVisible();
await expect(page.getByTestId("disclosure-retrievers")).toBeVisible();
await expect(page.getByTestId("disclosure-text splitters")).toBeVisible();
await expect(page.getByTestId("disclosure-tools")).toBeVisible();
await expect(page.getByTestId("dataAPI Request")).toBeVisible();
await expect(page.getByTestId("helpersChat Memory")).toBeVisible();
await expect(page.getByTestId("vectorstoresAstra DB")).toBeVisible();
await expect(page.getByTestId("langchain_utilitiesSearchApi")).toBeVisible();
await expect(page.getByTestId("prototypesSub Flow")).toBeVisible();
await expect(
page.getByTestId("retrieversSelf Query Retriever"),
).toBeVisible();
await expect(
page.getByTestId("textsplittersCharacterTextSplitter"),
).toBeVisible();
await expect(page.getByTestId("toolsSearchApi")).toBeVisible();
await page.getByPlaceholder("Search").click();
await expect(page.getByTestId("dataAPI Request")).not.toBeVisible();
await expect(page.getByTestId("helpersChat Memory")).not.toBeVisible();
await expect(page.getByTestId("vectorstoresAstra DB")).not.toBeVisible();
await expect(
page.getByTestId("langchain_utilitiesSearchApi"),
).not.toBeVisible();
await expect(page.getByTestId("prototypesSub Flow")).not.toBeVisible();
await expect(
page.getByTestId("retrieversSelf Query Retriever"),
).not.toBeVisible();
await expect(
page.getByTestId("textsplittersCharacterTextSplitter"),
).not.toBeVisible();
await expect(page.getByTestId("toolsSearchApi")).not.toBeVisible();
});