diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index d31edfbce..8addb1ac7 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -2,8 +2,7 @@ import ast import inspect import types from enum import Enum -from typing import (TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, - Optional) +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional from loguru import logger @@ -41,6 +40,7 @@ class Vertex: ) -> None: # is_external means that the Vertex send or receives data from # an external source (e.g the chat) + self._custom_component = None self.has_external_input = False self.has_external_output = False self.graph = graph @@ -202,6 +202,7 @@ class Vertex: self.output = self.data["node"]["base_classes"] self.display_name = self.data["node"]["display_name"] self.pinned = self.data["node"].get("pinned", False) + self.selected_output_type = self.data["node"].get("selected_output_type") template_dicts = { key: value for key, value in self.data["node"]["template"].items() @@ -500,11 +501,17 @@ class Vertex: if self.base_type is None: raise ValueError(f"Base type for node {self.display_name} not found") try: + outgoing_edges = self.graph.get_vertex_edges( + self.id, is_source=True, is_target=False + ) + result = await loading.instantiate_class( node_type=self.vertex_type, base_type=self.base_type, params=self.params, user_id=user_id, + outgoing_edges=outgoing_edges, + selected_output_type=self.selected_output_type, ) self._update_built_object_and_artifacts(result) except Exception as exc: @@ -518,7 +525,10 @@ class Vertex: Updates the built object and its artifacts. """ if isinstance(result, tuple): - self._built_object, self.artifacts = result + if len(result) == 2: + self._built_object, self.artifacts = result + elif len(result) == 3: + self._custom_component, self._built_object, self.artifacts = result else: self._built_object = result diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index 3534baf2a..7d8794878 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -1,6 +1,15 @@ import operator from pathlib import Path -from typing import Any, Callable, ClassVar, List, Optional, Sequence, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + List, + Optional, + Sequence, + Union, +) from uuid import UUID import yaml @@ -24,6 +33,9 @@ from langflow.services.deps import ( from langflow.services.storage.service import StorageService from langflow.utils import validate +if TYPE_CHECKING: + from langflow.graph.edge.base import ContractEdge + class CustomComponent(Component): display_name: Optional[str] = None @@ -40,6 +52,12 @@ class CustomComponent(Component): """The field order of the component. Defaults to an empty list.""" pinned: Optional[bool] = False """The default pinned state of the component. Defaults to False.""" + build_parameters: Optional[dict] = None + """The build parameters of the component. Defaults to None.""" + selected_output_type: Optional[str] = None + """The selected output type of the component. Defaults to None.""" + outgoing_edges: Optional[List["ContractEdge"]] = None + """The edge target parameter of the component. Defaults to None.""" code_class_base_inheritance: ClassVar[str] = "CustomComponent" function_entrypoint_name: ClassVar[str] = "build" function: Optional[Callable] = None @@ -88,7 +106,9 @@ class CustomComponent(Component): def tree(self): return self.get_code_tree(self.code or "") - def to_records(self, data: Any, text_key: str = "text", data_key: str = "data") -> List[dict]: + def to_records( + self, data: Any, text_key: str = "text", data_key: str = "data" + ) -> List[dict]: """ Convert data into a list of records. @@ -115,7 +135,9 @@ class CustomComponent(Component): return records - def create_references_from_records(self, records: List[dict], include_data: bool = False) -> str: + def create_references_from_records( + self, records: List[dict], include_data: bool = False + ) -> str: """ Create references from a list of records. @@ -150,7 +172,8 @@ class CustomComponent(Component): detail={ "error": "Type hint Error", "traceback": ( - "Prompt type is not supported in the build method." " Try using PromptTemplate instead." + "Prompt type is not supported in the build method." + " Try using PromptTemplate instead." ), }, ) @@ -164,14 +187,20 @@ class CustomComponent(Component): if not self.code: return {} - component_classes = [cls for cls in self.tree["classes"] if self.code_class_base_inheritance in cls["bases"]] + component_classes = [ + cls + for cls in self.tree["classes"] + if self.code_class_base_inheritance in cls["bases"] + ] if not component_classes: return {} # Assume the first Component class is the one we're interested in component_class = component_classes[0] build_methods = [ - method for method in component_class["methods"] if method["name"] == self.function_entrypoint_name + method + for method in component_class["methods"] + if method["name"] == self.function_entrypoint_name ] return build_methods[0] if build_methods else {} @@ -228,7 +257,9 @@ class CustomComponent(Component): # Retrieve and decrypt the credential by name for the current user db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session) + return credential_service.get_credential( + user_id=self._user_id or "", name=name, session=session + ) return get_credential @@ -238,7 +269,9 @@ class CustomComponent(Component): credential_service = get_credential_service() db_service = get_db_service() with session_getter(db_service) as session: - return credential_service.list_credentials(user_id=self._user_id, session=session) + return credential_service.list_credentials( + user_id=self._user_id, session=session + ) def index(self, value: int = 0): """Returns a function that returns the value at the given index in the iterable.""" @@ -289,7 +322,11 @@ class CustomComponent(Component): if flow_id: flow = session.query(Flow).get(flow_id) elif flow_name: - flow = (session.query(Flow).filter(Flow.name == flow_name).filter(Flow.user_id == self.user_id)).first() + flow = ( + session.query(Flow) + .filter(Flow.name == flow_name) + .filter(Flow.user_id == self.user_id) + ).first() else: raise ValueError("Either flow_name or flow_id must be provided") diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 83bd67321..6694da26a 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -1,6 +1,6 @@ import inspect import json -from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Type +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Type import orjson from langchain.agents import agent as agent_module @@ -34,16 +34,27 @@ from langflow.utils import validate if TYPE_CHECKING: from langflow import CustomComponent + from langflow.graph.edge.base import ContractEdge def build_vertex_in_params(params: Dict) -> Dict: from langflow.graph.vertex.base import Vertex # If any of the values in params is a Vertex, we will build it - return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()} + return { + key: value.build() if isinstance(value, Vertex) else value + for key, value in params.items() + } -async def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any: +async def instantiate_class( + node_type: str, + base_type: str, + params: Dict, + user_id=None, + outgoing_edges: Optional[List["ContractEdge"]] = None, + selected_output_type: Optional[str] = None, +) -> Any: """Instantiate class from module type and key, and params""" params = convert_params_to_sets(params) params = convert_kwargs(params) @@ -55,7 +66,15 @@ async def instantiate_class(node_type: str, base_type: str, params: Dict, user_i return custom_node(**params) logger.debug(f"Instantiating {node_type} of type {base_type}") class_object = import_by_type(_type=base_type, name=node_type) - return await instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id) + return await instantiate_based_on_type( + class_object, + base_type, + node_type, + params, + user_id=user_id, + outgoing_edges=outgoing_edges, + selected_output_type=selected_output_type, + ) def convert_params_to_sets(params): @@ -82,7 +101,15 @@ def convert_kwargs(params): return params -async def instantiate_based_on_type(class_object, base_type, node_type, params, user_id): +async def instantiate_based_on_type( + class_object, + base_type, + node_type, + params, + user_id, + outgoing_edges, + selected_output_type, +): if base_type == "agents": return instantiate_agent(node_type, class_object, params) elif base_type == "prompts": @@ -116,17 +143,33 @@ async def instantiate_based_on_type(class_object, base_type, node_type, params, elif base_type == "memory": return instantiate_memory(node_type, class_object, params) elif base_type == "custom_components": - return await instantiate_custom_component(node_type, class_object, params, user_id) + return await instantiate_custom_component( + node_type, + class_object, + params, + user_id, + outgoing_edges, + selected_output_type, + ) elif base_type == "wrappers": return instantiate_wrapper(node_type, class_object, params) else: return class_object(**params) -async def instantiate_custom_component(node_type, class_object, params, user_id): +async def instantiate_custom_component( + node_type, class_object, params, user_id, outgoing_edges, selected_output_type +): params_copy = params.copy() - class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code")) - custom_component: "CustomComponent" = class_object(user_id=user_id) + class_object: Type["CustomComponent"] = eval_custom_component_code( + params_copy.pop("code") + ) + custom_component: "CustomComponent" = class_object( + user_id=user_id, + parameters=params_copy, + outgoing_edges=outgoing_edges, + selected_output_type=selected_output_type, + ) if "retriever" in params_copy and hasattr(params_copy["retriever"], "as_retriever"): params_copy["retriever"] = params_copy["retriever"].as_retriever() @@ -141,7 +184,7 @@ async def instantiate_custom_component(node_type, class_object, params, user_id) # Call the build method directly if it's sync build_result = custom_component.build(**params_copy) - return build_result, {"repr": custom_component.custom_repr()} + return custom_component, build_result, {"repr": custom_component.custom_repr()} def instantiate_wrapper(node_type, class_object, params): @@ -194,7 +237,9 @@ def instantiate_memory(node_type, class_object, params): # I want to catch a specific attribute error that happens # when the object does not have a cursor attribute except Exception as exc: - if "object has no attribute 'cursor'" in str(exc) or 'object has no field "conn"' in str(exc): + if "object has no attribute 'cursor'" in str( + exc + ) or 'object has no field "conn"' in str(exc): raise AttributeError( ( "Failed to build connection to database." @@ -237,7 +282,9 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params: if class_method := getattr(class_object, method, None): agent = class_method(**params) tools = params.get("tools", []) - return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, handle_parsing_errors=True) + return AgentExecutor.from_agent_and_tools( + agent=agent, tools=tools, handle_parsing_errors=True + ) return load_agent_executor(class_object, params) @@ -293,7 +340,11 @@ def instantiate_embedding(node_type, class_object, params: Dict): try: return class_object(**params) except ValidationError: - params = {key: value for key, value in params.items() if key in class_object.model_fields} + params = { + key: value + for key, value in params.items() + if key in class_object.model_fields + } return class_object(**params) @@ -305,7 +356,9 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict): if "texts" in params: params["documents"] = params.pop("texts") if "documents" in params: - params["documents"] = [doc for doc in params["documents"] if isinstance(doc, Document)] + params["documents"] = [ + doc for doc in params["documents"] if isinstance(doc, Document) + ] if initializer := vecstore_initializer.get(class_object.__name__): vecstore = initializer(class_object, params) else: @@ -320,7 +373,9 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict): return vecstore -def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], params: Dict): +def instantiate_documentloader( + node_type: str, class_object: Type[BaseLoader], params: Dict +): if "file_filter" in params: # file_filter will be a string but we need a function # that will be used to filter the files using file_filter @@ -329,13 +384,17 @@ def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], p # in x and if it is, we will return True file_filter = params.pop("file_filter") extensions = file_filter.split(",") - params["file_filter"] = lambda x: any(extension.strip() in x for extension in extensions) + params["file_filter"] = lambda x: any( + extension.strip() in x for extension in extensions + ) metadata = params.pop("metadata", None) if metadata and isinstance(metadata, str): try: metadata = orjson.loads(metadata) except json.JSONDecodeError as exc: - raise ValueError("The metadata you provided is not a valid JSON string.") from exc + raise ValueError( + "The metadata you provided is not a valid JSON string." + ) from exc if node_type == "WebBaseLoader": if web_path := params.pop("web_path", None): @@ -368,12 +427,16 @@ def instantiate_textsplitter( "Try changing the chunk_size of the Text Splitter." ) from exc - if ("separator_type" in params and params["separator_type"] == "Text") or "separator_type" not in params: + if ( + "separator_type" in params and params["separator_type"] == "Text" + ) or "separator_type" not in params: params.pop("separator_type", None) # separators might come in as an escaped string like \\n # so we need to convert it to a string if "separators" in params: - params["separators"] = params["separators"].encode().decode("unicode-escape") + params["separators"] = ( + params["separators"].encode().decode("unicode-escape") + ) text_splitter = class_object(**params) else: from langchain.text_splitter import Language @@ -400,7 +463,8 @@ def replace_zero_shot_prompt_with_prompt_template(nodes): tools = [ tool for tool in nodes - if tool["type"] != "chatOutputNode" and "Tool" in tool["data"]["node"]["base_classes"] + if tool["type"] != "chatOutputNode" + and "Tool" in tool["data"]["node"]["base_classes"] ] node["data"] = build_prompt_template(prompt=node["data"], tools=tools) break @@ -414,7 +478,9 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs) # agent has hidden args for memory. might need to be support # memory = params["memory"] # if allowed_tools is not a list or set, make it a list - if not isinstance(allowed_tools, (list, set)) and isinstance(allowed_tools, BaseTool): + if not isinstance(allowed_tools, (list, set)) and isinstance( + allowed_tools, BaseTool + ): allowed_tools = [allowed_tools] tool_names = [tool.name for tool in allowed_tools] # Agent class requires an output_parser but Agent classes @@ -442,7 +508,10 @@ def build_prompt_template(prompt, tools): format_instructions = prompt["node"]["template"]["format_instructions"]["value"] tool_strings = "\n".join( - [f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" for tool in tools] + [ + f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" + for tool in tools + ] ) tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools]) format_instructions = format_instructions.format(tool_names=tool_names)