Add updated_raw_params flag and INPUT_FIELD_NAME constant

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 18:12:58 -03:00
commit 6ad4de8655
4 changed files with 60 additions and 47 deletions

View file

@ -4,6 +4,7 @@ from loguru import logger
from pydantic import BaseModel, Field
from langflow.graph.edge.utils import build_clean_params
from langflow.graph.schema import INPUT_FIELD_NAME
from langflow.services.deps import get_monitor_service
from langflow.services.monitor.utils import log_message
@ -12,7 +13,9 @@ if TYPE_CHECKING:
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
baseClasses: List[str] = Field(
..., description="List of base classes for the source handle."
)
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
@ -20,7 +23,9 @@ class SourceHandle(BaseModel):
class TargetHandle(BaseModel):
fieldName: str = Field(..., description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
inputTypes: Optional[List[str]] = Field(
None, description="List of input types for the target handle."
)
type: str = Field(..., description="Type of the target handle.")
@ -49,16 +54,24 @@ class Edge:
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
self.valid_handles = (
self.target_handle.type in self.source_handle.baseClasses
)
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
any(
baseClass in self.target_handle.inputTypes
for baseClass in self.source_handle.baseClasses
)
or self.target_handle.type in self.source_handle.baseClasses
)
if not self.valid_handles:
logger.debug(self.source_handle)
logger.debug(self.target_handle)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has invalid handles"
)
def __setstate__(self, state):
self.source_id = state["source_id"]
@ -75,7 +88,11 @@ class Edge:
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
self.valid = any(
output in target_req
for output in self.source_types
for target_req in self.target_reqs
)
# Get what type of input the target node is expecting
self.matched_type = next(
@ -86,7 +103,10 @@ class Edge:
if no_matched_type:
logger.debug(self.source_types)
logger.debug(self.target_reqs)
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
raise ValueError(
f"Edge between {source.vertex_type} and {target.vertex_type} "
f"has no matched type"
)
def __repr__(self) -> str:
return (
@ -98,7 +118,11 @@ class Edge:
return hash(self.__repr__())
def __eq__(self, __value: object) -> bool:
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False
return (
self.__repr__() == __value.__repr__()
if isinstance(__value, Edge)
else False
)
class ContractEdge(Edge):
@ -137,15 +161,15 @@ class ContractEdge(Edge):
log_transaction(self, source, target, "success")
# If the target vertex is a power component we log messages
if target.vertex_type == "ChatOutput" and (
isinstance(target.params.get("input_value"), str)
or isinstance(target.params.get("input_value"), dict)
isinstance(target.params.get(INPUT_FIELD_NAME), str)
or isinstance(target.params.get(INPUT_FIELD_NAME), dict)
):
if target.params.get("message") == "":
return self.result
await log_message(
sender=target.params.get("sender", ""),
sender_name=target.params.get("sender_name", ""),
message=target.params.get("input_value", {}),
message=target.params.get(INPUT_FIELD_NAME, {}),
session_id=target.params.get("session_id", ""),
artifacts=target.artifacts,
)
@ -155,7 +179,9 @@ class ContractEdge(Edge):
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
def log_transaction(
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
):
try:
monitor_service = get_monitor_service()
clean_params = build_clean_params(target)

View file

@ -33,8 +33,6 @@ class Graph:
edges: List[Dict[str, str]],
flow_id: Optional[str] = None,
) -> None:
self.inputs = []
self.outputs = []
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
@ -77,7 +75,7 @@ class Graph:
async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]:
"""Runs the graph with the given inputs."""
for vertex_id in self.inputs:
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
@ -89,7 +87,7 @@ class Graph:
logger.exception(exc)
raise ValueError(f"Error running graph: {exc}") from exc
outputs = []
for vertex_id in self.outputs:
for vertex_id in self._is_output_vertices:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
@ -104,11 +102,11 @@ class Graph:
# of the vertices that are inputs
# if the value is a list, we need to run multiple times
outputs = []
inputs_values = inputs.get("input_value")
inputs_values = inputs.get(INPUT_FIELD_NAME)
if not isinstance(inputs_values, list):
inputs_values = [inputs_values]
for input_value in inputs_values:
run_outputs = await self._run({"input_value": input_value})
run_outputs = await self._run({INPUT_FIELD_NAME: input_value})
logger.debug(f"Run outputs: {run_outputs}")
outputs.extend(run_outputs)
return outputs
@ -317,28 +315,6 @@ class Graph:
# Now that we have the vertices and edges
# We need to map the vertices that are connected to
# to ChatVertex instances
self._map_chat_vertices()
def _map_chat_vertices(self) -> None:
"""Maps the vertices that are connected to ChatVertex instances."""
# For each edge, we need to check if the source or target vertex is a ChatVertex
# If it is, we need to update the other vertex `is_external` attribute
# and store the id of the ChatVertex in the attributes self.inputs and self.outputs
for edge in self.edges:
source_vertex = self.get_vertex(edge.source_id)
target_vertex = self.get_vertex(edge.target_id)
if isinstance(source_vertex, ChatVertex):
# The source vertex is a ChatVertex
# thus the target vertex is an external vertex
# and the source vertex is an input
target_vertex.has_external_input = True
self.inputs.append(source_vertex.id)
if isinstance(target_vertex, ChatVertex):
# The target vertex is a ChatVertex
# thus the source vertex is an external vertex
# and the target vertex is an output
source_vertex.has_external_output = True
self.outputs.append(target_vertex.id)
def remove_vertex(self, vertex_id: str) -> None:
"""Removes a vertex from the graph."""
@ -443,13 +419,15 @@ class Graph:
async def _execute_tasks(self, tasks):
"""Executes tasks in parallel, handling exceptions for each task."""
results = []
for task in asyncio.as_completed(tasks):
for i, task in enumerate(asyncio.as_completed(tasks)):
try:
result = await task
results.append(result)
except Exception as e:
# Log the exception along with the task name for easier debugging
task_name = task.get_name()
# task_name = task.get_name()
# coroutine has not attribute get_name
task_name = tasks[i].get_name()
logger.error(f"Task {task_name} failed with exception: {e}")
return results

