Refactor vertex class and update build process
This commit is contained in:
parent
9beadd70f1
commit
04de488ede
2 changed files with 51 additions and 23 deletions
|
|
@ -6,7 +6,12 @@ from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, List, Optional
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.schema import InterfaceComponentTypes
|
||||
from langflow.graph.schema import (
|
||||
INPUT_COMPONENTS,
|
||||
OUTPUT_COMPONENTS,
|
||||
InterfaceComponentTypes,
|
||||
ResultData,
|
||||
)
|
||||
from langflow.graph.utils import UnbuiltObject, UnbuiltResult
|
||||
from langflow.graph.vertex.utils import generate_result
|
||||
from langflow.interface.initialize import loading
|
||||
|
|
@ -16,7 +21,6 @@ from langflow.utils.constants import DIRECT_TYPES
|
|||
from langflow.utils.util import sync_to_async
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.api.v1.schemas import ResultData
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
||||
|
|
@ -40,11 +44,19 @@ class Vertex:
|
|||
) -> None:
|
||||
# is_external means that the Vertex send or receives data from
|
||||
# an external source (e.g the chat)
|
||||
|
||||
self.id: str = data["id"]
|
||||
self.is_input = any(
|
||||
input_component_name in self.id for input_component_name in INPUT_COMPONENTS
|
||||
)
|
||||
self.is_output = any(
|
||||
output_component_name in self.id
|
||||
for output_component_name in OUTPUT_COMPONENTS
|
||||
)
|
||||
self._custom_component = None
|
||||
self.has_external_input = False
|
||||
self.has_external_output = False
|
||||
self.graph = graph
|
||||
self.id: str = data["id"]
|
||||
self._data = data
|
||||
self.base_type: Optional[str] = base_type
|
||||
self._parse_data()
|
||||
|
|
@ -61,7 +73,7 @@ class Vertex:
|
|||
self.parent_is_top_level = False
|
||||
self.layer = None
|
||||
self.should_run = True
|
||||
self.result: Optional["ResultData"] = None
|
||||
self.result: Optional[ResultData] = None
|
||||
try:
|
||||
self.is_interface_component = InterfaceComponentTypes(self.vertex_type)
|
||||
except ValueError:
|
||||
|
|
@ -116,7 +128,7 @@ class Vertex:
|
|||
)
|
||||
return edge_results
|
||||
|
||||
def set_result(self, result: "ResultData") -> None:
|
||||
def set_result(self, result: ResultData) -> None:
|
||||
self.result = result
|
||||
|
||||
def get_built_result(self):
|
||||
|
|
@ -203,6 +215,8 @@ class Vertex:
|
|||
self.display_name = self.data["node"]["display_name"]
|
||||
self.pinned = self.data["node"].get("pinned", False)
|
||||
self.selected_output_type = self.data["node"].get("selected_output_type")
|
||||
self.is_input = self.data["node"].get("is_input") or self.is_input
|
||||
self.is_output = self.data["node"].get("is_output") or self.is_output
|
||||
template_dicts = {
|
||||
key: value
|
||||
for key, value in self.data["node"]["template"].items()
|
||||
|
|
@ -359,6 +373,21 @@ class Vertex:
|
|||
self.params = params
|
||||
self._raw_params = params.copy()
|
||||
|
||||
def update_raw_params(self, new_params: Dict[str, str]):
|
||||
"""
|
||||
Update the raw parameters of the vertex with the given new parameters.
|
||||
|
||||
Args:
|
||||
new_params (Dict[str, Any]): The new parameters to update.
|
||||
|
||||
Raises:
|
||||
ValueError: If any key in new_params is not found in self._raw_params.
|
||||
"""
|
||||
for key in new_params:
|
||||
if key not in self._raw_params:
|
||||
raise ValueError(f"Key {key} not found in raw params")
|
||||
self._raw_params.update(new_params)
|
||||
|
||||
async def _build(self, user_id=None):
|
||||
"""
|
||||
Initiate the build process.
|
||||
|
|
@ -370,6 +399,18 @@ class Vertex:
|
|||
|
||||
self._built = True
|
||||
|
||||
def _finalize_build(self):
|
||||
result_dict = self.get_built_result()
|
||||
# We need to set the artifacts to pass information
|
||||
# to the frontend
|
||||
self.set_artifacts()
|
||||
artifacts = self.artifacts
|
||||
result_dict = ResultData(
|
||||
results=result_dict,
|
||||
artifacts=artifacts,
|
||||
)
|
||||
self.set_result(result_dict)
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
user_id: str,
|
||||
|
|
@ -501,17 +542,13 @@ class Vertex:
|
|||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for node {self.display_name} not found")
|
||||
try:
|
||||
outgoing_edges = self.graph.get_vertex_edges(
|
||||
self.id, is_source=True, is_target=False
|
||||
)
|
||||
|
||||
result = await loading.instantiate_class(
|
||||
node_type=self.vertex_type,
|
||||
base_type=self.base_type,
|
||||
params=self.params,
|
||||
user_id=user_id,
|
||||
outgoing_edges=outgoing_edges,
|
||||
selected_output_type=self.selected_output_type,
|
||||
vertex=self,
|
||||
)
|
||||
self._update_built_object_and_artifacts(result)
|
||||
except Exception as exc:
|
||||
|
|
@ -584,6 +621,8 @@ class Vertex:
|
|||
step(user_id=user_id, **kwargs)
|
||||
self.steps_ran.append(step)
|
||||
|
||||
self._finalize_build()
|
||||
|
||||
return await self.get_requester_result(requester)
|
||||
|
||||
async def get_requester_result(self, requester: Optional["Vertex"]):
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import ast
|
||||
import json
|
||||
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterator, List, Optional,
|
||||
Union)
|
||||
|
||||
import yaml
|
||||
from langchain_core.messages import AIMessage
|
||||
from loguru import logger
|
||||
|
||||
from langflow.graph.utils import UnbuiltObject, flatten_list
|
||||
from langflow.graph.vertex.base import StatefulVertex, StatelessVertex
|
||||
|
|
@ -344,17 +344,6 @@ class ChatVertex(StatelessVertex):
|
|||
def build_stream_url(self):
|
||||
return f"/api/v1/build/{self.graph.flow_id}/{self.id}/stream"
|
||||
|
||||
async def _build(self, user_id=None):
|
||||
"""
|
||||
Initiate the build process.
|
||||
"""
|
||||
logger.debug(f"Building {self.vertex_type}")
|
||||
await self._build_each_node_in_params_dict(user_id)
|
||||
await self._get_and_instantiate_class(user_id)
|
||||
self._validate_built_object()
|
||||
|
||||
self._built = True
|
||||
|
||||
def _built_object_repr(self):
|
||||
if self.task_id and self.is_task:
|
||||
if task := self.get_task():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue