Refactor vertex class and update build process

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-27 11:37:44 -03:00
commit 04de488ede
2 changed files with 51 additions and 23 deletions

View file

@ -6,7 +6,12 @@ from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional
from loguru import logger
from langflow.graph.schema import InterfaceComponentTypes
from langflow.graph.schema import (
INPUT_COMPONENTS,
OUTPUT_COMPONENTS,
InterfaceComponentTypes,
ResultData,
)
from langflow.graph.utils import UnbuiltObject, UnbuiltResult
from langflow.graph.vertex.utils import generate_result
from langflow.interface.initialize import loading
@ -16,7 +21,6 @@ from langflow.utils.constants import DIRECT_TYPES
from langflow.utils.util import sync_to_async
if TYPE_CHECKING:
from langflow.api.v1.schemas import ResultData
from langflow.graph.edge.base import ContractEdge
from langflow.graph.graph.base import Graph
@ -40,11 +44,19 @@ class Vertex:
) -> None:
# is_external means that the Vertex send or receives data from
# an external source (e.g the chat)
self.id: str = data["id"]
self.is_input = any(
input_component_name in self.id for input_component_name in INPUT_COMPONENTS
)
self.is_output = any(
output_component_name in self.id
for output_component_name in OUTPUT_COMPONENTS
)
self._custom_component = None
self.has_external_input = False
self.has_external_output = False
self.graph = graph
self.id: str = data["id"]
self._data = data
self.base_type: Optional[str] = base_type
self._parse_data()
@ -61,7 +73,7 @@ class Vertex:
self.parent_is_top_level = False
self.layer = None
self.should_run = True
self.result: Optional["ResultData"] = None
self.result: Optional[ResultData] = None
try:
self.is_interface_component = InterfaceComponentTypes(self.vertex_type)
except ValueError:
@ -116,7 +128,7 @@ class Vertex:
)
return edge_results
def set_result(self, result: "ResultData") -> None:
def set_result(self, result: ResultData) -> None:
self.result = result
def get_built_result(self):
@ -203,6 +215,8 @@ class Vertex:
self.display_name = self.data["node"]["display_name"]
self.pinned = self.data["node"].get("pinned", False)
self.selected_output_type = self.data["node"].get("selected_output_type")
self.is_input = self.data["node"].get("is_input") or self.is_input
self.is_output = self.data["node"].get("is_output") or self.is_output
template_dicts = {
key: value
for key, value in self.data["node"]["template"].items()
@ -359,6 +373,21 @@ class Vertex:
self.params = params
self._raw_params = params.copy()
def update_raw_params(self, new_params: Dict[str, str]):
"""
Update the raw parameters of the vertex with the given new parameters.
Args:
new_params (Dict[str, Any]): The new parameters to update.
Raises:
ValueError: If any key in new_params is not found in self._raw_params.
"""
for key in new_params:
if key not in self._raw_params:
raise ValueError(f"Key {key} not found in raw params")
self._raw_params.update(new_params)
async def _build(self, user_id=None):
"""
Initiate the build process.
@ -370,6 +399,18 @@ class Vertex:
self._built = True
def _finalize_build(self):
result_dict = self.get_built_result()
# We need to set the artifacts to pass information
# to the frontend
self.set_artifacts()
artifacts = self.artifacts
result_dict = ResultData(
results=result_dict,
artifacts=artifacts,
)
self.set_result(result_dict)
async def _run(
self,
user_id: str,
@ -501,17 +542,13 @@ class Vertex:
if self.base_type is None:
raise ValueError(f"Base type for node {self.display_name} not found")
try:
outgoing_edges = self.graph.get_vertex_edges(
self.id, is_source=True, is_target=False
)
result = await loading.instantiate_class(
node_type=self.vertex_type,
base_type=self.base_type,
params=self.params,
user_id=user_id,
outgoing_edges=outgoing_edges,
selected_output_type=self.selected_output_type,
vertex=self,
)
self._update_built_object_and_artifacts(result)
except Exception as exc:
@ -584,6 +621,8 @@ class Vertex:
step(user_id=user_id, **kwargs)
self.steps_ran.append(step)
self._finalize_build()
return await self.get_requester_result(requester)
async def get_requester_result(self, requester: Optional["Vertex"]):

View file

@ -1,10 +1,10 @@
import ast
import json
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
from typing import (AsyncIterator, Callable, Dict, Iterator, List, Optional,
Union)
import yaml
from langchain_core.messages import AIMessage
from loguru import logger
from langflow.graph.utils import UnbuiltObject, flatten_list
from langflow.graph.vertex.base import StatefulVertex, StatelessVertex
@ -344,17 +344,6 @@ class ChatVertex(StatelessVertex):
def build_stream_url(self):
return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream"
async def _build(self, user_id=None):
"""
Initiate the build process.
"""
logger.debug(f"Building {self.vertex_type}")
await self._build_each_node_in_params_dict(user_id)
await self._get_and_instantiate_class(user_id)
self._validate_built_object()
self._built = True
def _built_object_repr(self):
if self.task_id and self.is_task:
if task := self.get_task():