View file

@ -35,3 +35,5 @@ OUTPUT_COMPONENTS = [
InterfaceComponentTypes.ChatOutput,
InterfaceComponentTypes.TextOutput,
]
INPUT_FIELD_NAME = "input_value"

View file

@ -44,7 +44,7 @@ class Vertex:
) -> None:
# is_external means that the Vertex send or receives data from
# an external source (e.g the chat)
self.updated_raw_params = False
self.id: str = data["id"]
self.is_input = any(
input_component_name in self.id for input_component_name in INPUT_COMPONENTS
@ -285,6 +285,10 @@ class Vertex:
if self.graph is None:
raise ValueError("Graph not found")
if self.updated_raw_params:
self.updated_raw_params = False
return
template_dict = {
key: value
for key, value in self.data["node"]["template"].items()
@ -386,10 +390,11 @@ class Vertex:
Raises:
ValueError: If any key in new_params is not found in self._raw_params.
"""
for key in new_params:
if key not in self._raw_params:
raise ValueError(f"Key {key} not found in raw params")
# First check if the input_value in _raw_params is not a vertex
if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params):
return
self._raw_params.update(new_params)
self.updated_raw_params = True
async def _build(self, user_id=None):
"""
@ -451,6 +456,8 @@ class Vertex:
await self._build_node_and_update_params(key, value, user_id)
elif isinstance(value, list) and self._is_list_of_nodes(value):
await self._build_list_of_nodes_and_update_params(key, value, user_id)
elif key not in self.params:
self.params[key] = value
def _is_node(self, value):
"""
@ -586,7 +593,7 @@ class Vertex:
logger.warning(message)
def _reset(self):
def _reset(self, params_update: Optional[Dict[str, Any]] = None):
self._built = False
self._built_object = UnbuiltObject()
self._built_result = UnbuiltResult()