feat: add possibility to initialize the Graph with components (#3134)
* refactor: update code references to use _code instead of code * refactor: add backwards compatible attributes to Component class * refactor: update Component constructor to pass config params with underscore Refactored the `Component` class in `component.py` to handle inputs and outputs. Added a new method `map_outputs` to map a list of outputs to the component. Also updated the `__init__` method to properly initialize the inputs, outputs, and other attributes. This change improves the flexibility and extensibility of the `Component` class. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: change attribute to use underscore * refactor: update CustomComponent initialization parameters Refactored the `instantiate_class` function in `loading.py` to update the initialization parameters for the `CustomComponent` class. Changed the parameter names from `user_id`, `parameters`, `vertex`, and `tracing_service` to `_user_id`, `_parameters`, `_vertex`, and `_tracing_service` respectively. This change ensures consistency and improves code readability. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: update BaseComponent to accept UUID for _user_id Updated the `BaseComponent` class in `base_component.py` to accept a `UUID` type for the `_user_id` attribute. This change improves the type safety and ensures consistency with the usage of `_user_id` throughout the codebase. * refactor: import nanoid with type annotation The `nanoid` import in `component.py` has been updated to include a type annotation `# type: ignore`. This change ensures that the type checker ignores any errors related to the `nanoid` import. * fix(custom_component.py): convert _user_id to string before passing to functions to ensure compatibility with function signatures * feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components * refactor: extract method to get method return type in CustomComponent * refactor(utils.py): refactor code to use _user_id instead of user_id for consistency and clarity perf(utils.py): optimize code by reusing cc_instance instead of calling get_component_instance multiple times * refactor(utils.py, base.py): change parameter name 'add_name' to 'keep_name' for clarity and consistency in codebase * [autofix.ci] apply automated fixes * refactor: update schema.py to include Edge related typres The `schema.py` file in the `src/backend/base/langflow/graph/edge` directory has been updated to include the `TargetHandle` and `SourceHandle` models. These models define the structure and attributes of the target and source handles used in the edge data. This change improves the clarity and consistency of the codebase. * refactor: update BaseInputMixin to handle invalid field types gracefully The `BaseInputMixin` class in `input_mixin.py` has been updated to handle invalid field types gracefully. Instead of raising an exception, it now returns `FieldTypes.OTHER` for any invalid field type. This change improves the robustness and reliability of the codebase. * refactor: update file_types field alias in FileMixin The `file_types` field in the `FileMixin` class of `input_mixin.py` has been updated to use the `alias` parameter instead of `serialization_alias`. This change ensures consistency and improves the clarity of the codebase. * refactor(inputs): update field_type declarations in various input classes to use SerializableFieldTypes enum for better type safety and clarity * refactor(inputs): convert dict to Message object in _validate_value method * refactor(inputs): convert dict to Message object in _validate_value method * refactor(inputs): update model_config in BaseInputMixin to enable populating by name The `model_config` attribute in the `BaseInputMixin` class of `input_mixin.py` has been updated to include the `populate_by_name=True` parameter. This change allows the model configuration to be populated by name, improving the flexibility and usability of the codebase. * refactor: update _extract_return_type method in CustomComponent to accept Any type The _extract_return_type method in CustomComponent has been updated to accept the Any type as the return_type parameter. This change improves the flexibility and compatibility of the method, allowing it to handle a wider range of return types. * refactor(component): add get_input and get_output methods for easier access to input and output values The `Component` class in `component.py` has been updated to include the `get_input` and `get_output` methods. These methods allow for easier retrieval of input and output values by name, improving the usability and readability of the codebase. * refactor(vertex): add get_input and get_output methods for easier access to input and output values * refactor(component): add set_output_value method for easier modification of output values The `Component` class in `component.py` has been updated to include the `set_output_value` method. This method allows for easier modification of output values by name, improving the usability and flexibility of the codebase. * feat: add run_until_complete and run_in_thread functions for handling asyncio tasks The `async_helpers.py` file in the `src/backend/base/langflow/utils` directory has been added. This file includes the `run_until_complete` and `run_in_thread` functions, which provide a way to handle asyncio tasks in different scenarios. The `run_until_complete` function checks if an event loop is already running and either runs the coroutine in a separate event loop in a new thread or creates a new event loop and runs the coroutine. The `run_in_thread` function runs the coroutine in a separate thread and returns the result or raises an exception if one occurs. These functions improve the flexibility and usability of the codebase. * refactor(component): add _edges attribute to Component class for managing edges The `Component` class in `component.py` has been updated to include the `_edges` attribute. This attribute is a list of `EdgeData` objects and is used for managing edges in the component. This change improves the functionality and organization of the codebase. * fix(component.py): fix conditional statement to check if self._vertex is not None before accessing its attributes * refactor(component): add _get_fallback_input method for handling fallback input The `Component` class in `component.py` has been updated to include the `_get_fallback_input` method. This method returns an `Input` object with the provided keyword arguments, which is used as a fallback input when needed. This change improves the flexibility and readability of the codebase. * refactor(component): add TYPE_CHECKING import for Vertex in component.py * refactor(component): add _map_parameters_on_frontend_node and _map_parameters_on_template and other methods The `Component` class in `component.py` has been refactored to include the `_map_parameters_on_frontend_node` and `_map_parameters_on_template` methods. These methods are responsible for mapping the parameters of the component onto the frontend node and template, respectively. This change improves the organization and maintainability of the codebase. * refactor(component): Add map_inputs and map_outputs methods for mapping inputs and outputs The `Component` class in `component.py` has been updated to include the `map_inputs` and `map_outputs` methods. These methods allow for mapping the given inputs and outputs to the component, improving the functionality and organization of the codebase. * refactor(component): Add Input, Output, and ComponentFrontendNode imports and run_until_complete function This commit refactors the `component.py` file in the `src/backend/base/langflow/custom/custom_component` directory. It adds the `Input`, `Output`, and `ComponentFrontendNode` imports, as well as the `run_until_complete` function from the `async_helpers.py` file. These changes improve the functionality and organization of the codebase. * refactor(component): Add map_inputs and map_outputs methods for mapping inputs and outputs * refactor(component): Add _process_connection_or_parameter method for handling connections and parameters The `Component` class in `component.py` has been updated to include the `_process_connection_or_parameter` method. This method is responsible for handling connections and parameters based on the provided key and value. It checks if the value is callable and connects it to the component, otherwise it sets the parameter or attribute. This change improves the functionality and organization of the codebase. * refactor(frontend_node): Add set_field_value_in_template method for updating field values The `FrontendNode` class in `base.py` has been updated to include the `set_field_value_in_template` method. This method allows for updating the value of a specific field in the template of the frontend node. It iterates through the fields and sets the value of the field with the provided name. This change improves the flexibility and functionality of the codebase. * refactor(inputs): Add DefaultPromptField class for default prompt inputs The `inputs.py` file in the `src/backend/base/langflow/inputs` directory has been refactored to include the `DefaultPromptField` class. This class represents a default prompt input with customizable properties such as name, display name, field type, advanced flag, multiline flag, input types, and value. This change improves the flexibility and functionality of the codebase. * feat: Add Template.from_dict method for creating Template objects from dictionaries This commit adds the `from_dict` class method to the `Template` class in `base.py`. This method allows for creating `Template` objects from dictionaries by converting the dictionary keys and values into the appropriate `Template` attributes. This change improves the flexibility and functionality of the codebase. * refactor(frontend_node): Add from_dict method for creating FrontendNode objects from dictionaries * refactor: update BaseComponent to use get_template_config method Refactored the `BaseComponent` class in `base_component.py` to use the `get_template_config` method instead of duplicating the code. This change improves code readability and reduces redundancy. * refactor(graph): Add EdgeData import and update add_nodes_and_edges method signature The `Graph` class in `base.py` has been updated to include the `EdgeData` import and modify the signature of the `add_nodes_and_edges` method. The `add_nodes_and_edges` method now accepts a list of dictionaries representing `EdgeData` objects instead of a list of dictionaries with string keys and values. This change improves the type safety and clarity of the codebase. * refactor(graph): Add first_layer property to Graph class The `Graph` class in `base.py` has been updated to include the `first_layer` property. This property returns the first layer of the graph and throws a `ValueError` if the graph is not prepared. This change improves the functionality and organization of the codebase. * refactor(graph): Update Graph class instantiation in base.py The `Graph` class in `base.py` has been updated to use keyword arguments when instantiating the class. This change improves the readability and maintainability of the codebase. * refactor(graph): Add prepare method to Graph class The `Graph` class in `base.py` has been updated to include the `prepare` method. This method prepares the graph for execution by validating the stream, building edges, and sorting vertices. It also adds the first layer of vertices to the run manager and sets the run queue. This change improves the functionality and organization of the codebase. * refactor(graph): Improve graph preparation in retrieve_vertices_order function The `retrieve_vertices_order` function in `chat.py` has been updated to improve the graph preparation process. Instead of manually sorting vertices and adding them to the run manager, the function now calls the `prepare` method of the `Graph` class. This method validates the stream, builds edges, and sets the first layer of vertices. This change improves the functionality and organization of the codebase. * refactor: Add GetCache and SetCache protocols for caching functionality * refactor(graph): Add VertexBuildResult class for representing vertex build results * refactor(chat.py, base.py): update build_vertex method in chat.py and base.py * refactor(graph): Update Edge and ContractEdge constructors to use EdgeData type The constructors of the `Edge` and `ContractEdge` classes in `base.py` have been updated to use the `EdgeData` type for the `edge` and `raw_edge` parameters, respectively. This change improves the type safety and clarity of the codebase. * feat: add BaseModel class with model_config attribute A new `BaseModel` class has been added to the `base_model.py` file. This class extends the `PydanticBaseModel` and includes a `model_config` attribute of type `ConfigDict`. This change improves the codebase by providing a base model with a configuration dictionary for models. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor: update langflow.graph.edge.schema.py Refactor the `langflow.graph.edge.schema.py` file to include the `TargetHandle` and `SourceHandle` models. This change improves the clarity and consistency of the codebase. Co-authored-by: Gabriel Luiz Freitas Almeida <gabriel@langflow.org> * refactor(base): Update target_param assignment in Edge class The `target_param` assignment in the `Edge` class of `base.py` has been updated to use the `cast` function for type hinting. This change improves the type safety and clarity of the codebase. * refactor(base): Add check for existing type in add_types method * refactor: update build_custom_component_template to use add_name instead of keep_name Refactor the `build_custom_component_template` function in `utils.py` to use the `add_name` parameter instead of the deprecated `keep_name` parameter. This change ensures consistency with the updated method signature and improves code clarity. * feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components (#3115) * feat(component.py): add method to set output types based on method return type to improve type checking and validation in custom components * refactor: extract method to get method return type in CustomComponent * refactor: update _extract_return_type method in CustomComponent to accept Any type The _extract_return_type method in CustomComponent has been updated to accept the Any type as the return_type parameter. This change improves the flexibility and compatibility of the method, allowing it to handle a wider range of return types. * refactor: add _template_config property to BaseComponent Add a new `_template_config` property to the `BaseComponent` class in `base_component.py`. This property is used to store the template configuration for the custom component. If the `_template_config` property is empty, it is populated by calling the `build_template_config` method. This change improves the efficiency of accessing the template configuration and ensures that it is only built when needed. * refactor: add type checking for Output types in add_types method Improve type checking in the `add_types` method of the `Output` class in `base.py`. Check if the `type_` already exists in the `types` list before adding it. This change ensures that duplicate types are not added to the list. * update starter projects * refactor: optimize imports in base.py Optimize imports in the `base.py` file by removing unused imports and organizing the remaining imports. This change improves code readability and reduces unnecessary clutter. * fix(base.py): fix condition to check if self.types is not None before checking if type_ is in self.types * refactor: update build_custom_component_template to use add_name instead of keep_name * refactor(prompts): Update PromptComponent to support custom fields and template updates The `PromptComponent` class in `Prompt.py` has been updated to support custom fields and template updates. The `_update_template` method has been added to update the prompt template with custom fields. The `post_code_processing` method has been modified to update the template and improve backwards compatibility. The `_get_fallback_input` method has been added to provide a default prompt field. These changes improve the functionality and flexibility of the codebase. * refactor(base): Add DefaultPromptField to langflow.io The `DefaultPromptField` class has been added to the `langflow.io` module. This class provides a default prompt field for the `TableInput` class. This change improves the functionality and flexibility of the codebase. * refactor(prompts): Update PromptComponent to support custom fields and template updates * refactor(base): Update langflow.template.field.prompt.py for backwards compatibility * refactor(base): Update langflow.components.__init__.py to import the prompts module This change adds the prompts module to the list of imports in the __init__.py file of the langflow.components package. This ensures that the prompts module is available for use in the codebase. * refactor(base): Update langflow.template.field.prompt.py for backwards compatibility * refactor(graph): Update VertexTypesDict to import vertex types lazily The VertexTypesDict class in constants.py has been updated to import vertex types lazily. This change improves the performance of the codebase by deferring the import until it is actually needed. * refactor(graph): Add missing attributes and lock to Graph class The Graph class in base.py has been updated to add missing attributes and a lock. This change ensures that the necessary attributes are initialized and provides thread safety with the addition of a lock. It improves the functionality and reliability of the codebase. * refactor(graph): Add method to set inputs in Graph class The `_set_inputs` method has been added to the Graph class in base.py. This method allows for setting inputs for specific vertices based on input components, inputs, and input type. It improves the functionality and flexibility of the codebase. * refactor(graph): Set inputs for specific vertices in Graph class The `_set_inputs` method has been added to the Graph class in base.py. This method allows for setting inputs for specific vertices based on input components, inputs, and input type. It improves the functionality and flexibility of the codebase. * refactor(graph): Update Graph class to set cache using flow_id The `Graph` class in `base.py` has been updated to set the cache using the `flow_id` attribute. This change ensures that the cache is properly set when `cache` is enabled and `flow_id` is not None. It improves the functionality and reliability of the codebase. * refactor(graph): Refactor Graph class to improve edge building The `Graph` class in `base.py` has been refactored to improve the process of building edges. The `build_edge` method has been added to encapsulate the logic of creating a `ContractEdge` object from the given `EdgeData`. This change enhances the readability and maintainability of the codebase. * refactor(graph): Update _create_vertex method parameter name for clarity The `_create_vertex` method in the `Graph` class of `base.py` has been updated to change the parameter name from `vertex` to `frontend_data` for improved clarity. This change enhances the readability and maintainability of the codebase. * refactor(graph): Update Graph class to return first layer in sort_interface_components_first The `sort_interface_components_first` method in the `Graph` class of `base.py` has been updated to return just the first layer of vertices. This change improves the functionality of the codebase by providing a more focused and efficient sorting mechanism. * refactor(graph): Update Graph class to use get_vertex method for building vertices The _build_vertices method in the Graph class of base.py has been updated to use the get_vertex method instead of creating a new vertex instance. This change improves the efficiency and maintainability of the codebase. * refactor(graph): Update Graph class to use astep method for asynchronous execution The `Graph` class in `base.py` has been updated to use the `astep` method for asynchronous execution of vertices. This change improves the efficiency and maintainability of the codebase by leveraging asyncio and allowing for concurrent execution of vertices. * feat(base.py): implement methods to add components and component edges in the Graph class * refactor(graph): Import nest_asyncio for asynchronous execution in Graph class * refactor(base.py): Update Vertex class to handle parameter dictionaries in build_params The `build_params` method in the `Vertex` class of `base.py` has been updated to handle parameter dictionaries correctly. If the `param_dict` is empty or has more than one key, the method now sets the parameter value to the vertex that is the source of the edge. Otherwise, it sets the parameter value to a dictionary with keys corresponding to the keys in the `param_dict` and values as the vertex that is the source of the edge. This change improves the functionality and maintainability of the codebase. * refactor(base.py): Add methods to set input values and add component instances in Vertex class The `Vertex` class in `base.py` has been refactored to include two new methods: `set_input_value` and `add_component_instance`. The `set_input_value` method allows setting input values for a vertex by name, while the `add_component_instance` method adds a component instance to the vertex. These changes enhance the functionality and maintainability of the codebase. * refactor(message.py): Update _timestamp_to_str to handle datetime or str input The `_timestamp_to_str` function in `message.py` has been updated to handle both `datetime` and `str` input. If the input is a `datetime` object, it will be formatted as a string using the "%Y-%m-%d %H:%M:%S" format. If the input is already a string, it will be returned as is. This change improves the flexibility and usability of the function. * refactor(test_base.py): Add unit tests for Graph class Unit tests have been added to the `test_base.py` file to test the functionality of the `Graph` class. These tests ensure that the graph is prepared correctly, components are added and connected properly, and the graph executes as expected. This change improves the reliability and maintainability of the codebase. * refactor(initialize/loading.py): Refactor get_instance_results function The `get_instance_results` function in `initialize/loading.py` has been refactored to simplify the logic for building custom components and components. The previous implementation had separate checks for `CustomComponent` and `Component` types, but the refactored version combines these checks into a single condition based on the `base_type` parameter. This change improves the readability and maintainability of the codebase. * refactor(component.py): Add set_input_value method to Component class The `set_input_value` method has been added to the `Component` class in `component.py`. This method allows setting the value of an input by name, and also updates the `load_from_db` attribute if applicable. This change enhances the functionality and maintainability of the codebase. * refactor(component.py): Set input value in _set_parameter_or_attribute method The `_set_parameter_or_attribute` method in the `Component` class now sets the input value using the `set_input_value` method. This change improves the clarity and consistency of the codebase. * refactor(inputs.py): Improve error message for invalid value type The `SecretStrInput` class in `inputs.py` has been updated to improve the error message when an invalid value type is encountered. Instead of a generic error message, the new message includes the specific value type and the name of the input. This change enhances the clarity and usability of the codebase. * feat: Add unit test for memory chatbot functionality * refactor(base.py): Update __repr__ method in ContractEdge class The `__repr__` method in the `ContractEdge` class of `base.py` has been updated to include the source handle and target handle information when available. This change improves the readability and clarity of the representation of the edge in the codebase. * refactor(component.py): Update set method to return self The `set` method in the `Component` class of `component.py` has been updated to return `self` after processing the connection or parameter. This change improves the chaining capability of the method and enhances the readability and consistency of the codebase. * refactor(starter_projects): Add unit test for vector store RAG A unit test has been added to the `test_vector_store_rag.py` file in the `starter_projects` directory. This test ensures that the vector store RAG graph is set up correctly and produces the expected results. This change improves the reliability and maintainability of the codebase. * refactor: remove unused prepare method in Graph class * refactor: update Output class to use list[str] for types field * refactor: add name validation to BaseInputMixin * refactor: update ContractEdge __repr__ method for improved readability and consistency * refactor: update BaseInputMixin to ensure name field is required with appropriate description * refactor: remove name validation from BaseInputMixin * refactor: update input tests to include 'name' field in all input types for better validation and clarity * refactor: enhance Component class with methods to validate callable outputs and inheritance checks for better robustness * refactor: disable load_from_db for inputs in Component class to improve input handling logic and prevent unwanted database loading * refactor: add test for setting invalid output in test_component.py --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
dd3a92b196
commit
86ca23397e
15 changed files with 601 additions and 103 deletions
|
|
@ -80,6 +80,7 @@ class Component(CustomComponent):
|
|||
"""
|
||||
for key, value in kwargs.items():
|
||||
self._process_connection_or_parameter(key, value)
|
||||
return self
|
||||
|
||||
def list_inputs(self):
|
||||
"""
|
||||
|
|
@ -219,9 +220,32 @@ class Component(CustomComponent):
|
|||
raise ValueError(f"Output with method {method_name} not found")
|
||||
return output
|
||||
|
||||
def _inherits_from_component(self, method: Callable):
|
||||
# check if the method is a method from a class that inherits from Component
|
||||
# and that it is an output of that class
|
||||
inherits_from_component = hasattr(method, "__self__") and isinstance(method.__self__, Component)
|
||||
return inherits_from_component
|
||||
|
||||
def _method_is_valid_output(self, method: Callable):
|
||||
# check if the method is a method from a class that inherits from Component
|
||||
# and that it is an output of that class
|
||||
method_is_output = (
|
||||
hasattr(method, "__self__")
|
||||
and isinstance(method.__self__, Component)
|
||||
and method.__self__._get_output_by_method(method)
|
||||
)
|
||||
return method_is_output
|
||||
|
||||
def _process_connection_or_parameter(self, key, value):
|
||||
_input = self._get_or_create_input(key)
|
||||
if callable(value):
|
||||
# We need to check if callable AND if it is a method from a class that inherits from Component
|
||||
if callable(value) and self._inherits_from_component(value):
|
||||
try:
|
||||
self._method_is_valid_output(value)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Method {value.__name__} is not a valid output of {value.__self__.__class__.__name__}"
|
||||
)
|
||||
self._connect_to_component(key, value, _input)
|
||||
else:
|
||||
self._set_parameter_or_attribute(key, value)
|
||||
|
|
@ -264,6 +288,7 @@ class Component(CustomComponent):
|
|||
)
|
||||
|
||||
def _set_parameter_or_attribute(self, key, value):
|
||||
self._set_input_value(key, value)
|
||||
self._parameters[key] = value
|
||||
self._attributes[key] = value
|
||||
|
||||
|
|
@ -302,7 +327,8 @@ class Component(CustomComponent):
|
|||
f"Input {name} is connected to {input_value.__self__.display_name}.{input_value.__name__}"
|
||||
)
|
||||
self._inputs[name].value = value
|
||||
self._attributes[name] = value
|
||||
if hasattr(self._inputs[name], "load_from_db"):
|
||||
self._inputs[name].load_from_db = False
|
||||
else:
|
||||
raise ValueError(f"Input {name} not found in {self.__class__.__name__}")
|
||||
|
||||
|
|
|
|||
|
|
@ -227,4 +227,8 @@ class ContractEdge(Edge):
|
|||
return self.result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if (hasattr(self, "source_handle") and self.source_handle) and (
|
||||
hasattr(self, "target_handle") and self.target_handle
|
||||
):
|
||||
return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.fieldName}]-> {self.target_id}"
|
||||
return f"{self.source_id} -[{self.target_param}]-> {self.target_id}"
|
||||
|
|
|
|||
|
|
@ -6,19 +6,20 @@ from functools import partial
|
|||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
|
||||
import nest_asyncio
|
||||
from loguru import logger
|
||||
|
||||
from langflow.exceptions.component import ComponentBuildException
|
||||
from langflow.graph.edge.base import ContractEdge
|
||||
from langflow.graph.edge.schema import EdgeData
|
||||
from langflow.graph.graph.constants import lazy_load_vertex_dict
|
||||
from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict
|
||||
from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager
|
||||
from langflow.graph.graph.schema import VertexBuildResult
|
||||
from langflow.graph.graph.state_manager import GraphStateManager
|
||||
from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex
|
||||
from langflow.graph.schema import InterfaceComponentTypes, RunOutputs
|
||||
from langflow.graph.vertex.base import Vertex, VertexStates
|
||||
from langflow.graph.vertex.types import InterfaceVertex, StateVertex
|
||||
from langflow.graph.vertex.types import ComponentVertex, InterfaceVertex, StateVertex
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
|
||||
from langflow.services.cache.utils import CacheMiss
|
||||
|
|
@ -26,6 +27,8 @@ from langflow.services.chat.schema import GetCache, SetCache
|
|||
from langflow.services.deps import get_chat_service, get_tracing_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.api.v1.schemas import InputValueRequest
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.graph.schema import ResultData
|
||||
from langflow.services.tracing.service import TracingService
|
||||
|
||||
|
|
@ -35,6 +38,8 @@ class Graph:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
start: Optional["Component"] = None,
|
||||
end: Optional["Component"] = None,
|
||||
flow_id: Optional[str] = None,
|
||||
flow_name: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
|
|
@ -47,6 +52,7 @@ class Graph:
|
|||
edges (List[Dict[str, str]]): A list of dictionaries representing the edges of the graph.
|
||||
flow_id (Optional[str], optional): The ID of the flow. Defaults to None.
|
||||
"""
|
||||
self._prepared = False
|
||||
self._runs = 0
|
||||
self._updates = 0
|
||||
self.flow_id = flow_id
|
||||
|
|
@ -69,12 +75,27 @@ class Graph:
|
|||
self.vertices: List[Vertex] = []
|
||||
self.run_manager = RunnableVerticesManager()
|
||||
self.state_manager = GraphStateManager()
|
||||
self._vertices: List[dict] = []
|
||||
self._edges: List[EdgeData] = []
|
||||
self.top_level_vertices: List[str] = []
|
||||
self.vertex_map: Dict[str, Vertex] = {}
|
||||
self.predecessor_map: Dict[str, List[str]] = defaultdict(list)
|
||||
self.successor_map: Dict[str, List[str]] = defaultdict(list)
|
||||
self.in_degree_map: Dict[str, int] = defaultdict(int)
|
||||
self.parent_child_map: Dict[str, List[str]] = defaultdict(list)
|
||||
self._run_queue: deque[str] = deque()
|
||||
self._first_layer: List[str] = []
|
||||
self._lock = asyncio.Lock()
|
||||
try:
|
||||
self.tracing_service: "TracingService" | None = get_tracing_service()
|
||||
except Exception as exc:
|
||||
logger.error(f"Error getting tracing service: {exc}")
|
||||
self.tracing_service = None
|
||||
if start is not None and end is not None:
|
||||
self._set_start_and_end(start, end)
|
||||
self.prepare()
|
||||
if (start is not None and end is None) or (start is None and end is not None):
|
||||
raise ValueError("You must provide both input and output components")
|
||||
|
||||
def add_nodes_and_edges(self, nodes: List[Dict], edges: List[EdgeData]):
|
||||
self._vertices = nodes
|
||||
|
|
@ -90,11 +111,111 @@ class Graph:
|
|||
self._edges = self._graph_data["edges"]
|
||||
self.initialize()
|
||||
|
||||
def add_component(self, _id: str, component: "Component"):
|
||||
if _id in self.vertex_map:
|
||||
return
|
||||
frontend_node = component.to_frontend_node()
|
||||
frontend_node["data"]["id"] = _id
|
||||
frontend_node["id"] = _id
|
||||
self._vertices.append(frontend_node)
|
||||
vertex = self._create_vertex(frontend_node)
|
||||
vertex.add_component_instance(component)
|
||||
self.vertices.append(vertex)
|
||||
self.vertex_map[_id] = vertex
|
||||
|
||||
if component._edges:
|
||||
for edge in component._edges:
|
||||
self._add_edge(edge)
|
||||
|
||||
if component._components:
|
||||
for _component in component._components:
|
||||
self.add_component(_component._id, _component)
|
||||
|
||||
def _set_start_and_end(self, start: "Component", end: "Component"):
|
||||
if not hasattr(start, "to_frontend_node"):
|
||||
raise TypeError(f"start must be a Component. Got {type(start)}")
|
||||
if not hasattr(end, "to_frontend_node"):
|
||||
raise TypeError(f"end must be a Component. Got {type(end)}")
|
||||
self.add_component(start._id, start)
|
||||
self.add_component(end._id, end)
|
||||
|
||||
def add_component_edge(self, source_id: str, output_input_tuple: Tuple[str, str], target_id: str):
|
||||
source_vertex = self.get_vertex(source_id)
|
||||
if not isinstance(source_vertex, ComponentVertex):
|
||||
raise ValueError(f"Source vertex {source_id} is not a component vertex.")
|
||||
target_vertex = self.get_vertex(target_id)
|
||||
if not isinstance(target_vertex, ComponentVertex):
|
||||
raise ValueError(f"Target vertex {target_id} is not a component vertex.")
|
||||
output_name, input_name = output_input_tuple
|
||||
edge_data: EdgeData = {
|
||||
"source": source_id,
|
||||
"target": target_id,
|
||||
"data": {
|
||||
"sourceHandle": {
|
||||
"dataType": source_vertex.base_name,
|
||||
"id": source_vertex.id,
|
||||
"name": output_name,
|
||||
"output_types": source_vertex.get_output(output_name).types,
|
||||
},
|
||||
"targetHandle": {
|
||||
"fieldName": input_name,
|
||||
"id": target_vertex.id,
|
||||
"inputTypes": target_vertex.get_input(input_name).input_types,
|
||||
"type": str(target_vertex.get_input(input_name).field_type),
|
||||
},
|
||||
},
|
||||
}
|
||||
self._add_edge(edge_data)
|
||||
|
||||
async def async_start(self, inputs: Optional[List[dict]] = None):
|
||||
if not self._prepared:
|
||||
raise ValueError("Graph not prepared. Call prepare() first.")
|
||||
# The idea is for this to return a generator that yields the result of
|
||||
# each step call and raise StopIteration when the graph is done
|
||||
for _input in inputs or []:
|
||||
for key, value in _input.items():
|
||||
vertex = self.get_vertex(key)
|
||||
vertex.set_input_value(key, value)
|
||||
while True:
|
||||
result = await self.astep()
|
||||
yield result
|
||||
if isinstance(result, Finish):
|
||||
return
|
||||
|
||||
def start(self, inputs: Optional[List[dict]] = None) -> Generator:
|
||||
#! Change this soon
|
||||
nest_asyncio.apply()
|
||||
loop = asyncio.get_event_loop()
|
||||
async_gen = self.async_start(inputs)
|
||||
async_gen_task = asyncio.ensure_future(async_gen.__anext__())
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = loop.run_until_complete(async_gen_task)
|
||||
yield result
|
||||
if isinstance(result, Finish):
|
||||
return
|
||||
async_gen_task = asyncio.ensure_future(async_gen.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
def _add_edge(self, edge: EdgeData):
|
||||
self.add_edge(edge)
|
||||
source_id = edge["data"]["sourceHandle"]["id"]
|
||||
target_id = edge["data"]["targetHandle"]["id"]
|
||||
self.predecessor_map[target_id].append(source_id)
|
||||
self.successor_map[source_id].append(target_id)
|
||||
self.in_degree_map[target_id] += 1
|
||||
self.parent_child_map[source_id].append(target_id)
|
||||
|
||||
# TODO: Create a TypedDict to represente the node
|
||||
def add_node(self, node: dict):
|
||||
self._vertices.append(node)
|
||||
|
||||
def add_edge(self, edge: EdgeData):
|
||||
# Check if the edge already exists
|
||||
if edge in self._edges:
|
||||
return
|
||||
self._edges.append(edge)
|
||||
|
||||
def initialize(self):
|
||||
|
|
@ -303,6 +424,20 @@ class Graph:
|
|||
if getattr(vertex, attribute):
|
||||
getattr(self, f"_{attribute}_vertices").append(vertex.id)
|
||||
|
||||
def _set_inputs(self, input_components: list[str], inputs: Dict[str, str], input_type: InputType | None):
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
# If the vertex is not in the input_components list
|
||||
if input_components and (vertex_id not in input_components and vertex.display_name not in input_components):
|
||||
continue
|
||||
# If the input_type is not any and the input_type is not in the vertex id
|
||||
# Example: input_type = "chat" and vertex.id = "OpenAI-19ddn"
|
||||
elif input_type is not None and input_type != "any" and input_type not in vertex.id.lower():
|
||||
continue
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
vertex.update_raw_params(inputs, overwrite=True)
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
inputs: Dict[str, str],
|
||||
|
|
@ -335,20 +470,7 @@ class Graph:
|
|||
if not isinstance(inputs.get(INPUT_FIELD_NAME, ""), str):
|
||||
raise ValueError(f"Invalid input value: {inputs.get(INPUT_FIELD_NAME)}. Expected string")
|
||||
if inputs:
|
||||
for vertex_id in self._is_input_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
# If the vertex is not in the input_components list
|
||||
if input_components and (
|
||||
vertex_id not in input_components and vertex.display_name not in input_components
|
||||
):
|
||||
continue
|
||||
# If the input_type is not any and the input_type is not in the vertex id
|
||||
# Example: input_type = "chat" and vertex.id = "OpenAI-19ddn"
|
||||
elif input_type is not None and input_type != "any" and input_type not in vertex.id.lower():
|
||||
continue
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
vertex.update_raw_params(inputs, overwrite=True)
|
||||
self._set_inputs(input_components, inputs, input_type)
|
||||
# Update all the vertices with the session_id
|
||||
for vertex_id in self._has_session_id_vertices:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
|
|
@ -857,6 +979,50 @@ class Graph:
|
|||
return vertex
|
||||
raise ValueError(f"Vertex {vertex_id} is not a top level vertex or no root vertex found")
|
||||
|
||||
async def astep(
|
||||
self,
|
||||
inputs: Optional["InputValueRequest"] = None,
|
||||
files: Optional[list[str]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
if not self._prepared:
|
||||
raise ValueError("Graph not prepared. Call prepare() first.")
|
||||
if not self._run_queue:
|
||||
asyncio.create_task(self.end_all_traces())
|
||||
return Finish()
|
||||
vertex_id = self._run_queue.popleft()
|
||||
chat_service = get_chat_service()
|
||||
vertex_build_result = await self.build_vertex(
|
||||
vertex_id=vertex_id,
|
||||
user_id=user_id,
|
||||
inputs_dict=inputs.model_dump() if inputs else {},
|
||||
files=files,
|
||||
get_cache=chat_service.get_cache,
|
||||
set_cache=chat_service.set_cache,
|
||||
)
|
||||
|
||||
next_runnable_vertices = await self.get_next_runnable_vertices(
|
||||
self._lock, vertex=vertex_build_result.vertex, cache=False
|
||||
)
|
||||
if self.stop_vertex and self.stop_vertex in next_runnable_vertices:
|
||||
next_runnable_vertices = [self.stop_vertex]
|
||||
self._run_queue.extend(next_runnable_vertices)
|
||||
self.reset_inactivated_vertices()
|
||||
self.reset_activated_vertices()
|
||||
|
||||
await chat_service.set_cache(str(self.flow_id or self._run_id), self)
|
||||
return vertex_build_result
|
||||
|
||||
def step(
|
||||
self,
|
||||
inputs: Optional["InputValueRequest"] = None,
|
||||
files: Optional[list[str]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
):
|
||||
# Call astep but synchronously
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(self.astep(inputs, files, user_id))
|
||||
|
||||
async def build_vertex(
|
||||
self,
|
||||
vertex_id: str,
|
||||
|
|
@ -1037,9 +1203,9 @@ class Graph:
|
|||
next_runnable_vertices.remove(v_id)
|
||||
else:
|
||||
self.run_manager.add_to_vertices_being_run(next_v_id)
|
||||
if cache and self.flow_id:
|
||||
set_cache_coro = partial(get_chat_service().set_cache, self.flow_id)
|
||||
await set_cache_coro(self, lock)
|
||||
if cache and self.flow_id is not None:
|
||||
set_cache_coro = partial(get_chat_service().set_cache, key=self.flow_id)
|
||||
await set_cache_coro(data=self, lock=lock)
|
||||
return next_runnable_vertices
|
||||
|
||||
async def _execute_tasks(self, tasks: List[asyncio.Task], lock: asyncio.Lock) -> List[str]:
|
||||
|
|
@ -1185,19 +1351,22 @@ class Graph:
|
|||
|
||||
edges: set[ContractEdge] = set()
|
||||
for edge in self._edges:
|
||||
source = self.get_vertex(edge["source"])
|
||||
target = self.get_vertex(edge["target"])
|
||||
|
||||
if source is None:
|
||||
raise ValueError(f"Source vertex {edge['source']} not found")
|
||||
if target is None:
|
||||
raise ValueError(f"Target vertex {edge['target']} not found")
|
||||
new_edge = ContractEdge(source, target, edge)
|
||||
|
||||
new_edge = self.build_edge(edge)
|
||||
edges.add(new_edge)
|
||||
|
||||
return list(edges)
|
||||
|
||||
def build_edge(self, edge: EdgeData) -> ContractEdge:
|
||||
source = self.get_vertex(edge["source"])
|
||||
target = self.get_vertex(edge["target"])
|
||||
|
||||
if source is None:
|
||||
raise ValueError(f"Source vertex {edge['source']} not found")
|
||||
if target is None:
|
||||
raise ValueError(f"Target vertex {edge['target']} not found")
|
||||
new_edge = ContractEdge(source, target, edge)
|
||||
return new_edge
|
||||
|
||||
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
|
||||
"""Returns the node class based on the node type."""
|
||||
# First we check for the node_base_type
|
||||
|
|
@ -1222,14 +1391,17 @@ class Graph:
|
|||
def _build_vertices(self) -> List[Vertex]:
|
||||
"""Builds the vertices of the graph."""
|
||||
vertices: List[Vertex] = []
|
||||
for vertex in self._vertices:
|
||||
vertex_instance = self._create_vertex(vertex)
|
||||
for frontend_data in self._vertices:
|
||||
try:
|
||||
vertex_instance = self.get_vertex(frontend_data["id"])
|
||||
except ValueError:
|
||||
vertex_instance = self._create_vertex(frontend_data)
|
||||
vertices.append(vertex_instance)
|
||||
|
||||
return vertices
|
||||
|
||||
def _create_vertex(self, vertex: dict):
|
||||
vertex_data = vertex["data"]
|
||||
def _create_vertex(self, frontend_data: dict):
|
||||
vertex_data = frontend_data["data"]
|
||||
vertex_type: str = vertex_data["type"] # type: ignore
|
||||
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
|
||||
if "id" not in vertex_data:
|
||||
|
|
@ -1237,7 +1409,7 @@ class Graph:
|
|||
|
||||
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
|
||||
|
||||
vertex_instance = VertexClass(vertex, graph=self)
|
||||
vertex_instance = VertexClass(frontend_data, graph=self)
|
||||
vertex_instance.set_top_level(self.top_level_vertices)
|
||||
return vertex_instance
|
||||
|
||||
|
|
@ -1456,6 +1628,7 @@ class Graph:
|
|||
self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)}
|
||||
self.build_run_map()
|
||||
# Return just the first layer
|
||||
self._first_layer = first_layer
|
||||
return first_layer
|
||||
|
||||
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,25 @@
|
|||
from langflow.graph.schema import CHAT_COMPONENTS
|
||||
from langflow.graph.vertex import types
|
||||
from langflow.utils.lazy_load import LazyLoadDictBase
|
||||
|
||||
|
||||
class Finish:
|
||||
def __bool__(self):
|
||||
return True
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, Finish)
|
||||
|
||||
|
||||
def _import_vertex_types():
|
||||
from langflow.graph.vertex import types
|
||||
|
||||
return types
|
||||
|
||||
|
||||
class VertexTypesDict(LazyLoadDictBase):
|
||||
def __init__(self):
|
||||
self._all_types_dict = None
|
||||
self._types = _import_vertex_types()
|
||||
|
||||
@property
|
||||
def VERTEX_TYPE_MAP(self):
|
||||
|
|
@ -20,13 +34,13 @@ class VertexTypesDict(LazyLoadDictBase):
|
|||
|
||||
def get_type_dict(self):
|
||||
return {
|
||||
**{t: types.CustomComponentVertex for t in ["CustomComponent"]},
|
||||
**{t: types.ComponentVertex for t in ["Component"]},
|
||||
**{t: types.InterfaceVertex for t in CHAT_COMPONENTS},
|
||||
**{t: self._types.CustomComponentVertex for t in ["CustomComponent"]},
|
||||
**{t: self._types.ComponentVertex for t in ["Component"]},
|
||||
**{t: self._types.InterfaceVertex for t in CHAT_COMPONENTS},
|
||||
}
|
||||
|
||||
def get_custom_component_vertex_type(self):
|
||||
return types.CustomComponentVertex
|
||||
return self._types.CustomComponentVertex
|
||||
|
||||
|
||||
lazy_load_vertex_dict = VertexTypesDict()
|
||||
|
|
|
|||
|
|
@ -96,6 +96,15 @@ class Vertex:
|
|||
self.build_times: List[float] = []
|
||||
self.state = VertexStates.ACTIVE
|
||||
|
||||
def set_input_value(self, name: str, value: Any):
|
||||
if self._custom_component is None:
|
||||
raise ValueError(f"Vertex {self.id} does not have a component instance.")
|
||||
self._custom_component._set_input_value(name, value)
|
||||
|
||||
def add_component_instance(self, component_instance: "Component"):
|
||||
component_instance.set_vertex(self)
|
||||
self._custom_component = component_instance
|
||||
|
||||
def add_result(self, name: str, result: Any):
|
||||
self.results[name] = result
|
||||
|
||||
|
|
@ -289,12 +298,13 @@ class Vertex:
|
|||
# we don't know the key of the dict but we need to set the value
|
||||
# to the vertex that is the source of the edge
|
||||
param_dict = template_dict[param_key]["value"]
|
||||
if param_dict:
|
||||
if not param_dict or len(param_dict) != 1:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
else:
|
||||
params[param_key] = {
|
||||
key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys()
|
||||
}
|
||||
else:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
|
||||
else:
|
||||
params[param_key] = self.graph.get_vertex(edge.source_id)
|
||||
|
||||
|
|
|
|||
|
|
@ -40,12 +40,12 @@ class BaseInputMixin(BaseModel, validate_assignment=True): # type: ignore
|
|||
show: bool = True
|
||||
"""Should the field be shown. Defaults to True."""
|
||||
|
||||
name: str = Field(description="Name of the field.")
|
||||
"""Name of the field. Default is an empty string."""
|
||||
|
||||
value: Any = ""
|
||||
"""The value of the field. Default is an empty string."""
|
||||
|
||||
name: Optional[str] = None
|
||||
"""Name of the field. Default is an empty string."""
|
||||
|
||||
display_name: Optional[str] = None
|
||||
"""Display name of the field. Defaults to None."""
|
||||
|
||||
|
|
|
|||
|
|
@ -273,7 +273,7 @@ class SecretStrInput(BaseInputMixin, DatabaseLoadMixin):
|
|||
elif isinstance(v, (AsyncIterator, Iterator)):
|
||||
value = v
|
||||
else:
|
||||
raise ValueError(f"Invalid value type {type(v)}")
|
||||
raise ValueError(f"Invalid value type `{type(v)}` for input `{_info.data['name']}`")
|
||||
return value
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@ import orjson
|
|||
from loguru import logger
|
||||
from pydantic import PydanticDeprecatedSince20
|
||||
|
||||
from langflow.custom import Component, CustomComponent
|
||||
from langflow.custom.eval import eval_custom_component_code
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.artifact import get_artifact_type, post_process_raw
|
||||
from langflow.services.deps import get_tracing_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langflow.custom import Component, CustomComponent
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
||||
|
||||
|
|
@ -54,9 +54,9 @@ async def get_instance_results(
|
|||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)
|
||||
if base_type == "custom_components" and isinstance(custom_component, CustomComponent):
|
||||
if base_type == "custom_components":
|
||||
return await build_custom_component(params=custom_params, custom_component=custom_component)
|
||||
elif base_type == "component" and isinstance(custom_component, Component):
|
||||
elif base_type == "component":
|
||||
return await build_component(params=custom_params, custom_component=custom_component)
|
||||
else:
|
||||
raise ValueError(f"Base type {base_type} not found.")
|
||||
|
|
|
|||
|
|
@ -23,8 +23,10 @@ from langflow.utils.constants import (
|
|||
)
|
||||
|
||||
|
||||
def _timestamp_to_str(timestamp: datetime) -> str:
|
||||
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
def _timestamp_to_str(timestamp: datetime | str) -> str:
|
||||
if isinstance(timestamp, datetime):
|
||||
return timestamp.strftime("%Y-%m-%d %H:%M:%S")
|
||||
return timestamp
|
||||
|
||||
|
||||
class Message(Data):
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class Input(BaseModel):
|
|||
|
||||
|
||||
class Output(BaseModel):
|
||||
types: Optional[list[str]] = Field(default=[])
|
||||
types: list[str] = Field(default=[])
|
||||
"""List of output types for the field."""
|
||||
|
||||
selected: Optional[str] = Field(default=None)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,16 @@
|
|||
import pytest
|
||||
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.outputs import ChatOutput
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
|
||||
|
||||
def test_set_invalid_output():
|
||||
chatinput = ChatInput()
|
||||
chatoutput = ChatOutput()
|
||||
with pytest.raises(ValueError):
|
||||
chatoutput.set(input_value=chatinput.build_config)
|
||||
131
src/backend/tests/unit/graph/graph/test_base.py
Normal file
131
src/backend/tests/unit/graph/graph/test_base.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
from collections import deque
|
||||
|
||||
import pytest
|
||||
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
from langflow.components.outputs.TextOutput import TextOutputComponent
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_not_prepared():
|
||||
chat_input = ChatInput()
|
||||
chat_output = ChatOutput()
|
||||
graph = Graph()
|
||||
graph.add_component("chat_input", chat_input)
|
||||
graph.add_component("chat_output", chat_output)
|
||||
graph.add_component_edge("chat_input", (chat_input.outputs[0].name, chat_input.inputs[0].name), "chat_output")
|
||||
with pytest.raises(ValueError):
|
||||
await graph.astep()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph():
|
||||
chat_input = ChatInput()
|
||||
chat_output = ChatOutput()
|
||||
graph = Graph()
|
||||
graph.add_component("chat_input", chat_input)
|
||||
graph.add_component("chat_output", chat_output)
|
||||
graph.add_component_edge("chat_input", (chat_input.outputs[0].name, chat_input.inputs[0].name), "chat_output")
|
||||
graph.prepare()
|
||||
assert graph._run_queue == deque(["chat_input"])
|
||||
await graph.astep()
|
||||
assert graph._run_queue == deque(["chat_output"])
|
||||
|
||||
assert graph.vertices[0].id == "chat_input"
|
||||
assert graph.vertices[1].id == "chat_output"
|
||||
assert graph.edges[0].source_id == "chat_input"
|
||||
assert graph.edges[0].target_id == "chat_output"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_functional():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
graph = Graph(chat_input, chat_output)
|
||||
assert graph._run_queue == deque(["chat_input"])
|
||||
await graph.astep()
|
||||
assert graph._run_queue == deque(["chat_output"])
|
||||
|
||||
assert graph.vertices[0].id == "chat_input"
|
||||
assert graph.vertices[1].id == "chat_output"
|
||||
assert graph.edges[0].source_id == "chat_input"
|
||||
assert graph.edges[0].target_id == "chat_output"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_functional_async_start():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
graph = Graph(chat_input, chat_output)
|
||||
# Now iterate through the graph
|
||||
# and check that the graph is running
|
||||
# correctly
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
async for result in graph.async_start():
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
assert results[-1] == Finish()
|
||||
|
||||
|
||||
def test_graph_functional_start():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(sender_name=chat_input.message_response)
|
||||
graph = Graph(chat_input, chat_output)
|
||||
graph.prepare()
|
||||
# Now iterate through the graph
|
||||
# and check that the graph is running
|
||||
# correctly
|
||||
ids = ["chat_input", "chat_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 3
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
assert results[-1] == Finish()
|
||||
|
||||
|
||||
def test_graph_functional_start_end():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
text_output = TextOutputComponent(_id="text_output")
|
||||
text_output.set(input_value=chat_input.message_response)
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
chat_output.set(input_value=text_output.text_response)
|
||||
graph = Graph(chat_input, text_output)
|
||||
graph.prepare()
|
||||
# Now iterate through the graph
|
||||
# and check that the graph is running
|
||||
# correctly
|
||||
ids = ["chat_input", "text_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == len(ids) + 1
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
assert results[-1] == Finish()
|
||||
# Now, using the same components but different start and end components
|
||||
graph = Graph(chat_input, chat_output)
|
||||
graph.prepare()
|
||||
ids = ["chat_input", "chat_output", "text_output"]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == len(ids) + 1
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
assert results[-1] == Finish()
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
from collections import deque
|
||||
|
||||
from langflow.components.helpers.Memory import MemoryComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
from langflow.components.prompts.Prompt import PromptComponent
|
||||
from langflow.graph import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
|
||||
|
||||
def test_memory_chatbot():
|
||||
session_id = "test_session_id"
|
||||
template = """{context}
|
||||
|
||||
User: {user_message}
|
||||
AI: """
|
||||
memory_component = MemoryComponent(_id="chat_memory")
|
||||
memory_component.set(session_id=session_id)
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
prompt_component = PromptComponent(_id="prompt")
|
||||
prompt_component.set(
|
||||
template=template, user_message=chat_input.message_response, context=memory_component.retrieve_messages_as_text
|
||||
)
|
||||
openai_component = OpenAIModelComponent(_id="openai")
|
||||
openai_component.set(
|
||||
input_value=prompt_component.build_prompt, max_tokens=100, temperature=0.1, api_key="test_api_key"
|
||||
)
|
||||
openai_component.get_output("text_output").value = "Mock response"
|
||||
|
||||
chat_output = ChatOutput(_id="chat_output")
|
||||
chat_output.set(input_value=openai_component.text_response)
|
||||
|
||||
graph = Graph(chat_input, chat_output)
|
||||
# Now we run step by step
|
||||
expected_order = deque(["chat_input", "chat_memory", "prompt", "openai", "chat_output"])
|
||||
for step in expected_order:
|
||||
result = graph.step()
|
||||
if isinstance(result, Finish):
|
||||
break
|
||||
assert step == result.vertex.id
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
from textwrap import dedent
|
||||
|
||||
from langflow.components.data.File import FileComponent
|
||||
from langflow.components.embeddings.OpenAIEmbeddings import OpenAIEmbeddingsComponent
|
||||
from langflow.components.helpers.ParseData import ParseDataComponent
|
||||
from langflow.components.helpers.SplitText import SplitTextComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.models.OpenAIModel import OpenAIModelComponent
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
from langflow.components.prompts.Prompt import PromptComponent
|
||||
from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.schema.data import Data
|
||||
|
||||
|
||||
def test_vector_store_rag():
|
||||
# Ingestion Graph
|
||||
file_component = FileComponent(_id="file-123")
|
||||
file_component.set(path="test.txt")
|
||||
text_splitter = SplitTextComponent(_id="text-splitter-123")
|
||||
text_splitter.set(data_inputs=file_component.load_file)
|
||||
openai_embeddings = OpenAIEmbeddingsComponent(_id="openai-embeddings-123")
|
||||
openai_embeddings.set(
|
||||
openai_api_key="sk-123", openai_api_base="https://api.openai.com/v1", openai_api_type="openai"
|
||||
)
|
||||
vector_store = AstraVectorStoreComponent(_id="vector-store-123")
|
||||
vector_store.set(
|
||||
embedding=openai_embeddings.build_embeddings,
|
||||
ingest_data=text_splitter.split_text,
|
||||
api_endpoint="https://astra.example.com",
|
||||
token="token",
|
||||
)
|
||||
|
||||
# RAG Graph
|
||||
chat_input = ChatInput(_id="chatinput-123")
|
||||
chat_input.get_output("message").value = "What is the meaning of life?"
|
||||
rag_vector_store = AstraVectorStoreComponent(_id="rag-vector-store-123")
|
||||
rag_vector_store.set(
|
||||
search_input=chat_input.message_response,
|
||||
api_endpoint="https://astra.example.com",
|
||||
token="token",
|
||||
embedding=openai_embeddings.build_embeddings,
|
||||
)
|
||||
# Mock search_documents
|
||||
rag_vector_store.get_output("search_results").value = [
|
||||
Data(data={"text": "Hello, world!"}),
|
||||
Data(data={"text": "Goodbye, world!"}),
|
||||
]
|
||||
parse_data = ParseDataComponent(_id="parse-data-123")
|
||||
parse_data.set(data=rag_vector_store.search_documents)
|
||||
prompt_component = PromptComponent(_id="prompt-123")
|
||||
prompt_component.set(
|
||||
template=dedent("""Given the following context, answer the question.
|
||||
Context:{context}
|
||||
|
||||
Question: {question}
|
||||
Answer:"""),
|
||||
context=parse_data.parse_data,
|
||||
question=chat_input.message_response,
|
||||
)
|
||||
|
||||
openai_component = OpenAIModelComponent(_id="openai-123")
|
||||
openai_component.set(api_key="sk-123", openai_api_base="https://api.openai.com/v1")
|
||||
openai_component.set_output_value("text_output", "Hello, world!")
|
||||
openai_component.set(input_value=prompt_component.build_prompt)
|
||||
|
||||
chat_output = ChatOutput(_id="chatoutput-123")
|
||||
chat_output.set(input_value=openai_component.text_response)
|
||||
|
||||
graph = Graph(start=chat_input, end=chat_output)
|
||||
assert graph is not None
|
||||
ids = [
|
||||
"chatinput-123",
|
||||
"chatoutput-123",
|
||||
"openai-123",
|
||||
"parse-data-123",
|
||||
"prompt-123",
|
||||
"rag-vector-store-123",
|
||||
"openai-embeddings-123",
|
||||
]
|
||||
results = []
|
||||
for result in graph.start():
|
||||
results.append(result)
|
||||
|
||||
assert len(results) == 8
|
||||
vids = [result.vertex.id for result in results if hasattr(result, "vertex")]
|
||||
assert all(vid in ids for vid in vids), f"Diff: {set(vids) - set(ids)}"
|
||||
assert results[-1] == Finish()
|
||||
|
|
@ -31,51 +31,44 @@ def client():
|
|||
|
||||
|
||||
def test_table_input_valid():
|
||||
# Test with a valid list of dictionaries
|
||||
data = TableInput(value=[{"key": "value"}, {"key2": "value2"}])
|
||||
data = TableInput(name="valid_table", value=[{"key": "value"}, {"key2": "value2"}])
|
||||
assert data.value == [{"key": "value"}, {"key2": "value2"}]
|
||||
|
||||
|
||||
def test_table_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
# Test with an invalid value
|
||||
TableInput(value="invalid")
|
||||
TableInput(name="invalid_table", value="invalid")
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
# Test with a list containing invalid item
|
||||
TableInput(value=[{"key": "value"}, "invalid"])
|
||||
TableInput(name="invalid_table", value=[{"key": "value"}, "invalid"])
|
||||
|
||||
|
||||
def test_str_input_valid():
|
||||
data = StrInput(value="This is a string")
|
||||
data = StrInput(name="valid_str", value="This is a string")
|
||||
assert data.value == "This is a string"
|
||||
|
||||
|
||||
def test_str_input_invalid():
|
||||
with pytest.warns(UserWarning):
|
||||
# Test with an invalid value
|
||||
StrInput(value=1234)
|
||||
StrInput(name="invalid_str", value=1234)
|
||||
|
||||
|
||||
def test_message_text_input_valid():
|
||||
# Test with a valid string
|
||||
data = MessageTextInput(value="This is a message")
|
||||
data = MessageTextInput(name="valid_msg", value="This is a message")
|
||||
assert data.value == "This is a message"
|
||||
|
||||
# Test with a valid Message object
|
||||
msg = Message(text="This is a message")
|
||||
data = MessageTextInput(value=msg)
|
||||
data = MessageTextInput(name="valid_msg", value=msg)
|
||||
assert data.value == "This is a message"
|
||||
|
||||
|
||||
def test_message_text_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
# Test with an invalid value
|
||||
MessageTextInput(value=1234)
|
||||
MessageTextInput(name="invalid_msg", value=1234)
|
||||
|
||||
|
||||
def test_instantiate_input_valid():
|
||||
data = {"value": "This is a string"}
|
||||
data = {"name": "valid_input", "value": "This is a string"}
|
||||
input_instance = _instantiate_input("StrInput", data)
|
||||
assert isinstance(input_instance, StrInput)
|
||||
assert input_instance.value == "This is a string"
|
||||
|
|
@ -83,146 +76,145 @@ def test_instantiate_input_valid():
|
|||
|
||||
def test_instantiate_input_invalid():
|
||||
with pytest.raises(ValueError):
|
||||
# Test with an invalid input type
|
||||
_instantiate_input("InvalidInput", {"value": "This is a string"})
|
||||
_instantiate_input("InvalidInput", {"name": "invalid_input", "value": "This is a string"})
|
||||
|
||||
|
||||
def test_handle_input_valid():
|
||||
data = HandleInput(input_types=["BaseLanguageModel"])
|
||||
data = HandleInput(name="valid_handle", input_types=["BaseLanguageModel"])
|
||||
assert data.input_types == ["BaseLanguageModel"]
|
||||
|
||||
|
||||
def test_handle_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
HandleInput(input_types="BaseLanguageModel") # should be a list, not a string
|
||||
HandleInput(name="invalid_handle", input_types="BaseLanguageModel")
|
||||
|
||||
|
||||
def test_data_input_valid():
|
||||
data_input = DataInput(input_types=["Data"])
|
||||
data_input = DataInput(name="valid_data", input_types=["Data"])
|
||||
assert data_input.input_types == ["Data"]
|
||||
|
||||
|
||||
def test_prompt_input_valid():
|
||||
prompt_input = PromptInput(value="Enter your name")
|
||||
prompt_input = PromptInput(name="valid_prompt", value="Enter your name")
|
||||
assert prompt_input.value == "Enter your name"
|
||||
|
||||
|
||||
def test_multiline_input_valid():
|
||||
multiline_input = MultilineInput(value="This is a\nmultiline input")
|
||||
multiline_input = MultilineInput(name="valid_multiline", value="This is a\nmultiline input")
|
||||
assert multiline_input.value == "This is a\nmultiline input"
|
||||
assert multiline_input.multiline is True
|
||||
|
||||
|
||||
def test_multiline_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
MultilineInput(value=1234) # should be a string, not an integer
|
||||
MultilineInput(name="invalid_multiline", value=1234)
|
||||
|
||||
|
||||
def test_multiline_secret_input_valid():
|
||||
multiline_secret_input = MultilineSecretInput(value="secret")
|
||||
multiline_secret_input = MultilineSecretInput(name="valid_multiline_secret", value="secret")
|
||||
assert multiline_secret_input.value == "secret"
|
||||
assert multiline_secret_input.password is True
|
||||
|
||||
|
||||
def test_multiline_secret_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
MultilineSecretInput(value=1234) # should be a string, not an integer
|
||||
MultilineSecretInput(name="invalid_multiline_secret", value=1234)
|
||||
|
||||
|
||||
def test_secret_str_input_valid():
|
||||
secret_str_input = SecretStrInput(value="supersecret")
|
||||
secret_str_input = SecretStrInput(name="valid_secret_str", value="supersecret")
|
||||
assert secret_str_input.value == "supersecret"
|
||||
assert secret_str_input.password is True
|
||||
|
||||
|
||||
def test_secret_str_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
SecretStrInput(value=1234) # should be a string, not an integer
|
||||
SecretStrInput(name="invalid_secret_str", value=1234)
|
||||
|
||||
|
||||
def test_int_input_valid():
|
||||
int_input = IntInput(value=10)
|
||||
int_input = IntInput(name="valid_int", value=10)
|
||||
assert int_input.value == 10
|
||||
|
||||
|
||||
def test_int_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
IntInput(value="not_an_int") # should be an integer, not a string
|
||||
IntInput(name="invalid_int", value="not_an_int")
|
||||
|
||||
|
||||
def test_float_input_valid():
|
||||
float_input = FloatInput(value=10.5)
|
||||
float_input = FloatInput(name="valid_float", value=10.5)
|
||||
assert float_input.value == 10.5
|
||||
|
||||
|
||||
def test_float_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
FloatInput(value="not_a_float") # should be a float, not a string
|
||||
FloatInput(name="invalid_float", value="not_a_float")
|
||||
|
||||
|
||||
def test_bool_input_valid():
|
||||
bool_input = BoolInput(value=True)
|
||||
bool_input = BoolInput(name="valid_bool", value=True)
|
||||
assert bool_input.value is True
|
||||
|
||||
|
||||
def test_bool_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
BoolInput(value="not_a_bool") # should be a bool, not a string
|
||||
BoolInput(name="invalid_bool", value="not_a_bool")
|
||||
|
||||
|
||||
def test_nested_dict_input_valid():
|
||||
nested_dict_input = NestedDictInput(value={"key": "value"})
|
||||
nested_dict_input = NestedDictInput(name="valid_nested_dict", value={"key": "value"})
|
||||
assert nested_dict_input.value == {"key": "value"}
|
||||
|
||||
|
||||
def test_nested_dict_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
NestedDictInput(value="not_a_dict") # should be a dict, not a string
|
||||
NestedDictInput(name="invalid_nested_dict", value="not_a_dict")
|
||||
|
||||
|
||||
def test_dict_input_valid():
|
||||
dict_input = DictInput(value={"key": "value"})
|
||||
dict_input = DictInput(name="valid_dict", value={"key": "value"})
|
||||
assert dict_input.value == {"key": "value"}
|
||||
|
||||
|
||||
def test_dict_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
DictInput(value="not_a_dict") # should be a dict, not a string
|
||||
DictInput(name="invalid_dict", value="not_a_dict")
|
||||
|
||||
|
||||
def test_dropdown_input_valid():
|
||||
dropdown_input = DropdownInput(options=["option1", "option2"])
|
||||
dropdown_input = DropdownInput(name="valid_dropdown", options=["option1", "option2"])
|
||||
assert dropdown_input.options == ["option1", "option2"]
|
||||
|
||||
|
||||
def test_dropdown_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
DropdownInput(options="option1") # should be a list, not a string
|
||||
DropdownInput(name="invalid_dropdown", options="option1")
|
||||
|
||||
|
||||
def test_multiselect_input_valid():
|
||||
multiselect_input = MultiselectInput(value=["option1", "option2"])
|
||||
multiselect_input = MultiselectInput(name="valid_multiselect", value=["option1", "option2"])
|
||||
assert multiselect_input.value == ["option1", "option2"]
|
||||
|
||||
|
||||
def test_multiselect_input_invalid():
|
||||
with pytest.raises(ValidationError):
|
||||
MultiselectInput(value="option1") # should be a list, not a string
|
||||
MultiselectInput(name="invalid_multiselect", value="option1")
|
||||
|
||||
|
||||
def test_file_input_valid():
|
||||
file_input = FileInput(value=["/path/to/file"])
|
||||
file_input = FileInput(name="valid_file", value=["/path/to/file"])
|
||||
assert file_input.value == ["/path/to/file"]
|
||||
|
||||
|
||||
def test_instantiate_input_comprehensive():
|
||||
valid_data = {
|
||||
"StrInput": {"value": "A string"},
|
||||
"IntInput": {"value": 10},
|
||||
"FloatInput": {"value": 10.5},
|
||||
"BoolInput": {"value": True},
|
||||
"DictInput": {"value": {"key": "value"}},
|
||||
"MultiselectInput": {"value": ["option1", "option2"]},
|
||||
"StrInput": {"name": "str_input", "value": "A string"},
|
||||
"IntInput": {"name": "int_input", "value": 10},
|
||||
"FloatInput": {"name": "float_input", "value": 10.5},
|
||||
"BoolInput": {"name": "bool_input", "value": True},
|
||||
"DictInput": {"name": "dict_input", "value": {"key": "value"}},
|
||||
"MultiselectInput": {"name": "multiselect_input", "value": ["option1", "option2"]},
|
||||
}
|
||||
|
||||
for input_type, data in valid_data.items():
|
||||
|
|
@ -230,4 +222,4 @@ def test_instantiate_input_comprehensive():
|
|||
assert isinstance(input_instance, InputTypesMap[input_type])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_instantiate_input("InvalidInput", {"value": "Invalid"}) # Invalid input type
|
||||
_instantiate_input("InvalidInput", {"name": "invalid_input", "value": "Invalid"})
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue