diff --git a/src/backend/base/langflow/custom/custom_component/component.py b/src/backend/base/langflow/custom/custom_component/component.py index c734ad2e2..7c495e823 100644 --- a/src/backend/base/langflow/custom/custom_component/component.py +++ b/src/backend/base/langflow/custom/custom_component/component.py @@ -80,6 +80,7 @@ class Component(CustomComponent): """ for key, value in kwargs.items(): self._process_connection_or_parameter(key, value) + return self def list_inputs(self): """ @@ -219,9 +220,32 @@ class Component(CustomComponent): raise ValueError(f"Output with method {method_name} not found") return output + def _inherits_from_component(self, method: Callable): + # check if the method is a method from a class that inherits from Component + # and that it is an output of that class + inherits_from_component = hasattr(method, "__self__") and isinstance(method.__self__, Component) + return inherits_from_component + + def _method_is_valid_output(self, method: Callable): + # check if the method is a method from a class that inherits from Component + # and that it is an output of that class + method_is_output = ( + hasattr(method, "__self__") + and isinstance(method.__self__, Component) + and method.__self__._get_output_by_method(method) + ) + return method_is_output + def _process_connection_or_parameter(self, key, value): _input = self._get_or_create_input(key) - if callable(value): + # We need to check if callable AND if it is a method from a class that inherits from Component + if callable(value) and self._inherits_from_component(value): + try: + self._method_is_valid_output(value) + except ValueError: + raise ValueError( + f"Method {value.__name__} is not a valid output of {value.__self__.__class__.__name__}" + ) self._connect_to_component(key, value, _input) else: self._set_parameter_or_attribute(key, value) @@ -264,6 +288,7 @@ class Component(CustomComponent): ) def _set_parameter_or_attribute(self, key, value): + self._set_input_value(key, value) self._parameters[key] = value self._attributes[key] = value @@ -302,7 +327,8 @@ class Component(CustomComponent): f"Input {name} is connected to {input_value.__self__.display_name}.{input_value.__name__}" ) self._inputs[name].value = value - self._attributes[name] = value + if hasattr(self._inputs[name], "load_from_db"): + self._inputs[name].load_from_db = False else: raise ValueError(f"Input {name} not found in {self.__class__.__name__}") diff --git a/src/backend/base/langflow/graph/edge/base.py b/src/backend/base/langflow/graph/edge/base.py index a59eea256..549b3f095 100644 --- a/src/backend/base/langflow/graph/edge/base.py +++ b/src/backend/base/langflow/graph/edge/base.py @@ -227,4 +227,8 @@ class ContractEdge(Edge): return self.result def __repr__(self) -> str: + if (hasattr(self, "source_handle") and self.source_handle) and ( + hasattr(self, "target_handle") and self.target_handle + ): + return f"{self.source_id} -[{self.source_handle.name}->{self.target_handle.fieldName}]-> {self.target_id}" return f"{self.source_id} -[{self.target_param}]-> {self.target_id}" diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index 54c525137..e48519cb4 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -6,19 +6,20 @@ from functools import partial from itertools import chain from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Type, Union +import nest_asyncio from loguru import logger from langflow.exceptions.component import ComponentBuildException from langflow.graph.edge.base import ContractEdge from langflow.graph.edge.schema import EdgeData -from langflow.graph.graph.constants import lazy_load_vertex_dict +from langflow.graph.graph.constants import Finish, lazy_load_vertex_dict from langflow.graph.graph.runnable_vertices_manager import RunnableVerticesManager from langflow.graph.graph.schema import VertexBuildResult from langflow.graph.graph.state_manager import GraphStateManager from langflow.graph.graph.utils import find_start_component_id, process_flow, sort_up_to_vertex from langflow.graph.schema import InterfaceComponentTypes, RunOutputs from langflow.graph.vertex.base import Vertex, VertexStates -from langflow.graph.vertex.types import InterfaceVertex, StateVertex +from langflow.graph.vertex.types import ComponentVertex, InterfaceVertex, StateVertex from langflow.schema import Data from langflow.schema.schema import INPUT_FIELD_NAME, InputType from langflow.services.cache.utils import CacheMiss @@ -26,6 +27,8 @@ from langflow.services.chat.schema import GetCache, SetCache from langflow.services.deps import get_chat_service, get_tracing_service if TYPE_CHECKING: + from langflow.api.v1.schemas import InputValueRequest + from langflow.custom.custom_component.component import Component from langflow.graph.schema import ResultData from langflow.services.tracing.service import TracingService @@ -35,6 +38,8 @@ class Graph: def __init__( self, + start: Optional["Component"] = None, + end: Optional["Component"] = None, flow_id: Optional[str] = None, flow_name: Optional[str] = None, user_id: Optional[str] = None, @@ -47,6 +52,7 @@ class Graph: edges (List[Dict[str, str]]): A list of dictionaries representing the edges of the graph. flow_id (Optional[str], optional): The ID of the flow. Defaults to None. """ + self._prepared = False self._runs = 0 self._updates = 0 self.flow_id = flow_id @@ -69,12 +75,27 @@ class Graph: self.vertices: List[Vertex] = [] self.run_manager = RunnableVerticesManager() self.state_manager = GraphStateManager() + self._vertices: List[dict] = [] + self._edges: List[EdgeData] = [] + self.top_level_vertices: List[str] = [] + self.vertex_map: Dict[str, Vertex] = {} + self.predecessor_map: Dict[str, List[str]] = defaultdict(list) + self.successor_map: Dict[str, List[str]] = defaultdict(list) + self.in_degree_map: Dict[str, int] = defaultdict(int) + self.parent_child_map: Dict[str, List[str]] = defaultdict(list) + self._run_queue: deque[str] = deque() self._first_layer: List[str] = [] + self._lock = asyncio.Lock() try: self.tracing_service: "TracingService" | None = get_tracing_service() except Exception as exc: logger.error(f"Error getting tracing service: {exc}") self.tracing_service = None + if start is not None and end is not None: + self._set_start_and_end(start, end) + self.prepare() + if (start is not None and end is None) or (start is None and end is not None): + raise ValueError("You must provide both input and output components") def add_nodes_and_edges(self, nodes: List[Dict], edges: List[EdgeData]): self._vertices = nodes @@ -90,11 +111,111 @@ class Graph: self._edges = self._graph_data["edges"] self.initialize() + def add_component(self, _id: str, component: "Component"): + if _id in self.vertex_map: + return + frontend_node = component.to_frontend_node() + frontend_node["data"]["id"] = _id + frontend_node["id"] = _id + self._vertices.append(frontend_node) + vertex = self._create_vertex(frontend_node) + vertex.add_component_instance(component) + self.vertices.append(vertex) + self.vertex_map[_id] = vertex + + if component._edges: + for edge in component._edges: + self._add_edge(edge) + + if component._components: + for _component in component._components: + self.add_component(_component._id, _component) + + def _set_start_and_end(self, start: "Component", end: "Component"): + if not hasattr(start, "to_frontend_node"): + raise TypeError(f"start must be a Component. Got {type(start)}") + if not hasattr(end, "to_frontend_node"): + raise TypeError(f"end must be a Component. Got {type(end)}") + self.add_component(start._id, start) + self.add_component(end._id, end) + + def add_component_edge(self, source_id: str, output_input_tuple: Tuple[str, str], target_id: str): + source_vertex = self.get_vertex(source_id) + if not isinstance(source_vertex, ComponentVertex): + raise ValueError(f"Source vertex {source_id} is not a component vertex.") + target_vertex = self.get_vertex(target_id) + if not isinstance(target_vertex, ComponentVertex): + raise ValueError(f"Target vertex {target_id} is not a component vertex.") + output_name, input_name = output_input_tuple + edge_data: EdgeData = { + "source": source_id, + "target": target_id, + "data": { + "sourceHandle": { + "dataType": source_vertex.base_name, + "id": source_vertex.id, + "name": output_name, + "output_types": source_vertex.get_output(output_name).types, + }, + "targetHandle": { + "fieldName": input_name, + "id": target_vertex.id, + "inputTypes": target_vertex.get_input(input_name).input_types, + "type": str(target_vertex.get_input(input_name).field_type), + }, + }, + } + self._add_edge(edge_data) + + async def async_start(self, inputs: Optional[List[dict]] = None): + if not self._prepared: + raise ValueError("Graph not prepared. Call prepare() first.") + # The idea is for this to return a generator that yields the result of + # each step call and raise StopIteration when the graph is done + for _input in inputs or []: + for key, value in _input.items(): + vertex = self.get_vertex(key) + vertex.set_input_value(key, value) + while True: + result = await self.astep() + yield result + if isinstance(result, Finish): + return + + def start(self, inputs: Optional[List[dict]] = None) -> Generator: + #! Change this soon + nest_asyncio.apply() + loop = asyncio.get_event_loop() + async_gen = self.async_start(inputs) + async_gen_task = asyncio.ensure_future(async_gen.__anext__()) + + while True: + try: + result = loop.run_until_complete(async_gen_task) + yield result + if isinstance(result, Finish): + return + async_gen_task = asyncio.ensure_future(async_gen.__anext__()) + except StopAsyncIteration: + break + + def _add_edge(self, edge: EdgeData): + self.add_edge(edge) + source_id = edge["data"]["sourceHandle"]["id"] + target_id = edge["data"]["targetHandle"]["id"] + self.predecessor_map[target_id].append(source_id) + self.successor_map[source_id].append(target_id) + self.in_degree_map[target_id] += 1 + self.parent_child_map[source_id].append(target_id) + # TODO: Create a TypedDict to represente the node def add_node(self, node: dict): self._vertices.append(node) def add_edge(self, edge: EdgeData): + # Check if the edge already exists + if edge in self._edges: + return self._edges.append(edge) def initialize(self): @@ -303,6 +424,20 @@ class Graph: if getattr(vertex, attribute): getattr(self, f"_{attribute}_vertices").append(vertex.id) + def _set_inputs(self, input_components: list[str], inputs: Dict[str, str], input_type: InputType | None): + for vertex_id in self._is_input_vertices: + vertex = self.get_vertex(vertex_id) + # If the vertex is not in the input_components list + if input_components and (vertex_id not in input_components and vertex.display_name not in input_components): + continue + # If the input_type is not any and the input_type is not in the vertex id + # Example: input_type = "chat" and vertex.id = "OpenAI-19ddn" + elif input_type is not None and input_type != "any" and input_type not in vertex.id.lower(): + continue + if vertex is None: + raise ValueError(f"Vertex {vertex_id} not found") + vertex.update_raw_params(inputs, overwrite=True) + async def _run( self, inputs: Dict[str, str], @@ -335,20 +470,7 @@ class Graph: if not isinstance(inputs.get(INPUT_FIELD_NAME, ""), str): raise ValueError(f"Invalid input value: {inputs.get(INPUT_FIELD_NAME)}. Expected string") if inputs: - for vertex_id in self._is_input_vertices: - vertex = self.get_vertex(vertex_id) - # If the vertex is not in the input_components list - if input_components and ( - vertex_id not in input_components and vertex.display_name not in input_components - ): - continue - # If the input_type is not any and the input_type is not in the vertex id - # Example: input_type = "chat" and vertex.id = "OpenAI-19ddn" - elif input_type is not None and input_type != "any" and input_type not in vertex.id.lower(): - continue - if vertex is None: - raise ValueError(f"Vertex {vertex_id} not found") - vertex.update_raw_params(inputs, overwrite=True) + self._set_inputs(input_components, inputs, input_type) # Update all the vertices with the session_id for vertex_id in self._has_session_id_vertices: vertex = self.get_vertex(vertex_id) @@ -857,6 +979,50 @@ class Graph: return vertex raise ValueError(f"Vertex {vertex_id} is not a top level vertex or no root vertex found") + async def astep( + self, + inputs: Optional["InputValueRequest"] = None, + files: Optional[list[str]] = None, + user_id: Optional[str] = None, + ): + if not self._prepared: + raise ValueError("Graph not prepared. Call prepare() first.") + if not self._run_queue: + asyncio.create_task(self.end_all_traces()) + return Finish() + vertex_id = self._run_queue.popleft() + chat_service = get_chat_service() + vertex_build_result = await self.build_vertex( + vertex_id=vertex_id, + user_id=user_id, + inputs_dict=inputs.model_dump() if inputs else {}, + files=files, + get_cache=chat_service.get_cache, + set_cache=chat_service.set_cache, + ) + + next_runnable_vertices = await self.get_next_runnable_vertices( + self._lock, vertex=vertex_build_result.vertex, cache=False + ) + if self.stop_vertex and self.stop_vertex in next_runnable_vertices: + next_runnable_vertices = [self.stop_vertex] + self._run_queue.extend(next_runnable_vertices) + self.reset_inactivated_vertices() + self.reset_activated_vertices() + + await chat_service.set_cache(str(self.flow_id or self._run_id), self) + return vertex_build_result + + def step( + self, + inputs: Optional["InputValueRequest"] = None, + files: Optional[list[str]] = None, + user_id: Optional[str] = None, + ): + # Call astep but synchronously + loop = asyncio.get_event_loop() + return loop.run_until_complete(self.astep(inputs, files, user_id)) + async def build_vertex( self, vertex_id: str, @@ -1037,9 +1203,9 @@ class Graph: next_runnable_vertices.remove(v_id) else: self.run_manager.add_to_vertices_being_run(next_v_id) - if cache and self.flow_id: - set_cache_coro = partial(get_chat_service().set_cache, self.flow_id) - await set_cache_coro(self, lock) + if cache and self.flow_id is not None: + set_cache_coro = partial(get_chat_service().set_cache, key=self.flow_id) + await set_cache_coro(data=self, lock=lock) return next_runnable_vertices async def _execute_tasks(self, tasks: List[asyncio.Task], lock: asyncio.Lock) -> List[str]: @@ -1185,19 +1351,22 @@ class Graph: edges: set[ContractEdge] = set() for edge in self._edges: - source = self.get_vertex(edge["source"]) - target = self.get_vertex(edge["target"]) - - if source is None: - raise ValueError(f"Source vertex {edge['source']} not found") - if target is None: - raise ValueError(f"Target vertex {edge['target']} not found") - new_edge = ContractEdge(source, target, edge) - + new_edge = self.build_edge(edge) edges.add(new_edge) return list(edges) + def build_edge(self, edge: EdgeData) -> ContractEdge: + source = self.get_vertex(edge["source"]) + target = self.get_vertex(edge["target"]) + + if source is None: + raise ValueError(f"Source vertex {edge['source']} not found") + if target is None: + raise ValueError(f"Target vertex {edge['target']} not found") + new_edge = ContractEdge(source, target, edge) + return new_edge + def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]: """Returns the node class based on the node type.""" # First we check for the node_base_type @@ -1222,14 +1391,17 @@ class Graph: def _build_vertices(self) -> List[Vertex]: """Builds the vertices of the graph.""" vertices: List[Vertex] = [] - for vertex in self._vertices: - vertex_instance = self._create_vertex(vertex) + for frontend_data in self._vertices: + try: + vertex_instance = self.get_vertex(frontend_data["id"]) + except ValueError: + vertex_instance = self._create_vertex(frontend_data) vertices.append(vertex_instance) return vertices - def _create_vertex(self, vertex: dict): - vertex_data = vertex["data"] + def _create_vertex(self, frontend_data: dict): + vertex_data = frontend_data["data"] vertex_type: str = vertex_data["type"] # type: ignore vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore if "id" not in vertex_data: @@ -1237,7 +1409,7 @@ class Graph: VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"]) - vertex_instance = VertexClass(vertex, graph=self) + vertex_instance = VertexClass(frontend_data, graph=self) vertex_instance.set_top_level(self.top_level_vertices) return vertex_instance @@ -1456,6 +1628,7 @@ class Graph: self.vertices_to_run = {vertex_id for vertex_id in chain.from_iterable(vertices_layers)} self.build_run_map() # Return just the first layer + self._first_layer = first_layer return first_layer def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]: diff --git a/src/backend/base/langflow/graph/graph/constants.py b/src/backend/base/langflow/graph/graph/constants.py index ca04e81c6..740a7e443 100644 --- a/src/backend/base/langflow/graph/graph/constants.py +++ b/src/backend/base/langflow/graph/graph/constants.py @@ -1,11 +1,25 @@ from langflow.graph.schema import CHAT_COMPONENTS -from langflow.graph.vertex import types from langflow.utils.lazy_load import LazyLoadDictBase +class Finish: + def __bool__(self): + return True + + def __eq__(self, other): + return isinstance(other, Finish) + + +def _import_vertex_types(): + from langflow.graph.vertex import types + + return types + + class VertexTypesDict(LazyLoadDictBase): def __init__(self): self._all_types_dict = None + self._types = _import_vertex_types() @property def VERTEX_TYPE_MAP(self): @@ -20,13 +34,13 @@ class VertexTypesDict(LazyLoadDictBase): def get_type_dict(self): return { - **{t: types.CustomComponentVertex for t in ["CustomComponent"]}, - **{t: types.ComponentVertex for t in ["Component"]}, - **{t: types.InterfaceVertex for t in CHAT_COMPONENTS}, + **{t: self._types.CustomComponentVertex for t in ["CustomComponent"]}, + **{t: self._types.ComponentVertex for t in ["Component"]}, + **{t: self._types.InterfaceVertex for t in CHAT_COMPONENTS}, } def get_custom_component_vertex_type(self): - return types.CustomComponentVertex + return self._types.CustomComponentVertex lazy_load_vertex_dict = VertexTypesDict() diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index c12bdd88d..357e9dc2b 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -96,6 +96,15 @@ class Vertex: self.build_times: List[float] = [] self.state = VertexStates.ACTIVE + def set_input_value(self, name: str, value: Any): + if self._custom_component is None: + raise ValueError(f"Vertex {self.id} does not have a component instance.") + self._custom_component._set_input_value(name, value) + + def add_component_instance(self, component_instance: "Component"): + component_instance.set_vertex(self) + self._custom_component = component_instance + def add_result(self, name: str, result: Any): self.results[name] = result @@ -289,12 +298,13 @@ class Vertex: # we don't know the key of the dict but we need to set the value # to the vertex that is the source of the edge param_dict = template_dict[param_key]["value"] - if param_dict: + if not param_dict or len(param_dict) != 1: + params[param_key] = self.graph.get_vertex(edge.source_id) + else: params[param_key] = { key: self.graph.get_vertex(edge.source_id) for key in param_dict.keys() } - else: - params[param_key] = self.graph.get_vertex(edge.source_id) + else: params[param_key] = self.graph.get_vertex(edge.source_id) diff --git a/src/backend/base/langflow/inputs/input_mixin.py b/src/backend/base/langflow/inputs/input_mixin.py index f772a256a..7b7f7c2b1 100644 --- a/src/backend/base/langflow/inputs/input_mixin.py +++ b/src/backend/base/langflow/inputs/input_mixin.py @@ -40,12 +40,12 @@ class BaseInputMixin(BaseModel, validate_assignment=True): # type: ignore show: bool = True """Should the field be shown. Defaults to True.""" + name: str = Field(description="Name of the field.") + """Name of the field. Default is an empty string.""" + value: Any = "" """The value of the field. Default is an empty string.""" - name: Optional[str] = None - """Name of the field. Default is an empty string.""" - display_name: Optional[str] = None """Display name of the field. Defaults to None.""" diff --git a/src/backend/base/langflow/inputs/inputs.py b/src/backend/base/langflow/inputs/inputs.py index 87a0314f5..366a9d183 100644 --- a/src/backend/base/langflow/inputs/inputs.py +++ b/src/backend/base/langflow/inputs/inputs.py @@ -273,7 +273,7 @@ class SecretStrInput(BaseInputMixin, DatabaseLoadMixin): elif isinstance(v, (AsyncIterator, Iterator)): value = v else: - raise ValueError(f"Invalid value type {type(v)}") + raise ValueError(f"Invalid value type `{type(v)}` for input `{_info.data['name']}`") return value diff --git a/src/backend/base/langflow/interface/initialize/loading.py b/src/backend/base/langflow/interface/initialize/loading.py index ea2858100..f0ae915ad 100644 --- a/src/backend/base/langflow/interface/initialize/loading.py +++ b/src/backend/base/langflow/interface/initialize/loading.py @@ -7,13 +7,13 @@ import orjson from loguru import logger from pydantic import PydanticDeprecatedSince20 -from langflow.custom import Component, CustomComponent from langflow.custom.eval import eval_custom_component_code from langflow.schema import Data from langflow.schema.artifact import get_artifact_type, post_process_raw from langflow.services.deps import get_tracing_service if TYPE_CHECKING: + from langflow.custom import Component, CustomComponent from langflow.graph.vertex.base import Vertex @@ -54,9 +54,9 @@ async def get_instance_results( ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20) - if base_type == "custom_components" and isinstance(custom_component, CustomComponent): + if base_type == "custom_components": return await build_custom_component(params=custom_params, custom_component=custom_component) - elif base_type == "component" and isinstance(custom_component, Component): + elif base_type == "component": return await build_component(params=custom_params, custom_component=custom_component) else: raise ValueError(f"Base type {base_type} not found.") diff --git a/src/backend/base/langflow/schema/message.py b/src/backend/base/langflow/schema/message.py index f941a943f..153906e43 100644 --- a/src/backend/base/langflow/schema/message.py +++ b/src/backend/base/langflow/schema/message.py @@ -23,8 +23,10 @@ from langflow.utils.constants import ( ) -def _timestamp_to_str(timestamp: datetime) -> str: - return timestamp.strftime("%Y-%m-%d %H:%M:%S") +def _timestamp_to_str(timestamp: datetime | str) -> str: + if isinstance(timestamp, datetime): + return timestamp.strftime("%Y-%m-%d %H:%M:%S") + return timestamp class Message(Data): diff --git a/src/backend/base/langflow/template/field/base.py b/src/backend/base/langflow/template/field/base.py index 79ca2d3d9..f3583e466 100644 --- a/src/backend/base/langflow/template/field/base.py +++ b/src/backend/base/langflow/template/field/base.py @@ -158,7 +158,7 @@ class Input(BaseModel): class Output(BaseModel): - types: Optional[list[str]] = Field(default=[]) + types: list[str] = Field(default=[]) """List of output types for the field.""" selected: Optional[str] = Field(default=None) diff --git a/src/backend/tests/unit/custom/custom_component/test_component.py b/src/backend/tests/unit/custom/custom_component/test_component.py new file mode 100644 index 000000000..f28ce861b --- /dev/null +++ b/src/backend/tests/unit/custom/custom_component/test_component.py @@ -0,0 +1,16 @@ +import pytest + +from langflow.components.inputs.ChatInput import ChatInput +from langflow.components.outputs import ChatOutput + + +@pytest.fixture +def client(): + pass + + +def test_set_invalid_output(): + chatinput = ChatInput() + chatoutput = ChatOutput() + with pytest.raises(ValueError): + chatoutput.set(input_value=chatinput.build_config) diff --git a/src/backend/tests/unit/graph/graph/test_base.py b/src/backend/tests/unit/graph/graph/test_base.py new file mode 100644 index 000000000..68ecdff7a --- /dev/null +++ b/src/backend/tests/unit/graph/graph/test_base.py @@ -0,0 +1,131 @@ +from collections import deque + +import pytest + +from langflow.components.inputs.ChatInput import ChatInput +from langflow.components.outputs.ChatOutput import ChatOutput +from langflow.components.outputs.TextOutput import TextOutputComponent +from langflow.graph.graph.base import Graph +from langflow.graph.graph.constants import Finish + + +@pytest.fixture +def client(): + pass + + +@pytest.mark.asyncio +async def test_graph_not_prepared(): + chat_input = ChatInput() + chat_output = ChatOutput() + graph = Graph() + graph.add_component("chat_input", chat_input) + graph.add_component("chat_output", chat_output) + graph.add_component_edge("chat_input", (chat_input.outputs[0].name, chat_input.inputs[0].name), "chat_output") + with pytest.raises(ValueError): + await graph.astep() + + +@pytest.mark.asyncio +async def test_graph(): + chat_input = ChatInput() + chat_output = ChatOutput() + graph = Graph() + graph.add_component("chat_input", chat_input) + graph.add_component("chat_output", chat_output) + graph.add_component_edge("chat_input", (chat_input.outputs[0].name, chat_input.inputs[0].name), "chat_output") + graph.prepare() + assert graph._run_queue == deque(["chat_input"]) + await graph.astep() + assert graph._run_queue == deque(["chat_output"]) + + assert graph.vertices[0].id == "chat_input" + assert graph.vertices[1].id == "chat_output" + assert graph.edges[0].source_id == "chat_input" + assert graph.edges[0].target_id == "chat_output" + + +@pytest.mark.asyncio +async def test_graph_functional(): + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output) + assert graph._run_queue == deque(["chat_input"]) + await graph.astep() + assert graph._run_queue == deque(["chat_output"]) + + assert graph.vertices[0].id == "chat_input" + assert graph.vertices[1].id == "chat_output" + assert graph.edges[0].source_id == "chat_input" + assert graph.edges[0].target_id == "chat_output" + + +@pytest.mark.asyncio +async def test_graph_functional_async_start(): + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output) + # Now iterate through the graph + # and check that the graph is running + # correctly + ids = ["chat_input", "chat_output"] + results = [] + async for result in graph.async_start(): + results.append(result) + + assert len(results) == 3 + assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) + assert results[-1] == Finish() + + +def test_graph_functional_start(): + chat_input = ChatInput(_id="chat_input") + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(sender_name=chat_input.message_response) + graph = Graph(chat_input, chat_output) + graph.prepare() + # Now iterate through the graph + # and check that the graph is running + # correctly + ids = ["chat_input", "chat_output"] + results = [] + for result in graph.start(): + results.append(result) + + assert len(results) == 3 + assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) + assert results[-1] == Finish() + + +def test_graph_functional_start_end(): + chat_input = ChatInput(_id="chat_input") + text_output = TextOutputComponent(_id="text_output") + text_output.set(input_value=chat_input.message_response) + chat_output = ChatOutput(input_value="test", _id="chat_output") + chat_output.set(input_value=text_output.text_response) + graph = Graph(chat_input, text_output) + graph.prepare() + # Now iterate through the graph + # and check that the graph is running + # correctly + ids = ["chat_input", "text_output"] + results = [] + for result in graph.start(): + results.append(result) + + assert len(results) == len(ids) + 1 + assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) + assert results[-1] == Finish() + # Now, using the same components but different start and end components + graph = Graph(chat_input, chat_output) + graph.prepare() + ids = ["chat_input", "chat_output", "text_output"] + results = [] + for result in graph.start(): + results.append(result) + + assert len(results) == len(ids) + 1 + assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex")) + assert results[-1] == Finish() diff --git a/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py b/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py new file mode 100644 index 000000000..3fd135de2 --- /dev/null +++ b/src/backend/tests/unit/initial_setup/starter_projects/test_memory_chatbot.py @@ -0,0 +1,41 @@ +from collections import deque + +from langflow.components.helpers.Memory import MemoryComponent +from langflow.components.inputs.ChatInput import ChatInput +from langflow.components.models.OpenAIModel import OpenAIModelComponent +from langflow.components.outputs.ChatOutput import ChatOutput +from langflow.components.prompts.Prompt import PromptComponent +from langflow.graph import Graph +from langflow.graph.graph.constants import Finish + + +def test_memory_chatbot(): + session_id = "test_session_id" + template = """{context} + +User: {user_message} +AI: """ + memory_component = MemoryComponent(_id="chat_memory") + memory_component.set(session_id=session_id) + chat_input = ChatInput(_id="chat_input") + prompt_component = PromptComponent(_id="prompt") + prompt_component.set( + template=template, user_message=chat_input.message_response, context=memory_component.retrieve_messages_as_text + ) + openai_component = OpenAIModelComponent(_id="openai") + openai_component.set( + input_value=prompt_component.build_prompt, max_tokens=100, temperature=0.1, api_key="test_api_key" + ) + openai_component.get_output("text_output").value = "Mock response" + + chat_output = ChatOutput(_id="chat_output") + chat_output.set(input_value=openai_component.text_response) + + graph = Graph(chat_input, chat_output) + # Now we run step by step + expected_order = deque(["chat_input", "chat_memory", "prompt", "openai", "chat_output"]) + for step in expected_order: + result = graph.step() + if isinstance(result, Finish): + break + assert step == result.vertex.id diff --git a/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py b/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py new file mode 100644 index 000000000..1482ed556 --- /dev/null +++ b/src/backend/tests/unit/initial_setup/starter_projects/test_vector_store_rag.py @@ -0,0 +1,89 @@ +from textwrap import dedent + +from langflow.components.data.File import FileComponent +from langflow.components.embeddings.OpenAIEmbeddings import OpenAIEmbeddingsComponent +from langflow.components.helpers.ParseData import ParseDataComponent +from langflow.components.helpers.SplitText import SplitTextComponent +from langflow.components.inputs.ChatInput import ChatInput +from langflow.components.models.OpenAIModel import OpenAIModelComponent +from langflow.components.outputs.ChatOutput import ChatOutput +from langflow.components.prompts.Prompt import PromptComponent +from langflow.components.vectorstores.AstraDB import AstraVectorStoreComponent +from langflow.graph.graph.base import Graph +from langflow.graph.graph.constants import Finish +from langflow.schema.data import Data + + +def test_vector_store_rag(): + # Ingestion Graph + file_component = FileComponent(_id="file-123") + file_component.set(path="test.txt") + text_splitter = SplitTextComponent(_id="text-splitter-123") + text_splitter.set(data_inputs=file_component.load_file) + openai_embeddings = OpenAIEmbeddingsComponent(_id="openai-embeddings-123") + openai_embeddings.set( + openai_api_key="sk-123", openai_api_base="https://api.openai.com/v1", openai_api_type="openai" + ) + vector_store = AstraVectorStoreComponent(_id="vector-store-123") + vector_store.set( + embedding=openai_embeddings.build_embeddings, + ingest_data=text_splitter.split_text, + api_endpoint="https://astra.example.com", + token="token", + ) + + # RAG Graph + chat_input = ChatInput(_id="chatinput-123") + chat_input.get_output("message").value = "What is the meaning of life?" + rag_vector_store = AstraVectorStoreComponent(_id="rag-vector-store-123") + rag_vector_store.set( + search_input=chat_input.message_response, + api_endpoint="https://astra.example.com", + token="token", + embedding=openai_embeddings.build_embeddings, + ) + # Mock search_documents + rag_vector_store.get_output("search_results").value = [ + Data(data={"text": "Hello, world!"}), + Data(data={"text": "Goodbye, world!"}), + ] + parse_data = ParseDataComponent(_id="parse-data-123") + parse_data.set(data=rag_vector_store.search_documents) + prompt_component = PromptComponent(_id="prompt-123") + prompt_component.set( + template=dedent("""Given the following context, answer the question. + Context:{context} + + Question: {question} + Answer:"""), + context=parse_data.parse_data, + question=chat_input.message_response, + ) + + openai_component = OpenAIModelComponent(_id="openai-123") + openai_component.set(api_key="sk-123", openai_api_base="https://api.openai.com/v1") + openai_component.set_output_value("text_output", "Hello, world!") + openai_component.set(input_value=prompt_component.build_prompt) + + chat_output = ChatOutput(_id="chatoutput-123") + chat_output.set(input_value=openai_component.text_response) + + graph = Graph(start=chat_input, end=chat_output) + assert graph is not None + ids = [ + "chatinput-123", + "chatoutput-123", + "openai-123", + "parse-data-123", + "prompt-123", + "rag-vector-store-123", + "openai-embeddings-123", + ] + results = [] + for result in graph.start(): + results.append(result) + + assert len(results) == 8 + vids = [result.vertex.id for result in results if hasattr(result, "vertex")] + assert all(vid in ids for vid in vids), f"Diff: {set(vids) - set(ids)}" + assert results[-1] == Finish() diff --git a/src/backend/tests/unit/inputs/test_inputs.py b/src/backend/tests/unit/inputs/test_inputs.py index 535de0cbf..902d81422 100644 --- a/src/backend/tests/unit/inputs/test_inputs.py +++ b/src/backend/tests/unit/inputs/test_inputs.py @@ -31,51 +31,44 @@ def client(): def test_table_input_valid(): - # Test with a valid list of dictionaries - data = TableInput(value=[{"key": "value"}, {"key2": "value2"}]) + data = TableInput(name="valid_table", value=[{"key": "value"}, {"key2": "value2"}]) assert data.value == [{"key": "value"}, {"key2": "value2"}] def test_table_input_invalid(): with pytest.raises(ValidationError): - # Test with an invalid value - TableInput(value="invalid") + TableInput(name="invalid_table", value="invalid") with pytest.raises(ValidationError): - # Test with a list containing invalid item - TableInput(value=[{"key": "value"}, "invalid"]) + TableInput(name="invalid_table", value=[{"key": "value"}, "invalid"]) def test_str_input_valid(): - data = StrInput(value="This is a string") + data = StrInput(name="valid_str", value="This is a string") assert data.value == "This is a string" def test_str_input_invalid(): with pytest.warns(UserWarning): - # Test with an invalid value - StrInput(value=1234) + StrInput(name="invalid_str", value=1234) def test_message_text_input_valid(): - # Test with a valid string - data = MessageTextInput(value="This is a message") + data = MessageTextInput(name="valid_msg", value="This is a message") assert data.value == "This is a message" - # Test with a valid Message object msg = Message(text="This is a message") - data = MessageTextInput(value=msg) + data = MessageTextInput(name="valid_msg", value=msg) assert data.value == "This is a message" def test_message_text_input_invalid(): with pytest.raises(ValidationError): - # Test with an invalid value - MessageTextInput(value=1234) + MessageTextInput(name="invalid_msg", value=1234) def test_instantiate_input_valid(): - data = {"value": "This is a string"} + data = {"name": "valid_input", "value": "This is a string"} input_instance = _instantiate_input("StrInput", data) assert isinstance(input_instance, StrInput) assert input_instance.value == "This is a string" @@ -83,146 +76,145 @@ def test_instantiate_input_valid(): def test_instantiate_input_invalid(): with pytest.raises(ValueError): - # Test with an invalid input type - _instantiate_input("InvalidInput", {"value": "This is a string"}) + _instantiate_input("InvalidInput", {"name": "invalid_input", "value": "This is a string"}) def test_handle_input_valid(): - data = HandleInput(input_types=["BaseLanguageModel"]) + data = HandleInput(name="valid_handle", input_types=["BaseLanguageModel"]) assert data.input_types == ["BaseLanguageModel"] def test_handle_input_invalid(): with pytest.raises(ValidationError): - HandleInput(input_types="BaseLanguageModel") # should be a list, not a string + HandleInput(name="invalid_handle", input_types="BaseLanguageModel") def test_data_input_valid(): - data_input = DataInput(input_types=["Data"]) + data_input = DataInput(name="valid_data", input_types=["Data"]) assert data_input.input_types == ["Data"] def test_prompt_input_valid(): - prompt_input = PromptInput(value="Enter your name") + prompt_input = PromptInput(name="valid_prompt", value="Enter your name") assert prompt_input.value == "Enter your name" def test_multiline_input_valid(): - multiline_input = MultilineInput(value="This is a\nmultiline input") + multiline_input = MultilineInput(name="valid_multiline", value="This is a\nmultiline input") assert multiline_input.value == "This is a\nmultiline input" assert multiline_input.multiline is True def test_multiline_input_invalid(): with pytest.raises(ValidationError): - MultilineInput(value=1234) # should be a string, not an integer + MultilineInput(name="invalid_multiline", value=1234) def test_multiline_secret_input_valid(): - multiline_secret_input = MultilineSecretInput(value="secret") + multiline_secret_input = MultilineSecretInput(name="valid_multiline_secret", value="secret") assert multiline_secret_input.value == "secret" assert multiline_secret_input.password is True def test_multiline_secret_input_invalid(): with pytest.raises(ValidationError): - MultilineSecretInput(value=1234) # should be a string, not an integer + MultilineSecretInput(name="invalid_multiline_secret", value=1234) def test_secret_str_input_valid(): - secret_str_input = SecretStrInput(value="supersecret") + secret_str_input = SecretStrInput(name="valid_secret_str", value="supersecret") assert secret_str_input.value == "supersecret" assert secret_str_input.password is True def test_secret_str_input_invalid(): with pytest.raises(ValidationError): - SecretStrInput(value=1234) # should be a string, not an integer + SecretStrInput(name="invalid_secret_str", value=1234) def test_int_input_valid(): - int_input = IntInput(value=10) + int_input = IntInput(name="valid_int", value=10) assert int_input.value == 10 def test_int_input_invalid(): with pytest.raises(ValidationError): - IntInput(value="not_an_int") # should be an integer, not a string + IntInput(name="invalid_int", value="not_an_int") def test_float_input_valid(): - float_input = FloatInput(value=10.5) + float_input = FloatInput(name="valid_float", value=10.5) assert float_input.value == 10.5 def test_float_input_invalid(): with pytest.raises(ValidationError): - FloatInput(value="not_a_float") # should be a float, not a string + FloatInput(name="invalid_float", value="not_a_float") def test_bool_input_valid(): - bool_input = BoolInput(value=True) + bool_input = BoolInput(name="valid_bool", value=True) assert bool_input.value is True def test_bool_input_invalid(): with pytest.raises(ValidationError): - BoolInput(value="not_a_bool") # should be a bool, not a string + BoolInput(name="invalid_bool", value="not_a_bool") def test_nested_dict_input_valid(): - nested_dict_input = NestedDictInput(value={"key": "value"}) + nested_dict_input = NestedDictInput(name="valid_nested_dict", value={"key": "value"}) assert nested_dict_input.value == {"key": "value"} def test_nested_dict_input_invalid(): with pytest.raises(ValidationError): - NestedDictInput(value="not_a_dict") # should be a dict, not a string + NestedDictInput(name="invalid_nested_dict", value="not_a_dict") def test_dict_input_valid(): - dict_input = DictInput(value={"key": "value"}) + dict_input = DictInput(name="valid_dict", value={"key": "value"}) assert dict_input.value == {"key": "value"} def test_dict_input_invalid(): with pytest.raises(ValidationError): - DictInput(value="not_a_dict") # should be a dict, not a string + DictInput(name="invalid_dict", value="not_a_dict") def test_dropdown_input_valid(): - dropdown_input = DropdownInput(options=["option1", "option2"]) + dropdown_input = DropdownInput(name="valid_dropdown", options=["option1", "option2"]) assert dropdown_input.options == ["option1", "option2"] def test_dropdown_input_invalid(): with pytest.raises(ValidationError): - DropdownInput(options="option1") # should be a list, not a string + DropdownInput(name="invalid_dropdown", options="option1") def test_multiselect_input_valid(): - multiselect_input = MultiselectInput(value=["option1", "option2"]) + multiselect_input = MultiselectInput(name="valid_multiselect", value=["option1", "option2"]) assert multiselect_input.value == ["option1", "option2"] def test_multiselect_input_invalid(): with pytest.raises(ValidationError): - MultiselectInput(value="option1") # should be a list, not a string + MultiselectInput(name="invalid_multiselect", value="option1") def test_file_input_valid(): - file_input = FileInput(value=["/path/to/file"]) + file_input = FileInput(name="valid_file", value=["/path/to/file"]) assert file_input.value == ["/path/to/file"] def test_instantiate_input_comprehensive(): valid_data = { - "StrInput": {"value": "A string"}, - "IntInput": {"value": 10}, - "FloatInput": {"value": 10.5}, - "BoolInput": {"value": True}, - "DictInput": {"value": {"key": "value"}}, - "MultiselectInput": {"value": ["option1", "option2"]}, + "StrInput": {"name": "str_input", "value": "A string"}, + "IntInput": {"name": "int_input", "value": 10}, + "FloatInput": {"name": "float_input", "value": 10.5}, + "BoolInput": {"name": "bool_input", "value": True}, + "DictInput": {"name": "dict_input", "value": {"key": "value"}}, + "MultiselectInput": {"name": "multiselect_input", "value": ["option1", "option2"]}, } for input_type, data in valid_data.items(): @@ -230,4 +222,4 @@ def test_instantiate_input_comprehensive(): assert isinstance(input_instance, InputTypesMap[input_type]) with pytest.raises(ValueError): - _instantiate_input("InvalidInput", {"value": "Invalid"}) # Invalid input type + _instantiate_input("InvalidInput", {"name": "invalid_input", "value": "Invalid"})