diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index c9d12f4b9..f6da0edc7 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -1,17 +1,15 @@ -import contextlib import json -from typing import Any, Callable, Dict, List, Sequence, Type +from typing import Any, Callable, Dict, Sequence, Type -from langchain.agents import ZeroShotAgent from langchain.agents import agent as agent_module from langchain.agents.agent import AgentExecutor from langchain.agents.agent_toolkits.base import BaseToolkit from langchain.agents.tools import BaseTool from langflow.interface.initialize.llm import initialize_vertexai +from langflow.interface.initialize.utils import handle_format_kwargs, handle_node_type from langflow.interface.initialize.vector_store import vecstore_initializer -from langchain.schema import Document, BaseOutputParser from pydantic import ValidationError from langflow.interface.importing.utils import ( @@ -212,68 +210,8 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params: def instantiate_prompt(node_type, class_object, params: Dict): - if node_type == "ZeroShotPrompt": - if "tools" not in params: - params["tools"] = [] - return ZeroShotAgent.create_prompt(**params) - elif "MessagePromptTemplate" in node_type: - # Then we only need the template - from_template_params = { - "template": params.pop("prompt", params.pop("template", "")) - } - - if not from_template_params.get("template"): - raise ValueError("Prompt template is required") - prompt = class_object.from_template(**from_template_params) - - elif node_type == "ChatPromptTemplate": - prompt = class_object.from_messages(**params) - else: - prompt = class_object(**params) - - format_kwargs: Dict[str, Any] = {} - for input_variable in prompt.input_variables: - if input_variable in params: - variable = params[input_variable] - if isinstance(variable, str): - format_kwargs[input_variable] = variable - elif isinstance(variable, BaseOutputParser) and hasattr( - variable, "get_format_instructions" - ): - format_kwargs[input_variable] = variable.get_format_instructions() - elif isinstance(variable, List) and all( - isinstance(item, Document) for item in variable - ): - # Format document to contain page_content and metadata - # as one string separated by a newline - if len(variable) > 1: - content = "\n".join( - [item.page_content for item in variable if item.page_content] - ) - else: - content = variable[0].page_content - # content could be a json list of strings - with contextlib.suppress(json.JSONDecodeError): - content = json.loads(content) - if isinstance(content, list): - content = ",".join([str(item) for item in content]) - format_kwargs[input_variable] = content - # handle_keys will be a list but it does not exist yet - # so we need to create it - - if ( - isinstance(variable, List) - and all(isinstance(item, Document) for item in variable) - ) or ( - isinstance(variable, BaseOutputParser) - and hasattr(variable, "get_format_instructions") - ): - if "handle_keys" not in format_kwargs: - format_kwargs["handle_keys"] = [] - - # Add the handle_keys to the list - format_kwargs["handle_keys"].append(input_variable) - + params, prompt = handle_node_type(node_type, class_object, params) + format_kwargs = handle_format_kwargs(prompt, params) return prompt, format_kwargs diff --git a/src/backend/langflow/interface/initialize/utils.py b/src/backend/langflow/interface/initialize/utils.py new file mode 100644 index 000000000..0e8cf7bee --- /dev/null +++ b/src/backend/langflow/interface/initialize/utils.py @@ -0,0 +1,107 @@ +import contextlib +import json +from typing import Any, Dict, List + +from langchain.agents import ZeroShotAgent + + +from langchain.schema import Document, BaseOutputParser + + +def handle_node_type(node_type, class_object, params: Dict): + if node_type == "ZeroShotPrompt": + params = check_tools_in_params(params) + prompt = ZeroShotAgent.create_prompt(**params) + elif "MessagePromptTemplate" in node_type: + prompt = instantiate_from_template(class_object, params) + elif node_type == "ChatPromptTemplate": + prompt = class_object.from_messages(**params) + else: + prompt = class_object(**params) + return params, prompt + + +def check_tools_in_params(params: Dict): + if "tools" not in params: + params["tools"] = [] + return params + + +def instantiate_from_template(class_object, params: Dict): + from_template_params = { + "template": params.pop("prompt", params.pop("template", "")) + } + if not from_template_params.get("template"): + raise ValueError("Prompt template is required") + return class_object.from_template(**from_template_params) + + +def handle_format_kwargs(prompt, params: Dict): + format_kwargs: Dict[str, Any] = {} + for input_variable in prompt.input_variables: + if input_variable in params: + format_kwargs = handle_variable(params, input_variable, format_kwargs) + return format_kwargs + + +def handle_variable(params: Dict, input_variable: str, format_kwargs: Dict): + variable = params[input_variable] + if isinstance(variable, str): + format_kwargs[input_variable] = variable + elif isinstance(variable, BaseOutputParser) and hasattr( + variable, "get_format_instructions" + ): + format_kwargs[input_variable] = variable.get_format_instructions() + elif is_instance_of_list_or_document(variable): + format_kwargs = format_document(variable, input_variable, format_kwargs) + if needs_handle_keys(variable): + format_kwargs = add_handle_keys(input_variable, format_kwargs) + return format_kwargs + + +def is_instance_of_list_or_document(variable): + return ( + isinstance(variable, List) + and all(isinstance(item, Document) for item in variable) + or isinstance(variable, Document) + ) + + +def format_document(variable, input_variable: str, format_kwargs: Dict): + variable = variable if isinstance(variable, List) else [variable] + content = format_content(variable) + format_kwargs[input_variable] = content + return format_kwargs + + +def format_content(variable): + if len(variable) > 1: + content = "\n".join( + [item.page_content for item in variable if item.page_content] + ) + else: + content = variable[0].page_content + content = try_to_load_json(content) + return content + + +def try_to_load_json(content): + with contextlib.suppress(json.JSONDecodeError): + content = json.loads(content) + if isinstance(content, list): + content = ",".join([str(item) for item in content]) + return content + + +def needs_handle_keys(variable): + return is_instance_of_list_or_document(variable) or ( + isinstance(variable, BaseOutputParser) + and hasattr(variable, "get_format_instructions") + ) + + +def add_handle_keys(input_variable: str, format_kwargs: Dict): + if "handle_keys" not in format_kwargs: + format_kwargs["handle_keys"] = [] + format_kwargs["handle_keys"].append(input_variable) + return format_kwargs