diff --git a/src/backend/langflow/api/v1/base.py b/src/backend/langflow/api/v1/base.py index 28be40ae0..420e1645f 100644 --- a/src/backend/langflow/api/v1/base.py +++ b/src/backend/langflow/api/v1/base.py @@ -14,7 +14,7 @@ class Code(BaseModel): class FrontendNodeRequest(FrontendNode): - template: dict + template: dict # type: ignore class ValidatePromptRequest(BaseModel): diff --git a/src/backend/langflow/chat/utils.py b/src/backend/langflow/chat/utils.py index d070a7457..7db65b8e3 100644 --- a/src/backend/langflow/chat/utils.py +++ b/src/backend/langflow/chat/utils.py @@ -21,6 +21,10 @@ async def process_graph( # Generate result and thought try: + if not chat_inputs.message: + logger.debug("No message provided") + raise ValueError("No message provided") + logger.debug("Generating result and thought") result, intermediate_steps = await get_result_and_steps( langchain_object, chat_inputs.message, websocket=websocket diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index 032a2f049..f4f83301a 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -153,6 +153,7 @@ memories: documentation: "https://python.langchain.com/docs/modules/memory/how_to/vectorstore_retriever_memory" MongoDBChatMessageHistory: documentation: "https://python.langchain.com/docs/modules/memory/integrations/mongodb_chat_message_history" +prompts: ChatMessagePromptTemplate: documentation: "https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/msg_prompt_templates" HumanMessagePromptTemplate: @@ -161,7 +162,6 @@ memories: documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/how_to/prompts" ChatPromptTemplate: documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/how_to/prompts" -prompts: PromptTemplate: documentation: "https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/" textsplitters: @@ -280,6 +280,8 @@ vectorstores: wrappers: RequestsWrapper: documentation: "" + SQLDatabase: + documentation: "" output_parsers: StructuredOutputParser: documentation: "https://python.langchain.com/docs/modules/model_io/output_parsers/structured" diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index bbafb4526..58ef1b508 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -23,6 +23,7 @@ CUSTOM_NODES = { }, "memories": { "PostgresChatMessageHistory": frontend_node.memories.PostgresChatMessageHistoryFrontendNode(), + "MongoDBChatMessageHistory": frontend_node.memories.MongoDBChatMessageHistoryFrontendNode(), }, "chains": { "SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(), diff --git a/src/backend/langflow/interface/importing/utils.py b/src/backend/langflow/interface/importing/utils.py index 3c7f89b5b..ccfd8d5dd 100644 --- a/src/backend/langflow/interface/importing/utils.py +++ b/src/backend/langflow/interface/importing/utils.py @@ -10,6 +10,7 @@ from langchain.chains.base import Chain from langchain.chat_models.base import BaseChatModel from langchain.tools import BaseTool from langflow.utils import validate +from langflow.interface.wrappers.base import wrapper_creator def import_module(module_path: str) -> Any: @@ -96,7 +97,11 @@ def import_prompt(prompt: str) -> Type[PromptTemplate]: def import_wrapper(wrapper: str) -> Any: """Import wrapper from wrapper name""" - return import_module(f"from langchain.requests import {wrapper}") + if ( + isinstance(wrapper_creator.type_dict, dict) + and wrapper in wrapper_creator.type_dict + ): + return wrapper_creator.type_dict.get(wrapper) def import_toolkit(toolkit: str) -> Any: diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 3c3616dd8..25149cd4b 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -20,6 +20,7 @@ from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.chains.base import chain_creator from langflow.interface.output_parsers.base import output_parser_creator from langflow.interface.retrievers.base import retriever_creator +from langflow.interface.wrappers.base import wrapper_creator from langflow.interface.utils import load_file_into_dict from langflow.utils import validate from langchain.chains.base import Chain @@ -89,10 +90,21 @@ def instantiate_based_on_type(class_object, base_type, node_type, params): return instantiate_retriever(node_type, class_object, params) elif base_type == "memory": return instantiate_memory(node_type, class_object, params) + elif base_type == "wrappers": + return instantiate_wrapper(node_type, class_object, params) else: return class_object(**params) +def instantiate_wrapper(node_type, class_object, params): + if node_type in wrapper_creator.from_method_nodes: + method = wrapper_creator.from_method_nodes[node_type] + if class_method := getattr(class_object, method, None): + return class_method(**params) + raise ValueError(f"Method {method} not found in {class_object}") + return class_object(**params) + + def instantiate_output_parser(node_type, class_object, params): if node_type in output_parser_creator.from_method_nodes: method = output_parser_creator.from_method_nodes[node_type] @@ -115,7 +127,7 @@ def instantiate_memory(node_type, class_object, params): # process input_key and output_key to remove them if # they are empty strings for key in ["input_key", "output_key"]: - if key in params and not params[key]: + if key in params and (params[key] == "" or not params[key]): params.pop(key) try: @@ -193,7 +205,7 @@ def instantiate_prompt(node_type, class_object, params: Dict): prompt = class_object(**params) - format_kwargs = {} + format_kwargs: Dict[str, Any] = {} for input_variable in prompt.input_variables: if input_variable in params: variable = params[input_variable] diff --git a/src/backend/langflow/interface/wrappers/base.py b/src/backend/langflow/interface/wrappers/base.py index f5773d07a..77e38f921 100644 --- a/src/backend/langflow/interface/wrappers/base.py +++ b/src/backend/langflow/interface/wrappers/base.py @@ -1,25 +1,36 @@ from typing import Dict, List, Optional -from langchain import requests +from langchain import requests, sql_database from langflow.interface.base import LangChainTypeCreator from langflow.utils.logger import logger -from langflow.utils.util import build_template_from_class +from langflow.utils.util import build_template_from_class, build_template_from_method class WrapperCreator(LangChainTypeCreator): type_name: str = "wrappers" + from_method_nodes = {"SQLDatabase": "from_uri"} + @property def type_to_loader_dict(self) -> Dict: if self.type_dict is None: self.type_dict = { - wrapper.__name__: wrapper for wrapper in [requests.TextRequestsWrapper] + wrapper.__name__: wrapper + for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase] } return self.type_dict def get_signature(self, name: str) -> Optional[Dict]: try: + if name in self.from_method_nodes: + return build_template_from_method( + name, + type_to_cls_dict=self.type_to_loader_dict, + add_function=True, + method_name=self.from_method_nodes[name], + ) + return build_template_from_class(name, self.type_to_loader_dict) except ValueError as exc: raise ValueError("Wrapper not found") from exc diff --git a/src/backend/langflow/processing/base.py b/src/backend/langflow/processing/base.py index b39ad4af1..478b98816 100644 --- a/src/backend/langflow/processing/base.py +++ b/src/backend/langflow/processing/base.py @@ -1,3 +1,4 @@ +from typing import Union from langflow.api.v1.callback import ( AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler, @@ -6,7 +7,7 @@ from langflow.processing.process import fix_memory_inputs, format_actions from langflow.utils.logger import logger -async def get_result_and_steps(langchain_object, inputs: dict, **kwargs): +async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwargs): """Get result and thought from extracted json""" try: diff --git a/src/backend/langflow/template/frontend_node/agents.py b/src/backend/langflow/template/frontend_node/agents.py index 5cae6a346..02aea78b9 100644 --- a/src/backend/langflow/template/frontend_node/agents.py +++ b/src/backend/langflow/template/frontend_node/agents.py @@ -19,8 +19,7 @@ class AgentFrontendNode(FrontendNode): if field.name in ["suffix", "prefix"]: field.show = True if field.name == "Tools" and name == "ZeroShotAgent": - # field. - field.type_name = "BaseTool" + field.field_type = "BaseTool" field.is_list = True diff --git a/src/backend/langflow/template/frontend_node/memories.py b/src/backend/langflow/template/frontend_node/memories.py index 6d490212f..d98a322ff 100644 --- a/src/backend/langflow/template/frontend_node/memories.py +++ b/src/backend/langflow/template/frontend_node/memories.py @@ -4,6 +4,10 @@ from langflow.template.field.base import TemplateField from langflow.template.frontend_node.base import FrontendNode from langflow.template.template.base import Template from langchain.memory.chat_message_histories.postgres import DEFAULT_CONNECTION_STRING +from langchain.memory.chat_message_histories.mongodb import ( + DEFAULT_COLLECTION_NAME, + DEFAULT_DBNAME, +) class MemoryFrontendNode(FrontendNode): @@ -120,3 +124,56 @@ class PostgresChatMessageHistoryFrontendNode(MemoryFrontendNode): ) description: str = "Memory store with Postgres" base_classes: list[str] = ["PostgresChatMessageHistory", "BaseChatMessageHistory"] + + +class MongoDBChatMessageHistoryFrontendNode(MemoryFrontendNode): + name: str = "MongoDBChatMessageHistory" + template: Template = Template( + # langchain/memory/chat_message_histories/mongodb.py + # connection_string: str, + # session_id: str, + # database_name: str = DEFAULT_DBNAME, + # collection_name: str = DEFAULT_COLLECTION_NAME, + type_name="MongoDBChatMessageHistory", + fields=[ + TemplateField( + field_type="str", + required=True, + placeholder="", + is_list=False, + show=True, + multiline=False, + name="session_id", + ), + TemplateField( + field_type="str", + required=True, + show=True, + name="connection_string", + value="", + info="MongoDB connection string (e.g mongodb://mongo_user:password123@mongo:27017)", + ), + TemplateField( + field_type="str", + required=True, + placeholder="", + is_list=False, + show=True, + multiline=False, + value=DEFAULT_DBNAME, + name="database_name", + ), + TemplateField( + field_type="str", + required=True, + placeholder="", + is_list=False, + show=True, + multiline=False, + value=DEFAULT_COLLECTION_NAME, + name="collection_name", + ), + ], + ) + description: str = "Memory store with MongoDB" + base_classes: list[str] = ["MongoDBChatMessageHistory", "BaseChatMessageHistory"] diff --git a/src/frontend/index.html b/src/frontend/index.html index 50bdae647..426983565 100644 --- a/src/frontend/index.html +++ b/src/frontend/index.html @@ -5,6 +5,7 @@ +