From a1e34143f4153daa6cfe385e8254b9cd2d6feafa Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sat, 23 Mar 2024 18:51:16 -0300 Subject: [PATCH] Fix flow loading and running issues --- src/backend/langflow/graph/graph/base.py | 35 +++++++++------ src/backend/langflow/graph/vertex/base.py | 45 ++++++++++--------- src/backend/langflow/helpers/flow.py | 41 ++++++++--------- .../custom_component/custom_component.py | 2 +- .../services/database/models/flow/model.py | 2 +- 5 files changed, 69 insertions(+), 56 deletions(-) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 41cc94aed..2fb1aa71c 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -15,8 +15,6 @@ from langflow.graph.vertex.types import ChatVertex, FileToolVertex, LLMVertex, R from langflow.interface.tools.constants import FILE_TOOLS from langflow.schema import Record from langflow.schema.schema import INPUT_FIELD_NAME -from langflow.schema import Record -from langflow.schema.schema import INPUT_FIELD_NAME if TYPE_CHECKING: from langflow.graph.schema import ResultData @@ -226,18 +224,21 @@ class Graph: if not isinstance(inputs.get(INPUT_FIELD_NAME, ""), str): raise ValueError(f"Invalid input value: {inputs.get(INPUT_FIELD_NAME)}. Expected string") - for vertex_id in self._is_input_vertices: - vertex = self.get_vertex(vertex_id) - # If the vertex is not in the input_components list - if input_components and (vertex_id not in input_components or vertex.display_name not in input_components): - continue - # If the input_type is not any and the input_type is not in the vertex id - # Example: input_type = "chat" and vertex.id = "OpenAI-19ddn" - elif input_type != "any" and input_type not in vertex.id.lower(): - continue - if vertex is None: - raise ValueError(f"Vertex {vertex_id} not found") - vertex.update_raw_params(inputs, overwrite=True) + if inputs: + for vertex_id in self._is_input_vertices: + vertex = self.get_vertex(vertex_id) + # If the vertex is not in the input_components list + if input_components and ( + vertex_id not in input_components or vertex.display_name not in input_components + ): + continue + # If the input_type is not any and the input_type is not in the vertex id + # Example: input_type = "chat" and vertex.id = "OpenAI-19ddn" + elif input_type != "any" and input_type not in vertex.id.lower(): + continue + if vertex is None: + raise ValueError(f"Vertex {vertex_id} not found") + vertex.update_raw_params(inputs, overwrite=True) # Update all the vertices with the session_id for vertex_id in self._has_session_id_vertices: vertex = self.get_vertex(vertex_id) @@ -333,6 +334,12 @@ class Graph: inputs = [inputs] elif not inputs: inputs = [{}] + # Length of all should be the as inputs length + # just add empty lists to complete the length + for _ in range(len(inputs) - len(inputs_components)): + inputs_components.append([]) + for _ in range(len(inputs) - len(types)): + types.append("any") for run_inputs, components, input_type in zip(inputs, inputs_components, types): run_outputs = await self._run( inputs=run_inputs, diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 408deb1e9..d2bac3e4e 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -4,19 +4,18 @@ import inspect import types from enum import Enum from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, Iterator, List, Optional + from loguru import logger -from langflow.graph.schema import ( - INPUT_COMPONENTS, - OUTPUT_COMPONENTS, - InterfaceComponentTypes, - ResultData, -) +from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData from langflow.graph.utils import UnbuiltObject, UnbuiltResult from langflow.graph.vertex.utils import generate_result from langflow.interface.initialize import loading from langflow.interface.listing import lazy_load_dict from langflow.schema.schema import INPUT_FIELD_NAME +from langflow.services.deps import get_storage_service +from langflow.utils.constants import DIRECT_TYPES +from langflow.utils.schemas import ChatOutputResponse from langflow.utils.util import sync_to_async, unescape_string if TYPE_CHECKING: @@ -193,6 +192,7 @@ class Vertex: self.data = self._data["data"] self.output = self.data["node"]["base_classes"] self.display_name = self.data["node"].get("display_name", self.id.split("-")[0]) + self.description = self.data["node"].get("description", "") self.frozen = self.data["node"].get("frozen", False) self.selected_output_type = self.data["node"].get("selected_output_type") self.is_input = self.data["node"].get("is_input") or self.is_input @@ -296,8 +296,14 @@ class Vertex: # value.get('value') is the file name if file_path := value.get("file_path"): storage_service = get_storage_service() - flow_id, file_name = file_path.split("/") - full_path = storage_service.build_full_path(flow_id, file_name) + try: + flow_id, file_name = file_path.split("/") + full_path = storage_service.build_full_path(flow_id, file_name) + except ValueError as e: + if "too many values to unpack" in str(e): + full_path = file_path + else: + raise e params[key] = full_path elif value.get("required"): raise ValueError(f"File path not found for {self.display_name}") @@ -388,19 +394,18 @@ class Vertex: Returns: List[str]: The extracted messages. """ - messages = [] - for key, artifact in artifacts.items(): - if not isinstance(artifact, dict): - continue - if "message" in artifact: - chat_output_response = ChatOutputResponse( - message=artifact["message"], - sender=artifact.get("sender"), - sender_name=artifact.get("sender_name"), - session_id=artifact.get("session_id"), + try: + messages = [ + ChatOutputResponse( + message=artifacts["message"], + sender=artifacts.get("sender"), + sender_name=artifacts.get("sender_name"), + session_id=artifacts.get("session_id"), component_id=self.id, - ) - messages.append(chat_output_response.model_dump(exclude_none=True)) + ).model_dump(exclude_none=True) + ] + except KeyError: + messages = [] return messages diff --git a/src/backend/langflow/helpers/flow.py b/src/backend/langflow/helpers/flow.py index f81e68915..eda5d3116 100644 --- a/src/backend/langflow/helpers/flow.py +++ b/src/backend/langflow/helpers/flow.py @@ -19,16 +19,25 @@ def list_flows(*, user_id: Optional[str] = None) -> List[Record]: select(Flow).where(Flow.user_id == user_id).where(Flow.is_component == False) # noqa ).all() - flows_records = [flow.to_record() for flow in flows] - return flows_records + flows_records = [flow.to_record() for flow in flows] + return flows_records except Exception as e: raise ValueError(f"Error listing flows: {e}") -async def load_flow(flow_id: str, tweaks: Optional[dict] = None) -> "Graph": +async def load_flow( + user_id: str, flow_id: Optional[str] = None, flow_name: Optional[str] = None, tweaks: Optional[dict] = None +) -> "Graph": from langflow.graph.graph.base import Graph from langflow.processing.process import process_tweaks + if not flow_id and not flow_name: + raise ValueError("Flow ID or Flow Name is required") + if not flow_id and flow_name: + flow_id = find_flow(flow_name, user_id) + if not flow_id: + raise ValueError(f"Flow {flow_name} not found") + with session_scope() as session: graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None if not graph_data: @@ -39,28 +48,20 @@ async def load_flow(flow_id: str, tweaks: Optional[dict] = None) -> "Graph": return graph +def find_flow(flow_name: str, user_id: str) -> Optional[str]: + with session_scope() as session: + flow = session.exec(select(Flow).where(Flow.name == flow_name).where(Flow.user_id == user_id)).first() + return flow.id if flow else None + + async def run_flow( inputs: Union[dict, List[dict]] = None, + tweaks: Optional[dict] = None, flow_id: Optional[str] = None, flow_name: Optional[str] = None, - tweaks: Optional[dict] = None, - flows_records: Optional[List[Record]] = None, + user_id: Optional[str] = None, ) -> Any: - if not flow_id and not flow_name: - raise ValueError("Flow ID or Flow Name is required") - if not flows_records: - flows_records = list_flows() - if not flow_id and flows_records: - flow_ids = [flow.data["id"] for flow in flows_records if flow.data["name"] == flow_name] - if not flow_ids: - raise ValueError(f"Flow {flow_name} not found") - elif len(flow_ids) > 1: - raise ValueError(f"Multiple flows found with the name {flow_name}") - flow_id = flow_ids[0] - - if not flow_id: - raise ValueError(f"Flow {flow_name} not found") - graph = await load_flow(flow_id, tweaks) + graph = await load_flow(user_id, flow_id, flow_name, tweaks) if inputs is None: inputs = [] diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index 42d837157..87dbc2c0f 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -326,7 +326,7 @@ class CustomComponent(Component): flow_name: Optional[str] = None, tweaks: Optional[dict] = None, ) -> Any: - return await run_flow(inputs=inputs, flow_id=flow_id, flow_name=flow_name, tweaks=tweaks) + return await run_flow(inputs=inputs, flow_id=flow_id, flow_name=flow_name, tweaks=tweaks, user_id=self._user_id) def list_flows(self) -> List[Record]: if not self._user_id: diff --git a/src/backend/langflow/services/database/models/flow/model.py b/src/backend/langflow/services/database/models/flow/model.py index 8a97e97a1..95311909c 100644 --- a/src/backend/langflow/services/database/models/flow/model.py +++ b/src/backend/langflow/services/database/models/flow/model.py @@ -108,7 +108,7 @@ class Flow(FlowBase, table=True): "description": serialized.pop("description"), "updated_at": serialized.pop("updated_at"), } - record = Record(text=data.get("name"), data=data) + record = Record(data=data) return record