From 6ad4de86550ec24dd84de31c02df63347e69cf59 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 27 Feb 2024 18:12:58 -0300 Subject: [PATCH] Add updated_raw_params flag and INPUT_FIELD_NAME constant --- src/backend/langflow/graph/edge/base.py | 50 +++++++++++++++++------ src/backend/langflow/graph/graph/base.py | 38 ++++------------- src/backend/langflow/graph/schema.py | 2 + src/backend/langflow/graph/vertex/base.py | 17 +++++--- 4 files changed, 60 insertions(+), 47 deletions(-) diff --git a/src/backend/langflow/graph/edge/base.py b/src/backend/langflow/graph/edge/base.py index 6706156a0..cfcd33dd1 100644 --- a/src/backend/langflow/graph/edge/base.py +++ b/src/backend/langflow/graph/edge/base.py @@ -4,6 +4,7 @@ from loguru import logger from pydantic import BaseModel, Field from langflow.graph.edge.utils import build_clean_params +from langflow.graph.schema import INPUT_FIELD_NAME from langflow.services.deps import get_monitor_service from langflow.services.monitor.utils import log_message @@ -12,7 +13,9 @@ if TYPE_CHECKING: class SourceHandle(BaseModel): - baseClasses: List[str] = Field(..., description="List of base classes for the source handle.") + baseClasses: List[str] = Field( + ..., description="List of base classes for the source handle." + ) dataType: str = Field(..., description="Data type for the source handle.") id: str = Field(..., description="Unique identifier for the source handle.") @@ -20,7 +23,9 @@ class SourceHandle(BaseModel): class TargetHandle(BaseModel): fieldName: str = Field(..., description="Field name for the target handle.") id: str = Field(..., description="Unique identifier for the target handle.") - inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.") + inputTypes: Optional[List[str]] = Field( + None, description="List of input types for the target handle." + ) type: str = Field(..., description="Type of the target handle.") @@ -49,16 +54,24 @@ class Edge: def validate_handles(self, source, target) -> None: if self.target_handle.inputTypes is None: - self.valid_handles = self.target_handle.type in self.source_handle.baseClasses + self.valid_handles = ( + self.target_handle.type in self.source_handle.baseClasses + ) else: self.valid_handles = ( - any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses) + any( + baseClass in self.target_handle.inputTypes + for baseClass in self.source_handle.baseClasses + ) or self.target_handle.type in self.source_handle.baseClasses ) if not self.valid_handles: logger.debug(self.source_handle) logger.debug(self.target_handle) - raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles") + raise ValueError( + f"Edge between {source.vertex_type} and {target.vertex_type} " + f"has invalid handles" + ) def __setstate__(self, state): self.source_id = state["source_id"] @@ -75,7 +88,11 @@ class Edge: # Both lists contain strings and sometimes a string contains the value we are # looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"] # so we need to check if any of the strings in source_types is in target_reqs - self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs) + self.valid = any( + output in target_req + for output in self.source_types + for target_req in self.target_reqs + ) # Get what type of input the target node is expecting self.matched_type = next( @@ -86,7 +103,10 @@ class Edge: if no_matched_type: logger.debug(self.source_types) logger.debug(self.target_reqs) - raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type") + raise ValueError( + f"Edge between {source.vertex_type} and {target.vertex_type} " + f"has no matched type" + ) def __repr__(self) -> str: return ( @@ -98,7 +118,11 @@ class Edge: return hash(self.__repr__()) def __eq__(self, __value: object) -> bool: - return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False + return ( + self.__repr__() == __value.__repr__() + if isinstance(__value, Edge) + else False + ) class ContractEdge(Edge): @@ -137,15 +161,15 @@ class ContractEdge(Edge): log_transaction(self, source, target, "success") # If the target vertex is a power component we log messages if target.vertex_type == "ChatOutput" and ( - isinstance(target.params.get("input_value"), str) - or isinstance(target.params.get("input_value"), dict) + isinstance(target.params.get(INPUT_FIELD_NAME), str) + or isinstance(target.params.get(INPUT_FIELD_NAME), dict) ): if target.params.get("message") == "": return self.result await log_message( sender=target.params.get("sender", ""), sender_name=target.params.get("sender_name", ""), - message=target.params.get("input_value", {}), + message=target.params.get(INPUT_FIELD_NAME, {}), session_id=target.params.get("session_id", ""), artifacts=target.artifacts, ) @@ -155,7 +179,9 @@ class ContractEdge(Edge): return f"{self.source_id} -[{self.target_param}]-> {self.target_id}" -def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None): +def log_transaction( + edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None +): try: monitor_service = get_monitor_service() clean_params = build_clean_params(target) diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 341a4729c..0ffc9825b 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -33,8 +33,6 @@ class Graph: edges: List[Dict[str, str]], flow_id: Optional[str] = None, ) -> None: - self.inputs = [] - self.outputs = [] self._vertices = nodes self._edges = edges self.raw_graph_data = {"nodes": nodes, "edges": edges} @@ -77,7 +75,7 @@ class Graph: async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]: """Runs the graph with the given inputs.""" - for vertex_id in self.inputs: + for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") @@ -89,7 +87,7 @@ class Graph: logger.exception(exc) raise ValueError(f"Error running graph: {exc}") from exc outputs = [] - for vertex_id in self.outputs: + for vertex_id in self._is_output_vertices: vertex = self.get_vertex(vertex_id) if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") @@ -104,11 +102,11 @@ class Graph: # of the vertices that are inputs # if the value is a list, we need to run multiple times outputs = [] - inputs_values = inputs.get("input_value") + inputs_values = inputs.get(INPUT_FIELD_NAME) if not isinstance(inputs_values, list): inputs_values = [inputs_values] for input_value in inputs_values: - run_outputs = await self._run({"input_value": input_value}) + run_outputs = await self._run({INPUT_FIELD_NAME: input_value}) logger.debug(f"Run outputs: {run_outputs}") outputs.extend(run_outputs) return outputs @@ -317,28 +315,6 @@ class Graph: # Now that we have the vertices and edges # We need to map the vertices that are connected to # to ChatVertex instances - self._map_chat_vertices() - - def _map_chat_vertices(self) -> None: - """Maps the vertices that are connected to ChatVertex instances.""" - # For each edge, we need to check if the source or target vertex is a ChatVertex - # If it is, we need to update the other vertex `is_external` attribute - # and store the id of the ChatVertex in the attributes self.inputs and self.outputs - for edge in self.edges: - source_vertex = self.get_vertex(edge.source_id) - target_vertex = self.get_vertex(edge.target_id) - if isinstance(source_vertex, ChatVertex): - # The source vertex is a ChatVertex - # thus the target vertex is an external vertex - # and the source vertex is an input - target_vertex.has_external_input = True - self.inputs.append(source_vertex.id) - if isinstance(target_vertex, ChatVertex): - # The target vertex is a ChatVertex - # thus the source vertex is an external vertex - # and the target vertex is an output - source_vertex.has_external_output = True - self.outputs.append(target_vertex.id) def remove_vertex(self, vertex_id: str) -> None: """Removes a vertex from the graph.""" @@ -443,13 +419,15 @@ class Graph: async def _execute_tasks(self, tasks): """Executes tasks in parallel, handling exceptions for each task.""" results = [] - for task in asyncio.as_completed(tasks): + for i, task in enumerate(asyncio.as_completed(tasks)): try: result = await task results.append(result) except Exception as e: # Log the exception along with the task name for easier debugging - task_name = task.get_name() + # task_name = task.get_name() + # coroutine has not attribute get_name + task_name = tasks[i].get_name() logger.error(f"Task {task_name} failed with exception: {e}") return results diff --git a/src/backend/langflow/graph/schema.py b/src/backend/langflow/graph/schema.py index d41e0544a..028b8db9f 100644 --- a/src/backend/langflow/graph/schema.py +++ b/src/backend/langflow/graph/schema.py @@ -35,3 +35,5 @@ OUTPUT_COMPONENTS = [ InterfaceComponentTypes.ChatOutput, InterfaceComponentTypes.TextOutput, ] + +INPUT_FIELD_NAME = "input_value" diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 3e1133491..dd308f9f1 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -44,7 +44,7 @@ class Vertex: ) -> None: # is_external means that the Vertex send or receives data from # an external source (e.g the chat) - + self.updated_raw_params = False self.id: str = data["id"] self.is_input = any( input_component_name in self.id for input_component_name in INPUT_COMPONENTS @@ -285,6 +285,10 @@ class Vertex: if self.graph is None: raise ValueError("Graph not found") + if self.updated_raw_params: + self.updated_raw_params = False + return + template_dict = { key: value for key, value in self.data["node"]["template"].items() @@ -386,10 +390,11 @@ class Vertex: Raises: ValueError: If any key in new_params is not found in self._raw_params. """ - for key in new_params: - if key not in self._raw_params: - raise ValueError(f"Key {key} not found in raw params") + # First check if the input_value in _raw_params is not a vertex + if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params): + return self._raw_params.update(new_params) + self.updated_raw_params = True async def _build(self, user_id=None): """ @@ -451,6 +456,8 @@ class Vertex: await self._build_node_and_update_params(key, value, user_id) elif isinstance(value, list) and self._is_list_of_nodes(value): await self._build_list_of_nodes_and_update_params(key, value, user_id) + elif key not in self.params: + self.params[key] = value def _is_node(self, value): """ @@ -586,7 +593,7 @@ class Vertex: logger.warning(message) - def _reset(self): + def _reset(self, params_update: Optional[Dict[str, Any]] = None): self._built = False self._built_object = UnbuiltObject() self._built_result = UnbuiltResult()