diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 8782efb29..09ac2aac8 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -22,10 +22,10 @@ from langflow.services.auth.utils import get_current_active_user from langflow.services.chat.service import ChatService from langflow.services.deps import get_chat_service, get_session, get_session_service from langflow.services.monitor.utils import log_vertex_build -from langflow.services.session.service import SessionService if TYPE_CHECKING: from langflow.graph.vertex.types import ChatVertex + from langflow.services.session.service import SessionService router = APIRouter(tags=["Chat"]) @@ -49,7 +49,8 @@ async def try_running_celery_task(vertex, user_id): @router.get("/build/{flow_id}/vertices", response_model=VerticesOrderResponse) async def get_vertices( flow_id: str, - component_id: Optional[str] = None, + stop_component_id: Optional[str] = None, + start_component_id: Optional[str] = None, chat_service: "ChatService" = Depends(get_chat_service), session=Depends(get_session), ): @@ -60,9 +61,9 @@ async def get_vertices( if cache := chat_service.get_cache(flow_id): graph = cache.get("result") graph = build_and_cache_graph(flow_id, session, chat_service, graph) - if component_id: + if stop_component_id or start_component_id: try: - vertices = graph.sort_vertices(component_id) + vertices = graph.sort_vertices(stop_component_id, start_component_id) except Exception as exc: logger.error(exc) vertices = graph.sort_vertices() @@ -94,6 +95,7 @@ async def build_vertex( """Build a vertex instead of the entire graph.""" {"inputs": {"input_value": "some value"}} start_time = time.perf_counter() + next_vertices_ids = [] try: start_time = time.perf_counter() cache = chat_service.get_cache(flow_id) @@ -123,6 +125,10 @@ async def build_vertex( artifacts = vertex.artifacts else: raise ValueError(f"No result found for vertex {vertex_id}") + next_vertices_ids = vertex.successors_ids + next_vertices_ids = [ + v for v in next_vertices_ids if graph.should_run_vertex(v) + ] result_data_response = ResultDataResponse(**result_dict.model_dump()) @@ -160,9 +166,16 @@ async def build_vertex( graph.reset_activated_vertices() chat_service.set_cache(flow_id, graph) + # graph.stop_vertex tells us if the user asked + # to stop the build of the graph at a certain vertex + # if it is in next_vertices_ids, we need to remove other + # vertices from next_vertices_ids + if graph.stop_vertex and graph.stop_vertex in next_vertices_ids: + next_vertices_ids = [graph.stop_vertex] + build_response = VertexBuildResponse( inactivated_vertices=inactivated_vertices, - activated_layers=activated_layers, + next_vertices_ids=next_vertices_ids, valid=valid, params=params, id=vertex.id, diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index f570923e1..9968e839c 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -216,7 +216,7 @@ class ApiKeyCreateRequest(BaseModel): class VerticesOrderResponse(BaseModel): - ids: List[List[str]] + ids: List[str] run_id: UUID @@ -230,7 +230,7 @@ class ResultDataResponse(BaseModel): class VertexBuildResponse(BaseModel): id: Optional[str] = None inactivated_vertices: Optional[List[str]] = None - activated_layers: Optional[List[List[str]]] = None + next_vertices_ids: Optional[List[str]] = None valid: bool params: Optional[str] """JSON string of the params.""" diff --git a/src/backend/langflow/components/chains/ConversationChain.py b/src/backend/langflow/components/chains/ConversationChain.py index fbfb9dfd1..6e1e319d6 100644 --- a/src/backend/langflow/components/chains/ConversationChain.py +++ b/src/backend/langflow/components/chains/ConversationChain.py @@ -31,16 +31,12 @@ class ConversationChainComponent(CustomComponent): chain = ConversationChain(llm=llm) else: chain = ConversationChain(llm=llm, memory=memory) - result = chain.invoke({chain.input_key: input_value}) - # result is an AIMessage which is a subclass of BaseMessage - # We need to check if it is a string or a BaseMessage - result_str: Text = "" + result = chain.invoke(inputs) if hasattr(result, "content") and isinstance(result.content, str): - result_str = result.content + result = result.content elif isinstance(result, str): - result_str = result + result = result else: - # is dict - result_str = Text(result.get("response")) - self.status = result_str - return result_str + result = result.get("response") + self.status = result + return result diff --git a/src/backend/langflow/components/routing/ShouldRunNext.py b/src/backend/langflow/components/routing/ShouldRunNext.py new file mode 100644 index 000000000..3cfde3c22 --- /dev/null +++ b/src/backend/langflow/components/routing/ShouldRunNext.py @@ -0,0 +1,48 @@ +# Implement ShouldRunNext component +from langchain_core.prompts import PromptTemplate + +from langflow import CustomComponent +from langflow.field_typing import BaseLanguageModel, Prompt + + +class ShouldRunNext(CustomComponent): + display_name = "Should Run Next" + description = "Decides whether to run the next component." + + def build_config(self): + return { + "prompt": { + "display_name": "Prompt", + "info": "The prompt to use for the decision. It should generate a boolean response (True or False).", + }, + "llm": { + "display_name": "LLM", + "info": "The language model to use for the decision.", + }, + } + + def build(self, template: Prompt, llm: BaseLanguageModel, **kwargs) -> bool: + # This is a simple component that always returns True + prompt_template = PromptTemplate.from_template(template) + + attributes_to_check = ["text", "page_content"] + for key, value in kwargs.items(): + for attribute in attributes_to_check: + if hasattr(value, attribute): + kwargs[key] = getattr(value, attribute) + + chain = prompt_template | llm + result = chain.invoke(kwargs) + if hasattr(result, "content") and isinstance(result.content, str): + result = result.content + elif isinstance(result, str): + result = result + else: + result = result.get("response") + + if result.lower() not in ["true", "false"]: + raise ValueError("The prompt should generate a boolean response (True or False).") + # The string should be the words true or false + # if not raise an error + bool_result = result.lower() == "true" + return bool_result diff --git a/src/backend/langflow/components/routing/__init__.py b/src/backend/langflow/components/routing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index df270a201..ece989651 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -1,5 +1,6 @@ import asyncio from collections import defaultdict, deque +from itertools import chain from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union from langchain.chains.base import Chain @@ -59,6 +60,11 @@ class Graph: self._edges = self._graph_data["edges"] self.inactivated_vertices: set = set() self.activated_layers: List[List[str]] = [] + self.vertices_layers = [] + self.vertices_to_run = set() + self.stop_vertex = None + + self.inactive_vertices: set = set() self.edges: List[ContractEdge] = [] self.vertices: List[Vertex] = [] self._build_graph() @@ -191,6 +197,14 @@ class Graph: outputs.extend(run_outputs) return outputs + # vertices_layers is a list of lists ordered by the order the vertices + # should be built. + # We need to create a new method that will take the vertices_layers + # and return the next vertex to be built. + def next_vertex_to_build(self): + """Returns the next vertex to be built.""" + yield from chain.from_iterable(self.vertices_layers) + @property def metadata(self): return { @@ -336,9 +350,14 @@ class Graph: # Add new vertices for vertex_id in new_vertex_ids: new_vertex = other.get_vertex(vertex_id) - new_vertex.graph = self self._add_vertex(new_vertex) + # Now update the edges + for vertex_id in new_vertex_ids: + new_vertex = other.get_vertex(vertex_id) + self._update_edges(new_vertex) + new_vertex.graph = self + # Update existing vertices that have changed for vertex_id in existing_vertex_ids.intersection(other_vertex_ids): self_vertex = self.get_vertex(vertex_id) @@ -385,12 +404,24 @@ class Graph: _vertex._build_params() def _add_vertex(self, vertex: Vertex) -> None: - """Adds a new vertex to the graph.""" + """Adds a vertex to the graph.""" self.vertices.append(vertex) self.vertex_map[vertex.id] = vertex + + def add_vertex(self, vertex: Vertex) -> None: + """Adds a new vertex to the graph.""" + self._add_vertex(vertex) + self._update_edges(vertex) + + def _update_edges(self, vertex: Vertex) -> None: + """Updates the edges of a vertex.""" # Vertex has edges, so we need to update the edges for edge in vertex.edges: - if edge.source_id in self.vertex_map and edge.target_id in self.vertex_map: + if ( + edge not in self.edges + and edge.source_id in self.vertex_map + and edge.target_id in self.vertex_map + ): self.edges.append(edge) def _build_graph(self) -> None: @@ -722,7 +753,7 @@ class Graph: ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" - def sort_up_to_vertex(self, vertex_id: str) -> List[Vertex]: + def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]: """Cuts the graph up to a given vertex and sorts the resulting subgraph.""" # Initial setup visited = set() # To keep track of visited vertices @@ -756,11 +787,19 @@ class Graph: if current_id == vertex_id: # We should add to visited all the vertices that are successors of the current vertex # and their successors and so on + # if the vertex is a start, it means we are starting from the beginning + # and getting successors for successor in current_vertex.successors: - excluded.add(successor.id) + if is_start: + stack.append(successor.id) + else: + excluded.add(successor.id) all_successors = get_successors(successor) for successor in all_successors: - excluded.add(successor.id) + if is_start: + stack.append(successor.id) + else: + excluded.add(successor.id) # Filter the original graph's vertices and edges to keep only those in `visited` vertices_to_keep = [self.get_vertex(vid) for vid in visited] @@ -871,12 +910,19 @@ class Graph: return vertices_layers - def sort_vertices(self, component_id: Optional[str] = None) -> List[List[str]]: + def sort_vertices( + self, + stop_component_id: Optional[str] = None, + start_component_id: Optional[str] = None, + ) -> List[str]: """Sorts the vertices in the graph.""" self.mark_all_vertices("ACTIVE") - if component_id: - vertices = self.sort_up_to_vertex(component_id) - vertices_layers = self.layered_topological_sort(vertices) + if stop_component_id: + self.stop_vertex = stop_component_id + vertices = self.sort_up_to_vertex(stop_component_id) + elif start_component_id: + vertices = self.sort_up_to_vertex(start_component_id, is_start=True) + else: vertices = self.vertices # without component_id we are probably running in the chat @@ -886,10 +932,23 @@ class Graph: vertices, filter_graphs=True ) vertices_layers = self.sort_by_avg_build_time(vertices_layers) - vertices_layers = self.sort_chat_inputs_first(vertices_layers) + # vertices_layers = self.sort_chat_inputs_first(vertices_layers) self.increment_run_count() - self._sorted_vertices_layers = vertices_layers - return vertices_layers + first_layer = vertices_layers[0] + # save the only the rest + self.vertices_layers = vertices_layers[1:] + self.vertices_to_run = { + vertex for vertex in chain.from_iterable(vertices_layers) + } + # Return just the first layer + return first_layer + + def should_run_vertex(self, vertex_id: str) -> bool: + """Returns whether a component should be run.""" + should_run = vertex_id in self.vertices_to_run + if should_run: + self.vertices_to_run.remove(vertex_id) + return should_run def sort_interface_components_first( self, vertices_layers: List[List[str]] diff --git a/src/backend/langflow/graph/graph/constants.py b/src/backend/langflow/graph/graph/constants.py index 0d0e69c77..2badbf0eb 100644 --- a/src/backend/langflow/graph/graph/constants.py +++ b/src/backend/langflow/graph/graph/constants.py @@ -47,7 +47,10 @@ class VertexTypesDict(LazyLoadDictBase): **{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()}, **{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()}, **{t: types.OutputParserVertex for t in output_parser_creator.to_list()}, - **{t: types.CustomComponentVertex for t in custom_component_creator.to_list()}, + **{ + t: types.CustomComponentVertex + for t in custom_component_creator.to_list() + }, **{t: types.RetrieverVertex for t in retriever_creator.to_list()}, **{t: types.ChatVertex for t in CHAT_COMPONENTS}, **{t: types.RoutingVertex for t in ROUTING_COMPONENTS}, diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index 8f754325e..e40eabdaa 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -2,13 +2,26 @@ import ast import inspect import types from enum import Enum -from typing import (TYPE_CHECKING, Any, AsyncIterator, Callable, Coroutine, - Dict, Iterator, List, Optional) +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Callable, + Coroutine, + Dict, + Iterator, + List, + Optional, +) from loguru import logger -from langflow.graph.schema import (INPUT_COMPONENTS, OUTPUT_COMPONENTS, - InterfaceComponentTypes, ResultData) +from langflow.graph.schema import ( + INPUT_COMPONENTS, + OUTPUT_COMPONENTS, + InterfaceComponentTypes, + ResultData, +) from langflow.graph.utils import UnbuiltObject, UnbuiltResult from langflow.graph.vertex.utils import generate_result from langflow.interface.initialize import loading @@ -152,6 +165,10 @@ class Vertex: def successors(self) -> List["Vertex"]: return self.graph.get_successors(self) + @property + def successors_ids(self) -> List[str]: + return self.graph.successor_map.get(self.id, []) + def __getstate__(self): return { "_data": self._data, @@ -594,9 +611,6 @@ class Vertex: raise ValueError( f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead." ) - raise ValueError( - f"{self.display_name}: You are trying to stream to a non-streamable component." - ) def _reset(self, params_update: Optional[Dict[str, Any]] = None): self._built = False @@ -606,6 +620,9 @@ class Vertex: self.steps_ran = [] self._build_params() + def _is_chat_input(self): + return False + def build_inactive(self): # Just set the results to None self._built = True @@ -632,7 +649,7 @@ class Vertex: return await self.get_requester_result(requester) self._reset() - if self.is_input and inputs is not None: + if self._is_chat_input() and inputs is not None: self.update_raw_params(inputs) # Run steps diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index 516b06dee..a759be958 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -6,7 +6,7 @@ import yaml from langchain_core.messages import AIMessage from loguru import logger -from langflow.graph.schema import INPUT_FIELD_NAME +from langflow.graph.schema import INPUT_FIELD_NAME, InterfaceComponentTypes from langflow.graph.utils import UnbuiltObject, flatten_list, serialize_field from langflow.graph.vertex.base import Vertex from langflow.interface.utils import extract_input_variables_from_prompt @@ -455,18 +455,30 @@ class ChatVertex(Vertex): async for _ in self.stream(): pass + def _is_chat_input(self): + return self.vertex_type == InterfaceComponentTypes.ChatInput and self.is_input + class RoutingVertex(Vertex): def __init__(self, data: Dict, graph): super().__init__(data, graph=graph, base_type="custom_components") self.use_result = True - self.steps = [self._build, self._run] + self.steps = [self._build] def _built_object_repr(self): if self.artifacts and "repr" in self.artifacts: return self.artifacts["repr"] or super()._built_object_repr() return super()._built_object_repr() + @property + def successors_ids(self): + if isinstance(self._built_object, bool): + ids = super().successors_ids + if self._built_object: + return ids + return [] + raise ValueError("RoutingVertex should return a boolean value.") + def _run(self, *args, **kwargs): if self._built_object: condition = self._built_object.get("condition") diff --git a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx index d26505d5e..385d441d3 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx @@ -17,10 +17,10 @@ import TextAreaComponent from "../../../../components/textAreaComponent"; import ToggleShadComponent from "../../../../components/toggleShadComponent"; import { Button } from "../../../../components/ui/button"; import { + INPUT_HANDLER_HOVER, LANGFLOW_SUPPORTED_TYPES, + OUTPUT_HANDLER_HOVER, TOOLTIP_EMPTY, - inputHandleHover, - outputHandleHover, } from "../../../../constants/constants"; import { postCustomComponentUpdate } from "../../../../controllers/API"; import useAlertStore from "../../../../stores/alertStore"; @@ -182,7 +182,7 @@ export default function ParameterComponent({ return (
{index === 0 && ( - {left ? inputHandleHover : outputHandleHover} + {left ? INPUT_HANDLER_HOVER : OUTPUT_HANDLER_HOVER} )}
) : ( -
- +
+
{ if (nameEditable) { @@ -359,8 +364,8 @@ export default function GenericNode({ }} >
)} @@ -390,13 +395,34 @@ export default function GenericNode({ })} data={data} color={ - nodeColors[ - types[data.node?.template[templateField].type!] - ] ?? - nodeColors[ - data.node?.template[templateField].type! - ] ?? - nodeColors.unknown + data.node?.template[templateField].input_types && + data.node?.template[templateField].input_types! + .length > 0 + ? nodeColors[ + data.node?.template[templateField] + .input_types![ + data.node?.template[templateField] + .input_types!.length - 1 + ] + ] ?? + nodeColors[ + types[ + data.node?.template[templateField] + .input_types![ + data.node?.template[templateField] + .input_types!.length - 1 + ] + ] + ] + : nodeColors[ + data.node?.template[templateField].type! + ] ?? + nodeColors[ + types[ + data.node?.template[templateField].type! + ] + ] ?? + nodeColors.unknown } title={getFieldTitle( data.node?.template!, @@ -458,23 +484,14 @@ export default function GenericNode({ )}
{showNode && ( -