feat: migrate chains and memories to Component syntax (#2528)
* feat: migrate chains and memories to Component syntax
* use base class
* add classes
* [autofix.ci] apply automated fixes
* fix tests
* fix tests
* ✅ (filterSidebar.spec.ts): increase waitForTimeout from 1000ms to 2000ms to ensure elements are fully loaded before interaction
---------
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org>
Co-authored-by: cristhianzl <cristhian.lousa@gmail.com>
Co-authored-by: Cristhian Zanforlin Lousa <72977554+Cristhianzl@users.noreply.github.com>
This commit is contained in:
parent
5d9b29e2ae
commit
eb3420523e
24 changed files with 531 additions and 1096 deletions
0
src/backend/base/langflow/base/chains/__init__.py
Normal file
0
src/backend/base/langflow/base/chains/__init__.py
Normal file
17
src/backend/base/langflow/base/chains/model.py
Normal file
17
src/backend/base/langflow/base/chains/model.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from langflow.custom import Component
|
||||
from langflow.template import Output
|
||||
|
||||
|
||||
class LCChainComponent(Component):
|
||||
trace_type = "chain"
|
||||
|
||||
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
|
||||
|
||||
def _validate_outputs(self):
|
||||
required_output_methods = ["invoke_chain"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
if method_name not in output_names:
|
||||
raise ValueError(f"Output with name '{method_name}' must be defined.")
|
||||
elif not hasattr(self, method_name):
|
||||
raise ValueError(f"Method '{method_name}' must be defined.")
|
||||
35
src/backend/base/langflow/base/memory/model.py
Normal file
35
src/backend/base/langflow/base/memory/model.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from abc import abstractmethod
|
||||
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import BaseChatMessageHistory, BaseChatMemory
|
||||
from langflow.template import Output
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
|
||||
|
||||
class LCChatMemoryComponent(Component):
|
||||
trace_type = "chat_memory"
|
||||
outputs = [
|
||||
Output(
|
||||
display_name="Memory",
|
||||
name="base_memory",
|
||||
method="build_base_memory",
|
||||
)
|
||||
]
|
||||
|
||||
def _validate_outputs(self):
|
||||
required_output_methods = ["build_base_memory"]
|
||||
output_names = [output.name for output in self.outputs]
|
||||
for method_name in required_output_methods:
|
||||
if method_name not in output_names:
|
||||
raise ValueError(f"Output with name '{method_name}' must be defined.")
|
||||
elif not hasattr(self, method_name):
|
||||
raise ValueError(f"Method '{method_name}' must be defined.")
|
||||
|
||||
def build_base_memory(self) -> BaseChatMemory:
|
||||
return ConversationBufferMemory(chat_memory=self.build_message_history())
|
||||
|
||||
@abstractmethod
|
||||
def build_message_history(self) -> BaseChatMessageHistory:
|
||||
"""
|
||||
Builds the chat message history memory.
|
||||
"""
|
||||
|
|
@ -1,41 +1,34 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.chains import ConversationChain
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import BaseMemory, LanguageModel, Text
|
||||
from langflow.base.chains.model import LCChainComponent
|
||||
from langflow.field_typing import Message
|
||||
from langflow.inputs import MultilineInput, HandleInput
|
||||
|
||||
|
||||
class ConversationChainComponent(CustomComponent):
|
||||
class ConversationChainComponent(LCChainComponent):
|
||||
display_name = "ConversationChain"
|
||||
description = "Chain to have a conversation and load context from memory."
|
||||
name = "ConversationChain"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"prompt": {"display_name": "Prompt"},
|
||||
"llm": {"display_name": "LLM"},
|
||||
"memory": {
|
||||
"display_name": "Memory",
|
||||
"info": "Memory to load context from. If none is provided, a ConversationBufferMemory will be used.",
|
||||
},
|
||||
"input_value": {
|
||||
"display_name": "Input Value",
|
||||
"info": "The input value to pass to the chain.",
|
||||
},
|
||||
}
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
|
||||
),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
HandleInput(
|
||||
name="memory",
|
||||
display_name="Memory",
|
||||
input_types=["BaseChatMemory"],
|
||||
),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Text,
|
||||
llm: LanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Text:
|
||||
if memory is None:
|
||||
chain = ConversationChain(llm=llm)
|
||||
def invoke_chain(self) -> Message:
|
||||
if not self.memory:
|
||||
chain = ConversationChain(llm=self.llm)
|
||||
else:
|
||||
chain = ConversationChain(llm=llm, memory=memory)
|
||||
result = chain.invoke({"input": input_value})
|
||||
chain = ConversationChain(llm=self.llm, memory=self.memory)
|
||||
|
||||
result = chain.invoke({"input": self.input_value})
|
||||
if isinstance(result, dict):
|
||||
result = result.get(chain.output_key, "") # type: ignore
|
||||
|
||||
|
|
@ -43,5 +36,6 @@ class ConversationChainComponent(CustomComponent):
|
|||
result = result
|
||||
else:
|
||||
result = result.get("response")
|
||||
result = str(result)
|
||||
self.status = result
|
||||
return str(result)
|
||||
return Message(text=result)
|
||||
|
|
|
|||
|
|
@ -1,34 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import BaseMemory, LanguageModel, Text
|
||||
|
||||
|
||||
class LLMChainComponent(CustomComponent):
|
||||
display_name = "LLMChain"
|
||||
description = "Chain to run queries against LLMs"
|
||||
name = "LLMChain"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"prompt": {"display_name": "Prompt"},
|
||||
"llm": {"display_name": "LLM"},
|
||||
"memory": {"display_name": "Memory"},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
template: Text,
|
||||
llm: LanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Text:
|
||||
prompt = PromptTemplate.from_template(template)
|
||||
runnable = LLMChain(prompt=prompt, llm=llm, memory=memory)
|
||||
result_dict = runnable.invoke({})
|
||||
output_key = runnable.output_key
|
||||
result = result_dict[output_key]
|
||||
self.status = result
|
||||
return result
|
||||
|
|
@ -1,32 +1,27 @@
|
|||
from langchain.chains import LLMCheckerChain
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import LanguageModel, Text
|
||||
from langflow.base.chains.model import LCChainComponent
|
||||
from langflow.field_typing import Message
|
||||
from langflow.inputs import MultilineInput, HandleInput
|
||||
|
||||
|
||||
class LLMCheckerChainComponent(CustomComponent):
|
||||
class LLMCheckerChainComponent(LCChainComponent):
|
||||
display_name = "LLMCheckerChain"
|
||||
description = ""
|
||||
description = "Chain for question-answering with self-verification."
|
||||
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
|
||||
name = "LLMCheckerChain"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"llm": {"display_name": "LLM"},
|
||||
"input_value": {
|
||||
"display_name": "Input Value",
|
||||
"info": "The input value to pass to the chain.",
|
||||
},
|
||||
}
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
|
||||
),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Text,
|
||||
llm: LanguageModel,
|
||||
) -> Text:
|
||||
chain = LLMCheckerChain.from_llm(llm=llm)
|
||||
response = chain.invoke({chain.input_key: input_value})
|
||||
def invoke_chain(self) -> Message:
|
||||
chain = LLMCheckerChain.from_llm(llm=self.llm)
|
||||
response = chain.invoke({chain.input_key: self.input_value})
|
||||
result = response.get(chain.output_key, "")
|
||||
result_str = str(result)
|
||||
self.status = result_str
|
||||
return result_str
|
||||
result = str(result)
|
||||
self.status = result
|
||||
return Message(text=result)
|
||||
|
|
|
|||
|
|
@ -1,48 +1,30 @@
|
|||
from typing import Optional
|
||||
from langchain.chains import LLMMathChain
|
||||
|
||||
from langchain.chains import LLMChain, LLMMathChain
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import BaseMemory, LanguageModel, Text
|
||||
from langflow.base.chains.model import LCChainComponent
|
||||
from langflow.field_typing import Message
|
||||
from langflow.inputs import MultilineInput, HandleInput
|
||||
from langflow.template import Output
|
||||
|
||||
|
||||
class LLMMathChainComponent(CustomComponent):
|
||||
class LLMMathChainComponent(LCChainComponent):
|
||||
display_name = "LLMMathChain"
|
||||
description = "Chain that interprets a prompt and executes python code to do math."
|
||||
documentation = "https://python.langchain.com/docs/modules/chains/additional/llm_math"
|
||||
name = "LLMMathChain"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"llm": {"display_name": "LLM"},
|
||||
"llm_chain": {"display_name": "LLM Chain"},
|
||||
"memory": {"display_name": "Memory"},
|
||||
"input_key": {"display_name": "Input Key"},
|
||||
"output_key": {"display_name": "Output Key"},
|
||||
"input_value": {
|
||||
"display_name": "Input Value",
|
||||
"info": "The input value to pass to the chain.",
|
||||
},
|
||||
}
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
|
||||
),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Text,
|
||||
llm: LanguageModel,
|
||||
llm_chain: LLMChain,
|
||||
input_key: str = "question",
|
||||
output_key: str = "answer",
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Text:
|
||||
chain = LLMMathChain(
|
||||
llm=llm,
|
||||
llm_chain=llm_chain,
|
||||
input_key=input_key,
|
||||
output_key=output_key,
|
||||
memory=memory,
|
||||
)
|
||||
response = chain.invoke({input_key: input_value})
|
||||
result = response.get(output_key)
|
||||
result_str = str(result)
|
||||
self.status = result_str
|
||||
return result_str
|
||||
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
|
||||
|
||||
def invoke_chain(self) -> Message:
|
||||
chain = LLMMathChain.from_llm(llm=self.llm)
|
||||
response = chain.invoke({chain.input_key: self.input_value})
|
||||
result = response.get(chain.output_key, "")
|
||||
result = str(result)
|
||||
self.status = result
|
||||
return Message(text=result)
|
||||
|
|
|
|||
|
|
@ -1,69 +1,64 @@
|
|||
from typing import Optional
|
||||
from langchain.chains import RetrievalQA
|
||||
|
||||
from langchain.chains.retrieval_qa.base import RetrievalQA
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import BaseMemory, BaseRetriever, LanguageModel, Text
|
||||
from langflow.schema import Data
|
||||
from langflow.base.chains.model import LCChainComponent
|
||||
from langflow.field_typing import Message
|
||||
from langflow.inputs import HandleInput, MultilineInput, BoolInput, DropdownInput
|
||||
|
||||
|
||||
class RetrievalQAComponent(CustomComponent):
|
||||
class RetrievalQAComponent(LCChainComponent):
|
||||
display_name = "Retrieval QA"
|
||||
description = "Chain for question-answering against an index."
|
||||
description = "Chain for question-answering querying sources from a retriever."
|
||||
name = "RetrievalQA"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"llm": {"display_name": "LLM"},
|
||||
"chain_type": {"display_name": "Chain Type", "options": ["Stuff", "Map Reduce", "Refine", "Map Rerank"]},
|
||||
"retriever": {"display_name": "Retriever"},
|
||||
"memory": {"display_name": "Memory", "required": False},
|
||||
"input_key": {"display_name": "Input Key", "advanced": True},
|
||||
"output_key": {"display_name": "Output Key", "advanced": True},
|
||||
"return_source_documents": {"display_name": "Return Source Documents"},
|
||||
"input_value": {
|
||||
"display_name": "Input",
|
||||
"input_types": ["Data", "Document"],
|
||||
},
|
||||
}
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
|
||||
),
|
||||
DropdownInput(
|
||||
name="chain_type",
|
||||
display_name="Chain Type",
|
||||
info="Chain type to use.",
|
||||
options=["Stuff", "Map Reduce", "Refine", "Map Rerank"],
|
||||
value="Stuff",
|
||||
advanced=True,
|
||||
),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
HandleInput(name="retriever", display_name="Retriever", input_types=["Retriever"], required=True),
|
||||
HandleInput(
|
||||
name="memory",
|
||||
display_name="Memory",
|
||||
input_types=["BaseChatMemory"],
|
||||
),
|
||||
BoolInput(
|
||||
name="return_source_documents",
|
||||
display_name="Return Source Documents",
|
||||
value=False,
|
||||
),
|
||||
]
|
||||
|
||||
def invoke_chain(self) -> Message:
|
||||
chain_type = self.chain_type.lower().replace(" ", "_")
|
||||
if self.memory:
|
||||
self.memory.input_key = "query"
|
||||
self.memory.output_key = "result"
|
||||
|
||||
def build(
|
||||
self,
|
||||
llm: LanguageModel,
|
||||
chain_type: str,
|
||||
retriever: BaseRetriever,
|
||||
input_value: str = "",
|
||||
memory: Optional[BaseMemory] = None,
|
||||
input_key: str = "query",
|
||||
output_key: str = "result",
|
||||
return_source_documents: bool = True,
|
||||
) -> Text:
|
||||
chain_type = chain_type.lower().replace(" ", "_")
|
||||
runnable = RetrievalQA.from_chain_type(
|
||||
llm=llm,
|
||||
llm=self.llm,
|
||||
chain_type=chain_type,
|
||||
retriever=retriever,
|
||||
memory=memory,
|
||||
input_key=input_key,
|
||||
output_key=output_key,
|
||||
return_source_documents=return_source_documents,
|
||||
retriever=self.retriever,
|
||||
memory=self.memory,
|
||||
# always include to help debugging
|
||||
#
|
||||
return_source_documents=True,
|
||||
)
|
||||
if isinstance(input_value, Document):
|
||||
input_value = input_value.page_content
|
||||
if isinstance(input_value, Data):
|
||||
input_value = input_value.get_text()
|
||||
self.status = runnable
|
||||
result = runnable.invoke({input_key: input_value})
|
||||
result = result.content if hasattr(result, "content") else result
|
||||
# Result is a dict with keys "query", "result" and "source_documents"
|
||||
# for now we just return the result
|
||||
data = self.to_data(result.get("source_documents"))
|
||||
references_str = ""
|
||||
if return_source_documents:
|
||||
references_str = self.create_references_from_data(data)
|
||||
result_str = result.get("result", "")
|
||||
|
||||
final_result = "\n".join([str(result_str), references_str])
|
||||
self.status = final_result
|
||||
return final_result # OK
|
||||
result = runnable.invoke({"query": self.input_value})
|
||||
|
||||
source_docs = self.to_data(result.get("source_documents", []))
|
||||
result_str = str(result.get("result", ""))
|
||||
if self.return_source_documents and len(source_docs):
|
||||
references_str = self.create_references_from_data(source_docs)
|
||||
result_str = "\n".join([result_str, references_str])
|
||||
# put the entire result to debug history, query and content
|
||||
self.status = {**result, "source_documents": source_docs, "output": result_str}
|
||||
return result_str
|
||||
|
|
|
|||
|
|
@ -1,64 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.chains import RetrievalQAWithSourcesChain
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import BaseMemory, BaseRetriever, LanguageModel, Text
|
||||
|
||||
|
||||
class RetrievalQAWithSourcesChainComponent(CustomComponent):
|
||||
display_name = "RetrievalQAWithSourcesChain"
|
||||
description = "Question-answering with sources over an index."
|
||||
name = "RetrievalQAWithSourcesChain"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"llm": {"display_name": "LLM"},
|
||||
"chain_type": {
|
||||
"display_name": "Chain Type",
|
||||
"options": ["Stuff", "Map Reduce", "Refine", "Map Rerank"],
|
||||
"info": "The type of chain to use to combined Documents.",
|
||||
},
|
||||
"memory": {"display_name": "Memory"},
|
||||
"return_source_documents": {"display_name": "Return Source Documents"},
|
||||
"retriever": {"display_name": "Retriever"},
|
||||
"input_value": {
|
||||
"display_name": "Input Value",
|
||||
"info": "The input value to pass to the chain.",
|
||||
},
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Text,
|
||||
retriever: BaseRetriever,
|
||||
llm: LanguageModel,
|
||||
chain_type: str,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
return_source_documents: Optional[bool] = True,
|
||||
) -> Text:
|
||||
chain_type = chain_type.lower().replace(" ", "_")
|
||||
runnable = RetrievalQAWithSourcesChain.from_chain_type(
|
||||
llm=llm,
|
||||
chain_type=chain_type,
|
||||
memory=memory,
|
||||
return_source_documents=return_source_documents,
|
||||
retriever=retriever,
|
||||
)
|
||||
if isinstance(input_value, Document):
|
||||
input_value = input_value.page_content
|
||||
self.status = runnable
|
||||
input_key = runnable.input_keys[0]
|
||||
result = runnable.invoke({input_key: input_value})
|
||||
result = result.content if hasattr(result, "content") else result
|
||||
# Result is a dict with keys "query", "result" and "source_documents"
|
||||
# for now we just return the result
|
||||
data = self.to_data(result.get("source_documents"))
|
||||
references_str = ""
|
||||
if return_source_documents:
|
||||
references_str = self.create_references_from_data(data)
|
||||
result_str = str(result.get("answer", ""))
|
||||
final_result = "\n".join([result_str, references_str])
|
||||
self.status = final_result
|
||||
return final_result
|
||||
|
|
@ -1,62 +1,49 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain.chains import create_sql_query_chain
|
||||
from langchain_community.utilities.sql_database import SQLDatabase
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.field_typing import LanguageModel, Text
|
||||
from langflow.base.chains.model import LCChainComponent
|
||||
from langflow.field_typing import Message
|
||||
from langflow.inputs import MultilineInput, HandleInput, IntInput
|
||||
from langflow.template import Output
|
||||
|
||||
|
||||
class SQLGeneratorComponent(CustomComponent):
|
||||
class SQLGeneratorComponent(LCChainComponent):
|
||||
display_name = "Natural Language to SQL"
|
||||
description = "Generate SQL from natural language."
|
||||
name = "SQLGenerator"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"db": {"display_name": "Database"},
|
||||
"llm": {"display_name": "LLM"},
|
||||
"prompt": {
|
||||
"display_name": "Prompt",
|
||||
"info": "The prompt must contain `{question}`.",
|
||||
},
|
||||
"top_k": {
|
||||
"display_name": "Top K",
|
||||
"info": "The number of results per select statement to return. If 0, no limit.",
|
||||
},
|
||||
"input_value": {
|
||||
"display_name": "Input Value",
|
||||
"info": "The input value to pass to the chain.",
|
||||
},
|
||||
}
|
||||
inputs = [
|
||||
MultilineInput(
|
||||
name="input_value", display_name="Input", info="The input value to pass to the chain.", required=True
|
||||
),
|
||||
HandleInput(name="llm", display_name="Language Model", input_types=["LanguageModel"], required=True),
|
||||
HandleInput(name="db", display_name="SQLDatabase", input_types=["SQLDatabase"], required=True),
|
||||
IntInput(
|
||||
name="top_k", display_name="Top K", info="The number of results per select statement to return.", value=5
|
||||
),
|
||||
MultilineInput(name="prompt", display_name="Prompt", info="The prompt must contain `{question}`."),
|
||||
]
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Text,
|
||||
db: SQLDatabase,
|
||||
llm: LanguageModel,
|
||||
top_k: int = 5,
|
||||
prompt: Optional[Text] = None,
|
||||
) -> Text:
|
||||
if prompt:
|
||||
prompt_template = PromptTemplate.from_template(template=prompt)
|
||||
outputs = [Output(display_name="Text", name="text", method="invoke_chain")]
|
||||
|
||||
def invoke_chain(self) -> Message:
|
||||
if self.prompt:
|
||||
prompt_template = PromptTemplate.from_template(template=self.prompt)
|
||||
else:
|
||||
prompt_template = None
|
||||
|
||||
if top_k < 1:
|
||||
if self.top_k < 1:
|
||||
raise ValueError("Top K must be greater than 0.")
|
||||
|
||||
if not prompt_template:
|
||||
sql_query_chain = create_sql_query_chain(llm=llm, db=db, k=top_k)
|
||||
sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, k=self.top_k)
|
||||
else:
|
||||
# Check if {question} is in the prompt
|
||||
if "{question}" not in prompt_template.template or "question" not in prompt_template.input_variables:
|
||||
raise ValueError("Prompt must contain `{question}` to be used with Natural Language to SQL.")
|
||||
sql_query_chain = create_sql_query_chain(llm=llm, db=db, prompt=prompt_template, k=top_k)
|
||||
sql_query_chain = create_sql_query_chain(llm=self.llm, db=self.db, prompt=prompt_template, k=self.top_k)
|
||||
query_writer: Runnable = sql_query_chain | {"query": lambda x: x.replace("SQLQuery:", "").strip()}
|
||||
response = query_writer.invoke({"question": input_value})
|
||||
response = query_writer.invoke({"question": self.input_value})
|
||||
query = response.get("query")
|
||||
self.status = query
|
||||
return query
|
||||
|
|
|
|||
|
|
@ -1,17 +1,13 @@
|
|||
from .ConversationChain import ConversationChainComponent
|
||||
from .LLMChain import LLMChainComponent
|
||||
from .LLMCheckerChain import LLMCheckerChainComponent
|
||||
from .LLMMathChain import LLMMathChainComponent
|
||||
from .RetrievalQA import RetrievalQAComponent
|
||||
from .RetrievalQAWithSourcesChain import RetrievalQAWithSourcesChainComponent
|
||||
from .SQLGenerator import SQLGeneratorComponent
|
||||
|
||||
__all__ = [
|
||||
"ConversationChainComponent",
|
||||
"LLMChainComponent",
|
||||
"LLMCheckerChainComponent",
|
||||
"LLMMathChainComponent",
|
||||
"RetrievalQAComponent",
|
||||
"RetrievalQAWithSourcesChainComponent",
|
||||
"SQLGeneratorComponent",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,60 @@
|
|||
from langflow.base.memory.model import LCChatMemoryComponent
|
||||
from langflow.inputs import MessageTextInput, StrInput, SecretStrInput
|
||||
from langflow.field_typing import BaseChatMessageHistory
|
||||
|
||||
|
||||
class AstraDBChatMemory(LCChatMemoryComponent):
|
||||
display_name = "Astra DB Chat Memory"
|
||||
description = "Retrieves and store chat messages from Astra DB."
|
||||
name = "AstraDBChatMemory"
|
||||
icon: str = "AstraDB"
|
||||
|
||||
inputs = [
|
||||
StrInput(
|
||||
name="collection_name",
|
||||
display_name="Collection Name",
|
||||
info="The name of the collection within Astra DB where the vectors will be stored.",
|
||||
required=True,
|
||||
),
|
||||
SecretStrInput(
|
||||
name="token",
|
||||
display_name="Astra DB Application Token",
|
||||
info="Authentication token for accessing Astra DB.",
|
||||
value="ASTRA_DB_APPLICATION_TOKEN",
|
||||
required=True,
|
||||
),
|
||||
SecretStrInput(
|
||||
name="api_endpoint",
|
||||
display_name="API Endpoint",
|
||||
info="API endpoint URL for the Astra DB service.",
|
||||
value="ASTRA_DB_API_ENDPOINT",
|
||||
required=True,
|
||||
),
|
||||
StrInput(
|
||||
name="namespace",
|
||||
display_name="Namespace",
|
||||
info="Optional namespace within Astra DB to use for the collection.",
|
||||
advanced=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True
|
||||
),
|
||||
]
|
||||
|
||||
def build_message_history(self) -> BaseChatMessageHistory:
|
||||
try:
|
||||
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import langchain Astra DB integration package. "
|
||||
"Please install it with `pip install langchain-astradb`."
|
||||
)
|
||||
|
||||
memory = AstraDBChatMessageHistory(
|
||||
session_id=self.session_id,
|
||||
collection_name=self.collection_name,
|
||||
token=self.token,
|
||||
api_endpoint=self.api_endpoint,
|
||||
namespace=self.namespace or None,
|
||||
)
|
||||
return memory
|
||||
|
|
@ -1,97 +0,0 @@
|
|||
from typing import Optional, cast
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class AstraDBMessageReaderComponent(BaseMemoryComponent):
|
||||
display_name = "Astra DB Message Reader"
|
||||
description = "Retrieves stored chat messages from Astra DB."
|
||||
name = "AstraDBMessageReader"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"collection_name": {
|
||||
"display_name": "Collection Name",
|
||||
"info": "Collection name for Astra DB.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"token": {
|
||||
"display_name": "Astra DB Application Token",
|
||||
"info": "Token for the Astra DB instance.",
|
||||
"password": True,
|
||||
},
|
||||
"api_endpoint": {
|
||||
"display_name": "Astra DB API Endpoint",
|
||||
"info": "API Endpoint for the Astra DB instance.",
|
||||
"password": True,
|
||||
},
|
||||
"namespace": {
|
||||
"display_name": "Namespace",
|
||||
"info": "Namespace for the Astra DB instance.",
|
||||
"input_types": ["Text"],
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def get_messages(self, **kwargs) -> list[Data]:
|
||||
"""
|
||||
Retrieves messages from the AstraDBChatMessageHistory memory.
|
||||
|
||||
Args:
|
||||
memory (AstraDBChatMessageHistory): The AstraDBChatMessageHistory instance to retrieve messages from.
|
||||
|
||||
Returns:
|
||||
list[Data]: A list of Data objects representing the search results.
|
||||
"""
|
||||
try:
|
||||
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import langchain Astra DB integration package. "
|
||||
"Please install it with `pip install langchain-astradb`."
|
||||
)
|
||||
|
||||
memory: AstraDBChatMessageHistory = cast(AstraDBChatMessageHistory, kwargs.get("memory"))
|
||||
if not memory:
|
||||
raise ValueError("AstraDBChatMessageHistory instance is required.")
|
||||
|
||||
# Get messages from the memory
|
||||
messages = memory.messages
|
||||
results = [Data.from_lc_message(message) for message in messages]
|
||||
|
||||
return list(results)
|
||||
|
||||
def build(
|
||||
self,
|
||||
session_id: str,
|
||||
collection_name: str,
|
||||
token: str,
|
||||
api_endpoint: str,
|
||||
namespace: Optional[str] = None,
|
||||
) -> list[Data]:
|
||||
try:
|
||||
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import langchain Astra DB integration package. "
|
||||
"Please install it with `pip install langchain-astradb`."
|
||||
)
|
||||
|
||||
memory = AstraDBChatMessageHistory(
|
||||
session_id=session_id,
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
data = self.get_messages(memory=memory)
|
||||
self.status = data
|
||||
|
||||
return data
|
||||
|
|
@ -1,127 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class AstraDBMessageWriterComponent(BaseMemoryComponent):
|
||||
display_name = "Astra DB Message Writer"
|
||||
description = "Writes a message to Astra DB."
|
||||
name = "AstraDBMessageWriter"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"input_value": {
|
||||
"display_name": "Input Data",
|
||||
"info": "Data to write to Astra DB.",
|
||||
},
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"collection_name": {
|
||||
"display_name": "Collection Name",
|
||||
"info": "Collection name for Astra DB.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"token": {
|
||||
"display_name": "Astra DB Application Token",
|
||||
"info": "Token for the Astra DB instance.",
|
||||
"password": True,
|
||||
},
|
||||
"api_endpoint": {
|
||||
"display_name": "Astra DB API Endpoint",
|
||||
"info": "API Endpoint for the Astra DB instance.",
|
||||
"password": True,
|
||||
},
|
||||
"namespace": {
|
||||
"display_name": "Namespace",
|
||||
"info": "Namespace for the Astra DB instance.",
|
||||
"input_types": ["Text"],
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
sender: str,
|
||||
sender_name: str,
|
||||
text: str,
|
||||
session_id: str,
|
||||
metadata: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Adds a message to the AstraDBChatMessageHistory memory.
|
||||
|
||||
Args:
|
||||
sender (str): The type of the message sender. Typically "ai" or "human".
|
||||
sender_name (str): The name of the message sender.
|
||||
text (str): The content of the message.
|
||||
session_id (str): The session ID associated with the message.
|
||||
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
|
||||
**kwargs: Additional keyword arguments, including:
|
||||
memory (AstraDBChatMessageHistory | None): The memory instance to add the message to.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If the AstraDBChatMessageHistory instance is not provided.
|
||||
|
||||
"""
|
||||
try:
|
||||
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import langchain Astra DB integration package. "
|
||||
"Please install it with `pip install langchain-astradb`."
|
||||
)
|
||||
|
||||
memory: AstraDBChatMessageHistory | None = kwargs.pop("memory", None)
|
||||
if memory is None:
|
||||
raise ValueError("AstraDBChatMessageHistory instance is required.")
|
||||
|
||||
text_list = [
|
||||
BaseMessage(
|
||||
content=text,
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
metadata=metadata,
|
||||
session_id=session_id,
|
||||
type=sender,
|
||||
)
|
||||
]
|
||||
|
||||
memory.add_messages(text_list)
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Data,
|
||||
session_id: str,
|
||||
collection_name: str,
|
||||
token: str,
|
||||
api_endpoint: str,
|
||||
namespace: Optional[str] = None,
|
||||
) -> Data:
|
||||
try:
|
||||
from langchain_astradb.chat_message_histories import AstraDBChatMessageHistory
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import langchain Astra DB integration package. "
|
||||
"Please install it with `pip install langchain-astradb`."
|
||||
)
|
||||
|
||||
memory = AstraDBChatMessageHistory(
|
||||
session_id=session_id,
|
||||
collection_name=collection_name,
|
||||
token=token,
|
||||
api_endpoint=api_endpoint,
|
||||
namespace=namespace,
|
||||
)
|
||||
|
||||
self.add_message(**input_value.data, memory=memory)
|
||||
self.status = f"Added message to Astra DB memory for session {session_id}"
|
||||
|
||||
return input_value
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
from langflow.base.memory.model import LCChatMemoryComponent
|
||||
from langflow.inputs import MessageTextInput, SecretStrInput, DictInput
|
||||
from langflow.field_typing import BaseChatMessageHistory
|
||||
|
||||
|
||||
class CassandraChatMemory(LCChatMemoryComponent):
|
||||
display_name = "Cassandra Chat Memory"
|
||||
description = "Retrieves and store chat messages from Apache Cassandra."
|
||||
name = "CassandraChatMemory"
|
||||
icon = "Cassandra"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(
|
||||
name="database_ref",
|
||||
display_name="Contact Points / Astra Database ID",
|
||||
info="Contact points for the database (or AstraDB database ID)",
|
||||
required=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="username", display_name="Username", info="Username for the database (leave empty for AstraDB)."
|
||||
),
|
||||
SecretStrInput(
|
||||
name="token",
|
||||
display_name="Password / AstraDB Token",
|
||||
info="User password for the database (or AstraDB token).",
|
||||
required=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="keyspace",
|
||||
display_name="Keyspace",
|
||||
info="Table Keyspace (or AstraDB namespace).",
|
||||
required=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="table_name",
|
||||
display_name="Table Name",
|
||||
info="The name of the table (or AstraDB collection) where vectors will be stored.",
|
||||
required=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True
|
||||
),
|
||||
DictInput(
|
||||
name="cluster_kwargs",
|
||||
display_name="Cluster arguments",
|
||||
info="Optional dictionary of additional keyword arguments for the Cassandra cluster.",
|
||||
advanced=True,
|
||||
is_list=True,
|
||||
),
|
||||
]
|
||||
|
||||
def build_message_history(self) -> BaseChatMessageHistory:
|
||||
from langchain_community.chat_message_histories import CassandraChatMessageHistory
|
||||
|
||||
try:
|
||||
import cassio
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
|
||||
)
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
database_ref = self.database_ref
|
||||
|
||||
try:
|
||||
UUID(self.database_ref)
|
||||
is_astra = True
|
||||
except ValueError:
|
||||
is_astra = False
|
||||
if "," in self.database_ref:
|
||||
# use a copy because we can't change the type of the parameter
|
||||
database_ref = self.database_ref.split(",")
|
||||
|
||||
if is_astra:
|
||||
cassio.init(
|
||||
database_id=database_ref,
|
||||
token=self.token,
|
||||
cluster_kwargs=self.cluster_kwargs,
|
||||
)
|
||||
else:
|
||||
cassio.init(
|
||||
contact_points=database_ref,
|
||||
username=self.username,
|
||||
password=self.token,
|
||||
cluster_kwargs=self.cluster_kwargs,
|
||||
)
|
||||
|
||||
return CassandraChatMessageHistory(
|
||||
session_id=self.session_id,
|
||||
table_name=self.table_name,
|
||||
keyspace=self.keyspace,
|
||||
)
|
||||
|
|
@ -1,87 +0,0 @@
|
|||
from typing import Optional, cast
|
||||
|
||||
from langchain_community.chat_message_histories import CassandraChatMessageHistory
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.schema.data import Data
|
||||
|
||||
|
||||
class CassandraMessageReaderComponent(BaseMemoryComponent):
|
||||
display_name = "Cassandra Message Reader"
|
||||
description = "Retrieves stored chat messages from a Cassandra table on Astra DB."
|
||||
name = "CassandraMessageReader"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"database_id": {
|
||||
"display_name": "Database ID",
|
||||
"info": "The Astra database ID.",
|
||||
},
|
||||
"table_name": {
|
||||
"display_name": "Table Name",
|
||||
"info": "The name of the table where messages are stored.",
|
||||
},
|
||||
"token": {
|
||||
"display_name": "Token",
|
||||
"info": "Authentication token for accessing Cassandra on Astra DB.",
|
||||
"password": True,
|
||||
},
|
||||
"keyspace": {
|
||||
"display_name": "Keyspace",
|
||||
"info": "Optional key space within Astra DB. The keyspace should already be created.",
|
||||
"input_types": ["Text"],
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def get_messages(self, **kwargs) -> list[Data]:
|
||||
"""
|
||||
Retrieves messages from the CassandraChatMessageHistory memory.
|
||||
|
||||
Args:
|
||||
memory (CassandraChatMessageHistory): The CassandraChatMessageHistory instance to retrieve messages from.
|
||||
|
||||
Returns:
|
||||
list[Data]: A list of Data objects representing the search results.
|
||||
"""
|
||||
memory: CassandraChatMessageHistory = cast(CassandraChatMessageHistory, kwargs.get("memory"))
|
||||
if not memory:
|
||||
raise ValueError("CassandraChatMessageHistory instance is required.")
|
||||
|
||||
# Get messages from the memory
|
||||
messages = memory.messages
|
||||
results = [Data.from_lc_message(message) for message in messages]
|
||||
|
||||
return list(results)
|
||||
|
||||
def build(
|
||||
self,
|
||||
session_id: str,
|
||||
table_name: str,
|
||||
token: str,
|
||||
database_id: str,
|
||||
keyspace: Optional[str] = None,
|
||||
) -> list[Data]:
|
||||
try:
|
||||
import cassio
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
|
||||
)
|
||||
|
||||
cassio.init(token=token, database_id=database_id)
|
||||
memory = CassandraChatMessageHistory(
|
||||
session_id=session_id,
|
||||
table_name=table_name,
|
||||
keyspace=keyspace,
|
||||
)
|
||||
|
||||
data = self.get_messages(memory=memory)
|
||||
self.status = data
|
||||
|
||||
return data
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain_community.chat_message_histories import CassandraChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.schema.data import Data
|
||||
|
||||
|
||||
class CassandraMessageWriterComponent(BaseMemoryComponent):
|
||||
display_name = "Cassandra Message Writer"
|
||||
description = "Writes a message to a Cassandra table on Astra DB."
|
||||
name = "CassandraMessageWriter"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"input_value": {
|
||||
"display_name": "Input Data",
|
||||
"info": "Data to write to Cassandra.",
|
||||
},
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"database_id": {
|
||||
"display_name": "Database ID",
|
||||
"info": "The Astra database ID.",
|
||||
},
|
||||
"table_name": {
|
||||
"display_name": "Table Name",
|
||||
"info": "The name of the table where messages will be stored.",
|
||||
},
|
||||
"token": {
|
||||
"display_name": "Token",
|
||||
"info": "Authentication token for accessing Cassandra on Astra DB.",
|
||||
"password": True,
|
||||
},
|
||||
"keyspace": {
|
||||
"display_name": "Keyspace",
|
||||
"info": "Optional key space within Astra DB. The keyspace should already be created.",
|
||||
"input_types": ["Text"],
|
||||
"advanced": True,
|
||||
},
|
||||
"ttl_seconds": {
|
||||
"display_name": "TTL Seconds",
|
||||
"info": "Optional time-to-live for the messages.",
|
||||
"input_types": ["Number"],
|
||||
"advanced": True,
|
||||
},
|
||||
}
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
sender: str,
|
||||
sender_name: str,
|
||||
text: str,
|
||||
session_id: str,
|
||||
metadata: Optional[dict] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Adds a message to the CassandraChatMessageHistory memory.
|
||||
|
||||
Args:
|
||||
sender (str): The type of the message sender. Typically "ai" or "human".
|
||||
sender_name (str): The name of the message sender.
|
||||
text (str): The content of the message.
|
||||
session_id (str): The session ID associated with the message.
|
||||
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
|
||||
**kwargs: Additional keyword arguments, including:
|
||||
memory (CassandraChatMessageHistory | None): The memory instance to add the message to.
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: If the CassandraChatMessageHistory instance is not provided.
|
||||
|
||||
"""
|
||||
memory: CassandraChatMessageHistory | None = kwargs.pop("memory", None)
|
||||
if memory is None:
|
||||
raise ValueError("CassandraChatMessageHistory instance is required.")
|
||||
|
||||
text_list = [
|
||||
BaseMessage(
|
||||
content=text,
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
metadata=metadata,
|
||||
session_id=session_id,
|
||||
)
|
||||
]
|
||||
|
||||
memory.add_messages(text_list)
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Data,
|
||||
session_id: str,
|
||||
table_name: str,
|
||||
token: str,
|
||||
database_id: str,
|
||||
keyspace: Optional[str] = None,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
) -> Data:
|
||||
try:
|
||||
import cassio
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import cassio integration package. " "Please install it with `pip install cassio`."
|
||||
)
|
||||
|
||||
cassio.init(token=token, database_id=database_id)
|
||||
memory = CassandraChatMessageHistory(
|
||||
session_id=session_id,
|
||||
table_name=table_name,
|
||||
keyspace=keyspace,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
self.add_message(**input_value.data, memory=memory)
|
||||
self.status = f"Added message to Cassandra memory for session {session_id}"
|
||||
|
||||
return input_value
|
||||
|
|
@ -0,0 +1,43 @@
|
|||
from langflow.base.memory.model import LCChatMemoryComponent
|
||||
from langflow.inputs import MessageTextInput, SecretStrInput, DropdownInput
|
||||
from langflow.field_typing import BaseChatMessageHistory
|
||||
|
||||
|
||||
class ZepChatMemory(LCChatMemoryComponent):
|
||||
display_name = "Zep Chat Memory"
|
||||
description = "Retrieves and store chat messages from Zep."
|
||||
name = "ZepChatMemory"
|
||||
|
||||
inputs = [
|
||||
MessageTextInput(name="url", display_name="Zep URL", info="URL of the Zep instance."),
|
||||
SecretStrInput(name="api_key", display_name="API Key", info="API Key for the Zep instance."),
|
||||
DropdownInput(
|
||||
name="api_base_path",
|
||||
display_name="API Base Path",
|
||||
options=["api/v1", "api/v2"],
|
||||
value="api/v1",
|
||||
advanced=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="session_id", display_name="Session ID", info="Session ID for the message.", advanced=True
|
||||
),
|
||||
]
|
||||
|
||||
def build_message_history(self) -> BaseChatMessageHistory:
|
||||
try:
|
||||
# Monkeypatch API_BASE_PATH to
|
||||
# avoid 404
|
||||
# This is a workaround for the local Zep instance
|
||||
# cloud Zep works with v2
|
||||
import zep_python.zep_client
|
||||
from zep_python import ZepClient
|
||||
from zep_python.langchain import ZepChatMessageHistory
|
||||
|
||||
zep_python.zep_client.API_BASE_PATH = self.api_base_path
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
|
||||
)
|
||||
|
||||
zep_client = ZepClient(api_url=self.url, api_key=self.api_key)
|
||||
return ZepChatMessageHistory(session_id=self.session_id, zep_client=zep_client)
|
||||
|
|
@ -1,151 +0,0 @@
|
|||
from typing import Optional, cast
|
||||
|
||||
from langchain_community.chat_message_histories.zep import SearchScope, SearchType, ZepChatMessageHistory
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class ZepMessageReaderComponent(BaseMemoryComponent):
|
||||
display_name = "Zep Message Reader"
|
||||
description = "Retrieves stored chat messages from Zep."
|
||||
name = "ZepMessageReader"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"url": {
|
||||
"display_name": "Zep URL",
|
||||
"info": "URL of the Zep instance.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"api_key": {
|
||||
"display_name": "Zep API Key",
|
||||
"info": "API Key for the Zep instance.",
|
||||
"password": True,
|
||||
},
|
||||
"query": {
|
||||
"display_name": "Query",
|
||||
"info": "Query to search for in the chat history.",
|
||||
},
|
||||
"metadata": {
|
||||
"display_name": "Metadata",
|
||||
"info": "Optional metadata to attach to the message.",
|
||||
"advanced": True,
|
||||
},
|
||||
"search_scope": {
|
||||
"options": ["Messages", "Summary"],
|
||||
"display_name": "Search Scope",
|
||||
"info": "Scope of the search.",
|
||||
"advanced": True,
|
||||
},
|
||||
"search_type": {
|
||||
"options": ["Similarity", "MMR"],
|
||||
"display_name": "Search Type",
|
||||
"info": "Type of search.",
|
||||
"advanced": True,
|
||||
},
|
||||
"limit": {
|
||||
"display_name": "Limit",
|
||||
"info": "Limit of search results.",
|
||||
"advanced": True,
|
||||
},
|
||||
"api_base_path": {
|
||||
"display_name": "API Base Path",
|
||||
"options": ["api/v1", "api/v2"],
|
||||
},
|
||||
}
|
||||
|
||||
def get_messages(self, **kwargs) -> list[Data]:
|
||||
"""
|
||||
Retrieves messages from the ZepChatMessageHistory memory.
|
||||
|
||||
If a query is provided, the search method is used to search for messages in the memory, otherwise all messages are returned.
|
||||
|
||||
Args:
|
||||
memory (ZepChatMessageHistory): The ZepChatMessageHistory instance to retrieve messages from.
|
||||
query (str, optional): The query string to search for messages. Defaults to None.
|
||||
metadata (dict, optional): Additional metadata to filter the search results. Defaults to None.
|
||||
search_scope (str, optional): The scope of the search. Can be 'messages' or 'summary'. Defaults to 'messages'.
|
||||
search_type (str, optional): The type of search. Can be 'similarity' or 'exact'. Defaults to 'similarity'.
|
||||
limit (int, optional): The maximum number of search results to return. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[Data]: A list of Data objects representing the search results.
|
||||
"""
|
||||
memory: ZepChatMessageHistory = cast(ZepChatMessageHistory, kwargs.get("memory"))
|
||||
if not memory:
|
||||
raise ValueError("ZepChatMessageHistory instance is required.")
|
||||
query = kwargs.get("query")
|
||||
search_scope = kwargs.get("search_scope", SearchScope.messages).lower()
|
||||
search_type = kwargs.get("search_type", SearchType.similarity).lower()
|
||||
limit = kwargs.get("limit")
|
||||
|
||||
if query:
|
||||
memory_search_results = memory.search(
|
||||
query,
|
||||
search_scope=search_scope,
|
||||
search_type=search_type,
|
||||
limit=limit,
|
||||
)
|
||||
# Get the messages from the search results if the search scope is messages
|
||||
result_dicts = []
|
||||
for result in memory_search_results:
|
||||
result_dict = {}
|
||||
if search_scope == SearchScope.messages:
|
||||
result_dict["text"] = result.message
|
||||
else:
|
||||
result_dict["text"] = result.summary
|
||||
result_dict["metadata"] = result.metadata
|
||||
result_dict["score"] = result.score
|
||||
result_dicts.append(result_dict)
|
||||
results = [Data(data=result_dict) for result_dict in result_dicts]
|
||||
else:
|
||||
messages = memory.messages
|
||||
results = [Data.from_lc_message(message) for message in messages]
|
||||
return results
|
||||
|
||||
def build(
|
||||
self,
|
||||
session_id: Text,
|
||||
api_base_path: str = "api/v1",
|
||||
url: Optional[Text] = None,
|
||||
api_key: Optional[Text] = None,
|
||||
query: Optional[Text] = None,
|
||||
search_scope: str = SearchScope.messages,
|
||||
search_type: str = SearchType.similarity,
|
||||
limit: Optional[int] = None,
|
||||
) -> list[Data]:
|
||||
try:
|
||||
# Monkeypatch API_BASE_PATH to
|
||||
# avoid 404
|
||||
# This is a workaround for the local Zep instance
|
||||
# cloud Zep works with v2
|
||||
import zep_python.zep_client
|
||||
from zep_python import ZepClient
|
||||
from zep_python.langchain import ZepChatMessageHistory
|
||||
|
||||
zep_python.zep_client.API_BASE_PATH = api_base_path
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
|
||||
)
|
||||
if url == "":
|
||||
url = None
|
||||
|
||||
zep_client = ZepClient(api_url=url, api_key=api_key)
|
||||
memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client)
|
||||
data = self.get_messages(
|
||||
memory=memory,
|
||||
query=query,
|
||||
search_scope=search_scope,
|
||||
search_type=search_type,
|
||||
limit=limit,
|
||||
)
|
||||
self.status = data
|
||||
return data
|
||||
|
|
@ -1,109 +0,0 @@
|
|||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from langflow.base.memory.memory import BaseMemoryComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.schema import Data
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from zep_python.langchain import ZepChatMessageHistory
|
||||
|
||||
|
||||
class ZepMessageWriterComponent(BaseMemoryComponent):
|
||||
display_name = "Zep Message Writer"
|
||||
description = "Writes a message to Zep."
|
||||
name = "ZepMessageWriter"
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "Session ID of the chat history.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"url": {
|
||||
"display_name": "Zep URL",
|
||||
"info": "URL of the Zep instance.",
|
||||
"input_types": ["Text"],
|
||||
},
|
||||
"api_key": {
|
||||
"display_name": "Zep API Key",
|
||||
"info": "API Key for the Zep instance.",
|
||||
"password": True,
|
||||
},
|
||||
"limit": {
|
||||
"display_name": "Limit",
|
||||
"info": "Limit of search results.",
|
||||
"advanced": True,
|
||||
},
|
||||
"input_value": {
|
||||
"display_name": "Input Data",
|
||||
"info": "Data to write to Zep.",
|
||||
},
|
||||
"api_base_path": {
|
||||
"display_name": "API Base Path",
|
||||
"options": ["api/v1", "api/v2"],
|
||||
},
|
||||
}
|
||||
|
||||
def add_message(
|
||||
self, sender: Text, sender_name: Text, text: Text, session_id: Text, metadata: dict | None = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Adds a message to the ZepChatMessageHistory memory.
|
||||
|
||||
Args:
|
||||
sender (Text): The type of the message sender. Valid values are "Machine" or "User".
|
||||
sender_name (Text): The name of the message sender.
|
||||
text (Text): The content of the message.
|
||||
session_id (Text): The session ID associated with the message.
|
||||
metadata (dict | None, optional): Additional metadata for the message. Defaults to None.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Raises:
|
||||
ValueError: If the ZepChatMessageHistory instance is not provided.
|
||||
|
||||
"""
|
||||
memory: ZepChatMessageHistory | None = kwargs.pop("memory", None)
|
||||
if memory is None:
|
||||
raise ValueError("ZepChatMessageHistory instance is required.")
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
metadata["sender_name"] = sender_name
|
||||
metadata.update(kwargs)
|
||||
if sender == "Machine":
|
||||
memory.add_ai_message(text, metadata=metadata)
|
||||
elif sender == "User":
|
||||
memory.add_user_message(text, metadata=metadata)
|
||||
else:
|
||||
raise ValueError(f"Invalid sender type: {sender}")
|
||||
|
||||
def build(
|
||||
self,
|
||||
input_value: Data,
|
||||
session_id: Text,
|
||||
api_base_path: str = "api/v1",
|
||||
url: Optional[Text] = None,
|
||||
api_key: Optional[Text] = None,
|
||||
) -> Data:
|
||||
try:
|
||||
# Monkeypatch API_BASE_PATH to
|
||||
# avoid 404
|
||||
# This is a workaround for the local Zep instance
|
||||
# cloud Zep works with v2
|
||||
import zep_python.zep_client
|
||||
from zep_python import ZepClient
|
||||
from zep_python.langchain import ZepChatMessageHistory
|
||||
|
||||
zep_python.zep_client.API_BASE_PATH = api_base_path
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import zep-python package. " "Please install it with `pip install zep-python`."
|
||||
)
|
||||
if url == "":
|
||||
url = None
|
||||
|
||||
zep_client = ZepClient(api_url=url, api_key=api_key)
|
||||
memory = ZepChatMessageHistory(session_id=session_id, zep_client=zep_client)
|
||||
self.add_message(**input_value.data, memory=memory)
|
||||
self.status = f"Added message to Zep memory for session {session_id}"
|
||||
return input_value
|
||||
|
|
@ -3,6 +3,7 @@ from typing import Callable, Dict, Text, TypeAlias, TypeVar, Union
|
|||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
|
@ -58,6 +59,7 @@ LANGCHAIN_BASE_TYPES = {
|
|||
"BaseMemory": BaseMemory,
|
||||
"BaseChatMemory": BaseChatMemory,
|
||||
"BaseChatModel": BaseChatModel,
|
||||
"BaseChatMessageHistory": BaseChatMessageHistory,
|
||||
}
|
||||
# Langchain base types plus Python base types
|
||||
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { expect, test } from "@playwright/test";
|
||||
|
||||
test("LLMChain - Tooltip", async ({ page }) => {
|
||||
test("RetrievalQA - Tooltip", async ({ page }) => {
|
||||
await page.goto("/");
|
||||
await page.waitForTimeout(1000);
|
||||
|
||||
|
|
@ -26,11 +26,11 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
});
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("llmchain");
|
||||
await page.getByPlaceholder("Search").fill("retrievalqa");
|
||||
|
||||
await page.waitForTimeout(1000);
|
||||
await page
|
||||
.getByTestId("chainsLLMChain")
|
||||
.getByTestId("chainsRetrieval QA")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
await page.mouse.up();
|
||||
await page.mouse.down();
|
||||
|
|
@ -39,12 +39,12 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
await page.getByTitle("zoom out").click();
|
||||
await page.getByTitle("zoom out").click();
|
||||
|
||||
const llmChainOutputElements = await page
|
||||
.getByTestId("handle-llmchain-shownode-text-right")
|
||||
const outputElements = await page
|
||||
.getByTestId("handle-retrievalqa-shownode-text-right")
|
||||
.all();
|
||||
let visibleElementHandle;
|
||||
|
||||
for (const element of llmChainOutputElements) {
|
||||
for (const element of outputElements) {
|
||||
if (await element.isVisible()) {
|
||||
visibleElementHandle = element;
|
||||
break;
|
||||
|
|
@ -52,9 +52,15 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
}
|
||||
|
||||
await visibleElementHandle.hover().then(async () => {
|
||||
await expect(
|
||||
page.getByTestId("available-output-inputs").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-chains").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-textsplitters").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-retrievers").first(),
|
||||
).toBeVisible();
|
||||
|
|
@ -62,13 +68,23 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
page.getByTestId("available-output-prototypes").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-tools").first(),
|
||||
page.getByTestId("available-output-embeddings").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-data").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-vectorstores").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-memories").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-toolkits").first(),
|
||||
page.getByTestId("available-output-models").first(),
|
||||
).toBeVisible();
|
||||
|
||||
await expect(
|
||||
page.getByTestId("available-output-outputs").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-agents").first(),
|
||||
|
|
@ -76,9 +92,6 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
await expect(
|
||||
page.getByTestId("available-output-helpers").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-output-langchain_utilities").first(),
|
||||
).toBeVisible();
|
||||
|
||||
await page.getByTestId("icon-X").click();
|
||||
await page.waitForTimeout(500);
|
||||
|
|
@ -89,11 +102,11 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
await page.getByTitle("zoom out").click();
|
||||
await page.getByTitle("zoom out").click();
|
||||
|
||||
const llmChainInputElements1 = await page
|
||||
.getByTestId("handle-llmchain-shownode-llm-left")
|
||||
const rqaChainInputElements1 = await page
|
||||
.getByTestId("handle-retrievalqa-shownode-language model-left")
|
||||
.all();
|
||||
|
||||
for (const element of llmChainInputElements1) {
|
||||
for (const element of rqaChainInputElements1) {
|
||||
if (await element.isVisible()) {
|
||||
visibleElementHandle = element;
|
||||
break;
|
||||
|
|
@ -115,11 +128,11 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
await page.getByTitle("zoom out").click();
|
||||
await page.getByTitle("zoom out").click();
|
||||
|
||||
const llmChainInputElements0 = await page
|
||||
.getByTestId("handle-llmchain-shownode-template-left")
|
||||
const rqaChainInputElements0 = await page
|
||||
.getByTestId("handle-retrievalqa-shownode-retriever-left")
|
||||
.all();
|
||||
|
||||
for (const element of llmChainInputElements0) {
|
||||
for (const element of rqaChainInputElements0) {
|
||||
if (await element.isVisible()) {
|
||||
visibleElementHandle = element;
|
||||
break;
|
||||
|
|
@ -130,16 +143,10 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
await page.waitForTimeout(2000);
|
||||
|
||||
await expect(
|
||||
page.getByTestId("available-input-chains").first(),
|
||||
page.getByTestId("available-input-retrievers").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-input-prototypes").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-input-agents").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("available-input-helpers").first(),
|
||||
page.getByTestId("available-input-vectorstores").first(),
|
||||
).toBeVisible();
|
||||
|
||||
await page.waitForTimeout(500);
|
||||
|
|
@ -150,11 +157,11 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
await page.getByTitle("zoom out").click();
|
||||
await page.getByTitle("zoom out").click();
|
||||
|
||||
const llmChainInputElements2 = await page
|
||||
.getByTestId("handle-llmchain-shownode-memory-left")
|
||||
const rqaChainInputElements2 = await page
|
||||
.getByTestId("handle-retrievalqa-shownode-memory-left")
|
||||
.all();
|
||||
|
||||
for (const element of llmChainInputElements2) {
|
||||
for (const element of rqaChainInputElements2) {
|
||||
if (await element.isVisible()) {
|
||||
visibleElementHandle = element;
|
||||
break;
|
||||
|
|
@ -163,7 +170,7 @@ test("LLMChain - Tooltip", async ({ page }) => {
|
|||
|
||||
await visibleElementHandle.hover().then(async () => {
|
||||
await expect(
|
||||
page.getByTestId("empty-tooltip-filter").first(),
|
||||
page.getByTestId("available-input-memories").first(),
|
||||
).toBeVisible();
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import { expect, test } from "@playwright/test";
|
||||
|
||||
test("LLMChain - Filter", async ({ page }) => {
|
||||
test("RetrievalQA - Filter", async ({ page }) => {
|
||||
await page.goto("/");
|
||||
await page.waitForTimeout(2000);
|
||||
|
||||
|
|
@ -31,11 +31,11 @@ test("LLMChain - Filter", async ({ page }) => {
|
|||
});
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("llmchain");
|
||||
await page.getByPlaceholder("Search").fill("retrievalqa");
|
||||
|
||||
await page.waitForTimeout(1000);
|
||||
await page
|
||||
.getByTestId("chainsLLMChain")
|
||||
.getByTestId("chainsRetrieval QA")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
await page.mouse.up();
|
||||
await page.mouse.down();
|
||||
|
|
@ -47,11 +47,11 @@ test("LLMChain - Filter", async ({ page }) => {
|
|||
|
||||
let visibleElementHandle;
|
||||
|
||||
const llmChainOutputElements = await page
|
||||
.getByTestId("handle-llmchain-shownode-text-right")
|
||||
const outputElements = await page
|
||||
.getByTestId("handle-retrievalqa-shownode-text-right")
|
||||
.all();
|
||||
|
||||
for (const element of llmChainOutputElements) {
|
||||
for (const element of outputElements) {
|
||||
if (await element.isVisible()) {
|
||||
visibleElementHandle = element;
|
||||
break;
|
||||
|
|
@ -62,38 +62,47 @@ test("LLMChain - Filter", async ({ page }) => {
|
|||
force: true,
|
||||
});
|
||||
|
||||
await expect(page.getByTestId("disclosure-inputs")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-outputs")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-data")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-models")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-helpers")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-vector stores")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-embeddings")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-agents")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-chains")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-utilities")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-memories")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-prototypes")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-retrievers")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-toolkits")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-tools")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-text splitters")).toBeVisible();
|
||||
|
||||
await expect(page.getByTestId("helpersID Generator").first()).toBeVisible();
|
||||
|
||||
await expect(page.getByTestId("agentsCSVAgent").first()).toBeVisible();
|
||||
|
||||
await expect(page.getByTestId("chainsLLMChain").first()).toBeVisible();
|
||||
await expect(page.getByTestId("inputsChat Input").first()).toBeVisible();
|
||||
await expect(page.getByTestId("outputsChat Output").first()).toBeVisible();
|
||||
await expect(page.getByTestId("dataAPI Request").first()).toBeVisible();
|
||||
await expect(page.getByTestId("modelsAmazon Bedrock").first()).toBeVisible();
|
||||
await expect(page.getByTestId("helpersChat Memory").first()).toBeVisible();
|
||||
await expect(page.getByTestId("vectorstoresAstra DB").first()).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("langchain_utilitiesSearchApi").first(),
|
||||
page.getByTestId("embeddingsAmazon Bedrock Embeddings").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("memoriesAstra DB Message Reader").first(),
|
||||
page.getByTestId("agentsTool Calling Agent").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("prototypesFlow as Tool").first(),
|
||||
page.getByTestId("chainsConversationChain").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("retrieversAmazon Kendra Retriever").first(),
|
||||
page.getByTestId("memoriesAstra DB Chat Memory").first(),
|
||||
).toBeVisible();
|
||||
|
||||
await expect(
|
||||
page.getByTestId("toolkitsVectorStoreInfo").first(),
|
||||
page.getByTestId("prototypesConditional Router").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("retrieversSelf Query Retriever").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("textsplittersCharacterTextSplitter").first(),
|
||||
).toBeVisible();
|
||||
await expect(page.getByTestId("toolsSearchApi").first()).toBeVisible();
|
||||
|
||||
await page.getByPlaceholder("Search").click();
|
||||
|
||||
|
|
@ -110,11 +119,11 @@ test("LLMChain - Filter", async ({ page }) => {
|
|||
await expect(page.getByTestId("model_specsChatOpenAI")).not.toBeVisible();
|
||||
await expect(page.getByTestId("model_specsChatVertexAI")).not.toBeVisible();
|
||||
|
||||
const llmChainInputElements1 = await page
|
||||
.getByTestId("handle-llmchain-shownode-llm-left")
|
||||
const chainInputElements1 = await page
|
||||
.getByTestId("handle-retrievalqa-shownode-llm-left")
|
||||
.all();
|
||||
|
||||
for (const element of llmChainInputElements1) {
|
||||
for (const element of chainInputElements1) {
|
||||
if (await element.isVisible()) {
|
||||
visibleElementHandle = element;
|
||||
break;
|
||||
|
|
@ -129,11 +138,11 @@ test("LLMChain - Filter", async ({ page }) => {
|
|||
|
||||
await expect(page.getByTestId("disclosure-models")).toBeVisible();
|
||||
|
||||
const llmChainInputElements0 = await page
|
||||
.getByTestId("handle-llmchain-shownode-template-left")
|
||||
const rqaChainInputElements0 = await page
|
||||
.getByTestId("handle-retrievalqa-shownode-template-left")
|
||||
.all();
|
||||
|
||||
for (const element of llmChainInputElements0) {
|
||||
for (const element of rqaChainInputElements0) {
|
||||
if (await element.isVisible()) {
|
||||
visibleElementHandle = element;
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -30,11 +30,11 @@ test("LLMChain - Filter", async ({ page }) => {
|
|||
|
||||
await page.getByTestId("extended-disclosure").click();
|
||||
await page.getByPlaceholder("Search").click();
|
||||
await page.getByPlaceholder("Search").fill("llmchain");
|
||||
await page.getByPlaceholder("Search").fill("api request");
|
||||
|
||||
await page.waitForTimeout(1000);
|
||||
await page.waitForTimeout(2000);
|
||||
await page
|
||||
.getByTestId("chainsLLMChain")
|
||||
.getByTestId("dataAPI Request")
|
||||
.dragTo(page.locator('//*[@id="react-flow-id"]'));
|
||||
await page.mouse.up();
|
||||
await page.mouse.down();
|
||||
|
|
@ -44,65 +44,77 @@ test("LLMChain - Filter", async ({ page }) => {
|
|||
await page.getByTitle("zoom out").click();
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page
|
||||
.locator(
|
||||
'//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[7]/button/div[1]',
|
||||
)
|
||||
.click();
|
||||
|
||||
await expect(page.getByTestId("helpersID Generator")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-agents")).toBeVisible();
|
||||
|
||||
await expect(page.getByTestId("chainsLLMChain").first()).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("langchain_utilitiesSearchApi").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("memoriesAstra DB Message Reader").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("prototypesFlow as Tool").first(),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("retrieversAmazon Kendra Retriever").first(),
|
||||
).toBeVisible();
|
||||
|
||||
await expect(
|
||||
page.getByTestId("toolkitsVectorStoreInfo").first(),
|
||||
).toBeVisible();
|
||||
await expect(page.getByTestId("toolsSearchApi").first()).toBeVisible();
|
||||
|
||||
await page.getByPlaceholder("Search").click();
|
||||
|
||||
await expect(page.getByTestId("model_specsVertexAI")).not.toBeVisible();
|
||||
await expect(page.getByTestId("model_specsCTransformers")).not.toBeVisible();
|
||||
await expect(page.getByTestId("model_specsAmazon Bedrock")).not.toBeVisible();
|
||||
await expect(page.getByTestId("modelsAzure OpenAI")).not.toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("model_specsAzureChatOpenAI"),
|
||||
).not.toBeVisible();
|
||||
await expect(page.getByTestId("model_specsChatAnthropic")).not.toBeVisible();
|
||||
await expect(page.getByTestId("model_specsChatLiteLLM")).not.toBeVisible();
|
||||
await expect(page.getByTestId("model_specsChatOllama")).not.toBeVisible();
|
||||
await expect(page.getByTestId("model_specsChatOpenAI")).not.toBeVisible();
|
||||
await expect(page.getByTestId("model_specsChatVertexAI")).not.toBeVisible();
|
||||
|
||||
await page
|
||||
.locator(
|
||||
'//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[4]/div/button/div[1]',
|
||||
)
|
||||
.click();
|
||||
await page.getByTestId("handle-apirequest-shownode-urls-left").click();
|
||||
|
||||
await expect(page.getByTestId("disclosure-inputs")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-outputs")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-prompts")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-models")).toBeVisible();
|
||||
|
||||
await page
|
||||
.locator(
|
||||
'//*[@id="react-flow-id"]/div/div[1]/div[1]/div/div[2]/div/div/div[2]/div[3]/div/button/div[1]',
|
||||
)
|
||||
.click();
|
||||
|
||||
await expect(page.getByTestId("disclosure-helpers")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-agents")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-chains")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-prototypes")).toBeVisible();
|
||||
|
||||
await expect(page.getByTestId("inputsChat Input")).toBeVisible();
|
||||
await expect(page.getByTestId("outputsChat Output")).toBeVisible();
|
||||
await expect(page.getByTestId("promptsPrompt")).toBeVisible();
|
||||
await expect(page.getByTestId("modelsAmazon Bedrock")).toBeVisible();
|
||||
await expect(page.getByTestId("helpersChat Memory")).toBeVisible();
|
||||
await expect(page.getByTestId("agentsTool Calling Agent")).toBeVisible();
|
||||
await expect(page.getByTestId("chainsConversationChain")).toBeVisible();
|
||||
await expect(page.getByTestId("prototypesConditional Router")).toBeVisible();
|
||||
|
||||
await page.getByPlaceholder("Search").click();
|
||||
|
||||
await expect(page.getByTestId("inputsChat Input")).not.toBeVisible();
|
||||
await expect(page.getByTestId("outputsChat Output")).not.toBeVisible();
|
||||
await expect(page.getByTestId("promptsPrompt")).not.toBeVisible();
|
||||
await expect(page.getByTestId("modelsAmazon Bedrock")).not.toBeVisible();
|
||||
await expect(page.getByTestId("helpersChat Memory")).not.toBeVisible();
|
||||
await expect(page.getByTestId("agentsTool Calling Agent")).not.toBeVisible();
|
||||
await expect(page.getByTestId("chainsConversationChain")).not.toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("prototypesConditional Router"),
|
||||
).not.toBeVisible();
|
||||
|
||||
await page.getByTestId("handle-apirequest-shownode-headers-left").click();
|
||||
|
||||
await expect(page.getByTestId("disclosure-data")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-helpers")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-vector stores")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-utilities")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-prototypes")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-retrievers")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-text splitters")).toBeVisible();
|
||||
await expect(page.getByTestId("disclosure-tools")).toBeVisible();
|
||||
|
||||
await expect(page.getByTestId("dataAPI Request")).toBeVisible();
|
||||
await expect(page.getByTestId("helpersChat Memory")).toBeVisible();
|
||||
await expect(page.getByTestId("vectorstoresAstra DB")).toBeVisible();
|
||||
await expect(page.getByTestId("langchain_utilitiesSearchApi")).toBeVisible();
|
||||
await expect(page.getByTestId("prototypesSub Flow")).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("retrieversSelf Query Retriever"),
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("textsplittersCharacterTextSplitter"),
|
||||
).toBeVisible();
|
||||
await expect(page.getByTestId("toolsSearchApi")).toBeVisible();
|
||||
|
||||
await page.getByPlaceholder("Search").click();
|
||||
|
||||
await expect(page.getByTestId("dataAPI Request")).not.toBeVisible();
|
||||
await expect(page.getByTestId("helpersChat Memory")).not.toBeVisible();
|
||||
await expect(page.getByTestId("vectorstoresAstra DB")).not.toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("langchain_utilitiesSearchApi"),
|
||||
).not.toBeVisible();
|
||||
await expect(page.getByTestId("prototypesSub Flow")).not.toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("retrieversSelf Query Retriever"),
|
||||
).not.toBeVisible();
|
||||
await expect(
|
||||
page.getByTestId("textsplittersCharacterTextSplitter"),
|
||||
).not.toBeVisible();
|
||||
await expect(page.getByTestId("toolsSearchApi")).not.toBeVisible();
|
||||
});
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue