diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index 4af5f65b8..c734ad2e2 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -1,21 +1,27 @@ import inspect -from typing import Any, Callable, ClassVar, List, Optional, Union, get_type_hints +from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Union, get_type_hints from uuid import UUID import nanoid # type: ignore import yaml from pydantic import BaseModel +from langflow.graph.edge.schema import EdgeData from langflow.helpers.custom import format_type from langflow.inputs.inputs import InputTypes from langflow.schema.artifact import get_artifact_type, post_process_raw from langflow.schema.data import Data from langflow.schema.message import Message from langflow.services.tracing.schema import Log -from langflow.template.field.base import UNDEFINED, Output +from langflow.template.field.base import UNDEFINED, Input, Output +from langflow.template.frontend_node.custom_components import ComponentFrontendNode +from langflow.utils.async_helpers import run_until_complete from .custom_component import CustomComponent +if TYPE_CHECKING: + from langflow.graph.vertex.base import Vertex + BACKWARDS_COMPATIBLE_ATTRIBUTES = ["user_id", "vertex", "tracing_service"] @@ -40,6 +46,7 @@ class Component(CustomComponent): self._results: dict[str, Any] = {} self._attributes: dict[str, Any] = {} self._parameters = inputs or {} + self._edges: list[EdgeData] = [] self._components: list[Component] = [] self.set_attributes(self._parameters) self._output_logs = {} @@ -55,47 +62,68 @@ class Component(CustomComponent): self.map_inputs(self.inputs) if self.outputs is not None: self.map_outputs(self.outputs) + # Set output types self._set_output_types() - def __getattr__(self, name: str) -> Any: - if "_attributes" in self.__dict__ and name in self.__dict__["_attributes"]: - return self.__dict__["_attributes"][name] - if "_inputs" in self.__dict__ and name in self.__dict__["_inputs"]: - return self.__dict__["_inputs"][name].value - if name in BACKWARDS_COMPATIBLE_ATTRIBUTES: - return self.__dict__[f"_{name}"] - raise AttributeError(f"{name} not found in {self.__class__.__name__}") - - def map_inputs(self, inputs: List[InputTypes]): - self.inputs = inputs - for input_ in inputs: - if input_.name is None: - raise ValueError("Input name cannot be None.") - self._inputs[input_.name] = input_ - - def map_outputs(self, outputs: List[Output]): + def set(self, **kwargs): """ - Maps the given list of outputs to the component. + Connects the component to other components or sets parameters and attributes. + Args: - outputs (List[Output]): The list of outputs to be mapped. + **kwargs: Keyword arguments representing the connections, parameters, and attributes. + + Returns: + None + Raises: - ValueError: If the output name is None. + KeyError: If the specified input name does not exist. + """ + for key, value in kwargs.items(): + self._process_connection_or_parameter(key, value) + + def list_inputs(self): + """ + Returns a list of input names. + """ + return [_input.name for _input in self.inputs] + + def list_outputs(self): + """ + Returns a list of output names. + """ + return [_output.name for _output in self.outputs] + + async def run(self): + """ + Executes the component's logic and returns the result. + + Returns: + The result of executing the component's logic. + """ + return await self._run() + + def set_vertex(self, vertex: "Vertex"): + """ + Sets the vertex for the component. + + Args: + vertex (Vertex): The vertex to set. + Returns: None """ - self.outputs = outputs - for output in outputs: - if output.name is None: - raise ValueError("Output name cannot be None.") - self._outputs[output.name] = output + self._vertex = vertex def get_input(self, name: str) -> Any: """ Retrieves the value of the input with the specified name. + Args: name (str): The name of the input. + Returns: Any: The value of the input. + Raises: ValueError: If the input with the specified name is not found. """ @@ -106,10 +134,13 @@ class Component(CustomComponent): def get_output(self, name: str) -> Any: """ Retrieves the output with the specified name. + Args: name (str): The name of the output to retrieve. + Returns: Any: The output value. + Raises: ValueError: If the output with the specified name is not found. """ @@ -123,7 +154,53 @@ class Component(CustomComponent): else: raise ValueError(f"Output {name} not found in {self.__class__.__name__}") + def map_outputs(self, outputs: List[Output]): + """ + Maps the given list of outputs to the component. + + Args: + outputs (List[Output]): The list of outputs to be mapped. + + Raises: + ValueError: If the output name is None. + + Returns: + None + """ + self.outputs = outputs + for output in outputs: + if output.name is None: + raise ValueError("Output name cannot be None.") + self._outputs[output.name] = output + + def map_inputs(self, inputs: List[InputTypes]): + """ + Maps the given inputs to the component. + + Args: + inputs (List[InputTypes]): A list of InputTypes objects representing the inputs. + + Raises: + ValueError: If the input name is None. + + """ + self.inputs = inputs + for input_ in inputs: + if input_.name is None: + raise ValueError("Input name cannot be None.") + self._inputs[input_.name] = input_ + def validate(self, params: dict): + """ + Validates the component parameters. + + Args: + params (dict): A dictionary containing the component parameters. + + Raises: + ValueError: If the inputs are not valid. + ValueError: If the outputs are not valid. + """ self._validate_inputs(params) self._validate_outputs() @@ -133,12 +210,6 @@ class Component(CustomComponent): output.add_types(return_types) output.set_selected() - def _get_method_return_type(self, method_name: str) -> List[str]: - method = getattr(self, method_name) - return_type = get_type_hints(method)["return"] - extracted_return_types = self._extract_return_type(return_type) - return [format_type(extracted_return_type) for extracted_return_type in extracted_return_types] - def _get_output_by_method(self, method: Callable): # method is a callable and output.method is a string # we need to find the output that has the same method @@ -148,10 +219,145 @@ class Component(CustomComponent): raise ValueError(f"Output with method {method_name} not found") return output + def _process_connection_or_parameter(self, key, value): + _input = self._get_or_create_input(key) + if callable(value): + self._connect_to_component(key, value, _input) + else: + self._set_parameter_or_attribute(key, value) + + def _get_or_create_input(self, key): + try: + return self._inputs[key] + except KeyError: + _input = self._get_fallback_input(name=key, display_name=key) + self._inputs[key] = _input + self.inputs.append(_input) + return _input + + def _connect_to_component(self, key, value, _input): + component = value.__self__ + self._components.append(component) + output = component._get_output_by_method(value) + self._add_edge(component, key, output, _input) + + def _add_edge(self, component, key, output, _input): + self._edges.append( + { + "source": component._id, + "target": self._id, + "data": { + "sourceHandle": { + "dataType": self.name, + "id": component._id, + "name": output.name, + "output_types": output.types, + }, + "targetHandle": { + "fieldName": key, + "id": self._id, + "inputTypes": _input.input_types, + "type": _input.field_type, + }, + }, + } + ) + + def _set_parameter_or_attribute(self, key, value): + self._parameters[key] = value + self._attributes[key] = value + + def __call__(self, **kwargs): + self.set(**kwargs) + + return run_until_complete(self.run()) + + async def _run(self): + # Resolve callable inputs + for key, _input in self._inputs.items(): + if callable(_input.value): + result = _input.value() + if inspect.iscoroutine(result): + result = await result + self._inputs[key].value = result + + self.set_attributes({}) + + return await self.build_results() + + def __getattr__(self, name: str) -> Any: + if "_attributes" in self.__dict__ and name in self.__dict__["_attributes"]: + return self.__dict__["_attributes"][name] + if "_inputs" in self.__dict__ and name in self.__dict__["_inputs"]: + return self.__dict__["_inputs"][name].value + if name in BACKWARDS_COMPATIBLE_ATTRIBUTES: + return self.__dict__[f"_{name}"] + raise AttributeError(f"{name} not found in {self.__class__.__name__}") + + def _set_input_value(self, name: str, value: Any): + if name in self._inputs: + input_value = self._inputs[name].value + if callable(input_value): + raise ValueError( + f"Input {name} is connected to {input_value.__self__.display_name}.{input_value.__name__}" + ) + self._inputs[name].value = value + self._attributes[name] = value + else: + raise ValueError(f"Input {name} not found in {self.__class__.__name__}") + def _validate_outputs(self): # Raise Error if some rule isn't met pass + def _map_parameters_on_frontend_node(self, frontend_node: ComponentFrontendNode): + for name, value in self._parameters.items(): + frontend_node.set_field_value_in_template(name, value) + + def _map_parameters_on_template(self, template: dict): + for name, value in self._parameters.items(): + template[name]["value"] = value + + def _get_method_return_type(self, method_name: str) -> List[str]: + method = getattr(self, method_name) + return_type = get_type_hints(method)["return"] + extracted_return_types = self._extract_return_type(return_type) + return [format_type(extracted_return_type) for extracted_return_type in extracted_return_types] + + def _update_template(self, frontend_node: dict): + return frontend_node + + def to_frontend_node(self): + #! This part here is clunky but we need it like this for + #! backwards compatibility. We can change how prompt component + #! works and then update this later + field_config = self.get_template_config(self) + frontend_node = ComponentFrontendNode.from_inputs(**field_config) + self._map_parameters_on_frontend_node(frontend_node) + + frontend_node_dict = frontend_node.to_dict(keep_name=False) + frontend_node_dict = self._update_template(frontend_node_dict) + self._map_parameters_on_template(frontend_node_dict["template"]) + + frontend_node = ComponentFrontendNode.from_dict(frontend_node_dict) + + for output in frontend_node.outputs: + if output.types: + continue + return_types = self._get_method_return_type(output.method) + output.add_types(return_types) + output.set_selected() + + frontend_node.validate_component() + frontend_node.set_base_classes_from_outputs() + data = { + "data": { + "node": frontend_node.to_dict(keep_name=False), + "type": self.__class__.__name__, + } + } + return data + def _validate_inputs(self, params: dict): # Params keys are the `name` attribute of the Input objects for key, value in params.copy().items(): @@ -159,6 +365,7 @@ class Component(CustomComponent): continue input_ = self._inputs[key] # BaseInputMixin has a `validate_assignment=True` + input_.value = value params[input_.name] = input_.value @@ -166,7 +373,7 @@ class Component(CustomComponent): self._validate_inputs(params) _attributes = {} for key, value in params.items(): - if key in self.__dict__: + if key in self.__dict__ and value != getattr(self, key): raise ValueError( f"{self.__class__.__name__} defines an input parameter named '{key}' " f"that is a reserved word and cannot be used." @@ -220,11 +427,16 @@ class Component(CustomComponent): _results = {} _artifacts = {} if hasattr(self, "outputs"): - self._set_outputs(self._vertex.outputs) + if self._vertex: + self._set_outputs(self._vertex.outputs) for output in self.outputs: # Build the output if it's connected to some other vertex # or if it's not connected to any vertex - if not self._vertex.outgoing_edges or output.name in self._vertex.edges_source_names: + if ( + not self._vertex + or not self._vertex.outgoing_edges + or output.name in self._vertex.edges_source_names + ): if output.method is None: raise ValueError(f"Output {output.name} does not have a method defined.") method: Callable = getattr(self, output.method) @@ -236,7 +448,8 @@ class Component(CustomComponent): if inspect.iscoroutinefunction(method): result = await result if ( - isinstance(result, Message) + self._vertex is not None + and isinstance(result, Message) and result.flow_id is None and self._vertex.graph.flow_id is not None ): @@ -314,3 +527,6 @@ class Component(CustomComponent): def build(self, **kwargs): self.set_attributes(kwargs) + + def _get_fallback_input(self, **kwargs): + return Input(**kwargs) diff --git a/src/backend/base/langflow/inputs/inputs.py b/src/backend/base/langflow/inputs/inputs.py index bc244ec3d..87a0314f5 100644 --- a/src/backend/base/langflow/inputs/inputs.py +++ b/src/backend/base/langflow/inputs/inputs.py @@ -6,6 +6,7 @@ from pydantic import Field, field_validator from langflow.inputs.validators import CoalesceBool from langflow.schema.data import Data from langflow.schema.message import Message +from langflow.template.field.base import Input from .input_mixin import ( BaseInputMixin, @@ -57,7 +58,7 @@ class HandleInput(BaseInputMixin, ListableInputMixin, MetadataTraceMixin): field_type: SerializableFieldTypes = FieldTypes.OTHER -class DataInput(HandleInput, InputTraceMixin): +class DataInput(HandleInput, InputTraceMixin, ListableInputMixin): """ Represents an Input that has a Handle that receives a Data object. @@ -459,7 +460,23 @@ class FileInput(BaseInputMixin, ListableInputMixin, FileMixin, MetadataTraceMixi field_type: SerializableFieldTypes = FieldTypes.FILE +DEFAULT_PROMPT_INTUT_TYPES = ["Message", "Text"] + + +class DefaultPromptField(Input): + name: str + display_name: Optional[str] = None + field_type: str = "str" + + advanced: bool = False + multiline: bool = True + input_types: list[str] = DEFAULT_PROMPT_INTUT_TYPES + value: str = "" # Set the value to empty string + + InputTypes = Union[ + Input, + DefaultPromptField, BoolInput, DataInput, DictInput, @@ -485,6 +502,9 @@ InputTypesMap: dict[str, type[InputTypes]] = {t.__name__: t for t in get_args(In def _instantiate_input(input_type: str, data: dict) -> InputTypes: input_type_class = InputTypesMap.get(input_type) + if "type" in data: + # Replate with field_type + data["field_type"] = data.pop("type") if input_type_class: return input_type_class(**data) else: diff --git a/src/backend/base/langflow/template/frontend_node/base.py b/src/backend/base/langflow/template/frontend_node/base.py index 9aeb0b97d..199233ee9 100644 --- a/src/backend/base/langflow/template/frontend_node/base.py +++ b/src/backend/base/langflow/template/frontend_node/base.py @@ -89,6 +89,12 @@ class FrontendNode(BaseModel): return {name: result} + @classmethod + def from_dict(cls, data: dict) -> "FrontendNode": + if "template" in data: + data["template"] = Template.from_dict(data["template"]) + return cls(**data) + # For backwards compatibility def to_dict(self, keep_name=True) -> dict: """Returns a dict representation of the frontend node.""" @@ -173,3 +179,9 @@ class FrontendNode(BaseModel): template = Template(type_name="Component", fields=inputs) kwargs["template"] = template return cls(**kwargs) + + def set_field_value_in_template(self, field_name, value): + for field in self.template.fields: + if field.name == field_name: + field.value = value + break diff --git a/src/backend/base/langflow/template/template/base.py b/src/backend/base/langflow/template/template/base.py index 7c0c7fa0f..d0372b0b6 100644 --- a/src/backend/base/langflow/template/template/base.py +++ b/src/backend/base/langflow/template/template/base.py @@ -2,7 +2,7 @@ from typing import Callable, Union, cast from pydantic import BaseModel, Field, model_serializer -from langflow.inputs.inputs import InputTypes +from langflow.inputs.inputs import InputTypes, _instantiate_input from langflow.template.field.base import Input from langflow.utils.constants import DIRECT_TYPES @@ -35,6 +35,28 @@ class Template(BaseModel): return result + @classmethod + def from_dict(cls, data: dict) -> "Template": + for key, value in data.copy().items(): + if key == "_type": + data["type_name"] = value + del data[key] + else: + value["name"] = key + if "fields" not in data: + data["fields"] = [] + input_type = value.pop("_input_type", None) + if input_type: + try: + _input = _instantiate_input(input_type, value) + except Exception as e: + raise ValueError(f"Error instantiating input {input_type}: {e}") + else: + _input = Input(**value) + + data["fields"].append(_input) + return cls(**data) + # For backwards compatibility def to_dict(self, format_field_func=None): self.process_fields(format_field_func) diff --git a/src/frontend/tests/end-to-end/userSettings.spec.ts b/src/frontend/tests/end-to-end/userSettings.spec.ts index 14c827732..302c2f908 100644 --- a/src/frontend/tests/end-to-end/userSettings.spec.ts +++ b/src/frontend/tests/end-to-end/userSettings.spec.ts @@ -65,7 +65,7 @@ test("should interact with global variables", async ({ page }) => { await page.keyboard.press("Escape"); await page.getByText("Save Variable", { exact: true }).click(); - await page.getByText(randomName).isVisible(); + await page.getByText(randomName).last().isVisible(); const focusElementsOnBoard = async ({ page }) => { await page.waitForSelector(