Add updated_raw_params flag and INPUT_FIELD_NAME constant
This commit is contained in:
parent
d6963b5812
commit
6ad4de8655
4 changed files with 60 additions and 47 deletions
|
|
@ -4,6 +4,7 @@ from loguru import logger
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from langflow.graph.edge.utils import build_clean_params
|
||||
from langflow.graph.schema import INPUT_FIELD_NAME
|
||||
from langflow.services.deps import get_monitor_service
|
||||
from langflow.services.monitor.utils import log_message
|
||||
|
||||
|
|
@ -12,7 +13,9 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class SourceHandle(BaseModel):
|
||||
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
|
||||
baseClasses: List[str] = Field(
|
||||
..., description="List of base classes for the source handle."
|
||||
)
|
||||
dataType: str = Field(..., description="Data type for the source handle.")
|
||||
id: str = Field(..., description="Unique identifier for the source handle.")
|
||||
|
||||
|
|
@ -20,7 +23,9 @@ class SourceHandle(BaseModel):
|
|||
class TargetHandle(BaseModel):
|
||||
fieldName: str = Field(..., description="Field name for the target handle.")
|
||||
id: str = Field(..., description="Unique identifier for the target handle.")
|
||||
inputTypes: Optional[List[str]] = Field(None, description="List of input types for the target handle.")
|
||||
inputTypes: Optional[List[str]] = Field(
|
||||
None, description="List of input types for the target handle."
|
||||
)
|
||||
type: str = Field(..., description="Type of the target handle.")
|
||||
|
||||
|
||||
|
|
@ -49,16 +54,24 @@ class Edge:
|
|||
|
||||
def validate_handles(self, source, target) -> None:
|
||||
if self.target_handle.inputTypes is None:
|
||||
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
|
||||
self.valid_handles = (
|
||||
self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
else:
|
||||
self.valid_handles = (
|
||||
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
|
||||
any(
|
||||
baseClass in self.target_handle.inputTypes
|
||||
for baseClass in self.source_handle.baseClasses
|
||||
)
|
||||
or self.target_handle.type in self.source_handle.baseClasses
|
||||
)
|
||||
if not self.valid_handles:
|
||||
logger.debug(self.source_handle)
|
||||
logger.debug(self.target_handle)
|
||||
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has invalid handles")
|
||||
raise ValueError(
|
||||
f"Edge between {source.vertex_type} and {target.vertex_type} "
|
||||
f"has invalid handles"
|
||||
)
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.source_id = state["source_id"]
|
||||
|
|
@ -75,7 +88,11 @@ class Edge:
|
|||
# Both lists contain strings and sometimes a string contains the value we are
|
||||
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
|
||||
# so we need to check if any of the strings in source_types is in target_reqs
|
||||
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
|
||||
self.valid = any(
|
||||
output in target_req
|
||||
for output in self.source_types
|
||||
for target_req in self.target_reqs
|
||||
)
|
||||
# Get what type of input the target node is expecting
|
||||
|
||||
self.matched_type = next(
|
||||
|
|
@ -86,7 +103,10 @@ class Edge:
|
|||
if no_matched_type:
|
||||
logger.debug(self.source_types)
|
||||
logger.debug(self.target_reqs)
|
||||
raise ValueError(f"Edge between {source.vertex_type} and {target.vertex_type} " f"has no matched type")
|
||||
raise ValueError(
|
||||
f"Edge between {source.vertex_type} and {target.vertex_type} "
|
||||
f"has no matched type"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
|
@ -98,7 +118,11 @@ class Edge:
|
|||
return hash(self.__repr__())
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
return self.__repr__() == __value.__repr__() if isinstance(__value, Edge) else False
|
||||
return (
|
||||
self.__repr__() == __value.__repr__()
|
||||
if isinstance(__value, Edge)
|
||||
else False
|
||||
)
|
||||
|
||||
|
||||
class ContractEdge(Edge):
|
||||
|
|
@ -137,15 +161,15 @@ class ContractEdge(Edge):
|
|||
log_transaction(self, source, target, "success")
|
||||
# If the target vertex is a power component we log messages
|
||||
if target.vertex_type == "ChatOutput" and (
|
||||
isinstance(target.params.get("input_value"), str)
|
||||
or isinstance(target.params.get("input_value"), dict)
|
||||
isinstance(target.params.get(INPUT_FIELD_NAME), str)
|
||||
or isinstance(target.params.get(INPUT_FIELD_NAME), dict)
|
||||
):
|
||||
if target.params.get("message") == "":
|
||||
return self.result
|
||||
await log_message(
|
||||
sender=target.params.get("sender", ""),
|
||||
sender_name=target.params.get("sender_name", ""),
|
||||
message=target.params.get("input_value", {}),
|
||||
message=target.params.get(INPUT_FIELD_NAME, {}),
|
||||
session_id=target.params.get("session_id", ""),
|
||||
artifacts=target.artifacts,
|
||||
)
|
||||
|
|
@ -155,7 +179,9 @@ class ContractEdge(Edge):
|
|||
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
|
||||
|
||||
|
||||
def log_transaction(edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None):
|
||||
def log_transaction(
|
||||
edge: ContractEdge, source: "Vertex", target: "Vertex", status, error=None
|
||||
):
|
||||
try:
|
||||
monitor_service = get_monitor_service()
|
||||
clean_params = build_clean_params(target)
|
||||
|
|
|
|||
|
|
@ -33,8 +33,6 @@ class Graph:
|
|||
edges: List[Dict[str, str]],
|
||||
flow_id: Optional[str] = None,
|
||||
) -> None:
|
||||
self.inputs = []
|
||||
self.outputs = []
|
||||
self._vertices = nodes
|
||||
self._edges = edges
|
||||
self.raw_graph_data = {"nodes": nodes, "edges": edges}
|
||||
|
|
@ -77,7 +75,7 @@ class Graph:
|
|||
|
||||
async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
for vertex_id in self.inputs:
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
|
|
@ -89,7 +87,7 @@ class Graph:
|
|||
logger.exception(exc)
|
||||
raise ValueError(f"Error running graph: {exc}") from exc
|
||||
outputs = []
|
||||
for vertex_id in self.outputs:
|
||||
for vertex_id in self._is_output_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
|
|
@ -104,11 +102,11 @@ class Graph:
|
|||
# of the vertices that are inputs
|
||||
# if the value is a list, we need to run multiple times
|
||||
outputs = []
|
||||
inputs_values = inputs.get("input_value")
|
||||
inputs_values = inputs.get(INPUT_FIELD_NAME)
|
||||
if not isinstance(inputs_values, list):
|
||||
inputs_values = [inputs_values]
|
||||
for input_value in inputs_values:
|
||||
run_outputs = await self._run({"input_value": input_value})
|
||||
run_outputs = await self._run({INPUT_FIELD_NAME: input_value})
|
||||
logger.debug(f"Run outputs: {run_outputs}")
|
||||
outputs.extend(run_outputs)
|
||||
return outputs
|
||||
|
|
@ -317,28 +315,6 @@ class Graph:
|
|||
# Now that we have the vertices and edges
|
||||
# We need to map the vertices that are connected to
|
||||
# to ChatVertex instances
|
||||
self._map_chat_vertices()
|
||||
|
||||
def _map_chat_vertices(self) -> None:
|
||||
"""Maps the vertices that are connected to ChatVertex instances."""
|
||||
# For each edge, we need to check if the source or target vertex is a ChatVertex
|
||||
# If it is, we need to update the other vertex `is_external` attribute
|
||||
# and store the id of the ChatVertex in the attributes self.inputs and self.outputs
|
||||
for edge in self.edges:
|
||||
source_vertex = self.get_vertex(edge.source_id)
|
||||
target_vertex = self.get_vertex(edge.target_id)
|
||||
if isinstance(source_vertex, ChatVertex):
|
||||
# The source vertex is a ChatVertex
|
||||
# thus the target vertex is an external vertex
|
||||
# and the source vertex is an input
|
||||
target_vertex.has_external_input = True
|
||||
self.inputs.append(source_vertex.id)
|
||||
if isinstance(target_vertex, ChatVertex):
|
||||
# The target vertex is a ChatVertex
|
||||
# thus the source vertex is an external vertex
|
||||
# and the target vertex is an output
|
||||
source_vertex.has_external_output = True
|
||||
self.outputs.append(target_vertex.id)
|
||||
|
||||
def remove_vertex(self, vertex_id: str) -> None:
|
||||
"""Removes a vertex from the graph."""
|
||||
|
|
@ -443,13 +419,15 @@ class Graph:
|
|||
async def _execute_tasks(self, tasks):
|
||||
"""Executes tasks in parallel, handling exceptions for each task."""
|
||||
results = []
|
||||
for task in asyncio.as_completed(tasks):
|
||||
for i, task in enumerate(asyncio.as_completed(tasks)):
|
||||
try:
|
||||
result = await task
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
# Log the exception along with the task name for easier debugging
|
||||
task_name = task.get_name()
|
||||
# task_name = task.get_name()
|
||||
# coroutine has not attribute get_name
|
||||
task_name = tasks[i].get_name()
|
||||
logger.error(f"Task {task_name} failed with exception: {e}")
|
||||
return results
|
||||
|
||||
|
|
|
|||
|
|
@ -35,3 +35,5 @@ OUTPUT_COMPONENTS = [
|
|||
InterfaceComponentTypes.ChatOutput,
|
||||
InterfaceComponentTypes.TextOutput,
|
||||
]
|
||||
|
||||
INPUT_FIELD_NAME = "input_value"
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class Vertex:
|
|||
) -> None:
|
||||
# is_external means that the Vertex send or receives data from
|
||||
# an external source (e.g the chat)
|
||||
|
||||
self.updated_raw_params = False
|
||||
self.id: str = data["id"]
|
||||
self.is_input = any(
|
||||
input_component_name in self.id for input_component_name in INPUT_COMPONENTS
|
||||
|
|
@ -285,6 +285,10 @@ class Vertex:
|
|||
if self.graph is None:
|
||||
raise ValueError("Graph not found")
|
||||
|
||||
if self.updated_raw_params:
|
||||
self.updated_raw_params = False
|
||||
return
|
||||
|
||||
template_dict = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
|
|
@ -386,10 +390,11 @@ class Vertex:
|
|||
Raises:
|
||||
ValueError: If any key in new_params is not found in self._raw_params.
|
||||
"""
|
||||
for key in new_params:
|
||||
if key not in self._raw_params:
|
||||
raise ValueError(f"Key {key} not found in raw params")
|
||||
# First check if the input_value in _raw_params is not a vertex
|
||||
if any(isinstance(self._raw_params.get(key), Vertex) for key in new_params):
|
||||
return
|
||||
self._raw_params.update(new_params)
|
||||
self.updated_raw_params = True
|
||||
|
||||
async def _build(self, user_id=None):
|
||||
"""
|
||||
|
|
@ -451,6 +456,8 @@ class Vertex:
|
|||
await self._build_node_and_update_params(key, value, user_id)
|
||||
elif isinstance(value, list) and self._is_list_of_nodes(value):
|
||||
await self._build_list_of_nodes_and_update_params(key, value, user_id)
|
||||
elif key not in self.params:
|
||||
self.params[key] = value
|
||||
|
||||
def _is_node(self, value):
|
||||
"""
|
||||
|
|
@ -586,7 +593,7 @@ class Vertex:
|
|||
|
||||
logger.warning(message)
|
||||
|
||||
def _reset(self):
|
||||
def _reset(self, params_update: Optional[Dict[str, Any]] = None):
|
||||
self._built = False
|
||||
self._built_object = UnbuiltObject()
|
||||
self._built_result = UnbuiltResult()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue