Add activated_vertices to VertexBuildResponse and update state management in Graph
This commit is contained in:
parent
94a00c0c5e
commit
3d4ab24858
4 changed files with 125 additions and 44 deletions
|
|
@ -230,6 +230,7 @@ class ResultDataResponse(BaseModel):
|
|||
class VertexBuildResponse(BaseModel):
|
||||
id: Optional[str] = None
|
||||
inactivated_vertices: Optional[List[str]] = None
|
||||
activated_vertices: Optional[List[str]] = None
|
||||
valid: bool
|
||||
params: Optional[str]
|
||||
"""JSON string of the params."""
|
||||
|
|
|
|||
|
|
@ -16,9 +16,11 @@ from langflow.graph.vertex.types import (
|
|||
FileToolVertex,
|
||||
LLMVertex,
|
||||
RoutingVertex,
|
||||
StateVertex,
|
||||
ToolkitVertex,
|
||||
)
|
||||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.schema import Record
|
||||
from langflow.utils import payload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -55,6 +57,7 @@ class Graph:
|
|||
self._vertices = self._graph_data["nodes"]
|
||||
self._edges = self._graph_data["edges"]
|
||||
self.inactivated_vertices: set = set()
|
||||
self.activated_vertices: set = set()
|
||||
self.edges: List[ContractEdge] = []
|
||||
self.vertices: List[Vertex] = []
|
||||
self._build_graph()
|
||||
|
|
@ -62,6 +65,37 @@ class Graph:
|
|||
self.define_vertices_lists()
|
||||
self.state_manager = GraphStateManager()
|
||||
|
||||
def update_state(
|
||||
self, name: str, record: Union[str, Record], caller: Optional[str] = None
|
||||
) -> None:
|
||||
"""Updates the state of the graph."""
|
||||
if caller:
|
||||
# If there is a caller which is a vertex_id, I want to activate
|
||||
# all StateVertex in self.vertices that are not the caller
|
||||
# essentially notifying all the other vertices that the state has changed
|
||||
# This also has to activate their successors
|
||||
caller_vertex = self.get_vertex(caller)
|
||||
for vertex in self.vertices:
|
||||
if vertex.id != caller and isinstance(vertex, StateVertex):
|
||||
successors = self.get_all_successors(vertex)
|
||||
self.activated_vertices.add(vertex.id)
|
||||
for successor in successors:
|
||||
self.activated_vertices.add(successor.id)
|
||||
|
||||
self.state_manager.update_state(name, record)
|
||||
|
||||
def reset_activated_vertices(self):
|
||||
self.activated_vertices = set()
|
||||
|
||||
def append_state(
|
||||
self, name: str, record: Union[str, Record], caller: Optional[str] = None
|
||||
) -> None:
|
||||
"""Appends the state of the graph."""
|
||||
if caller:
|
||||
self.state_manager.subscribe(name, caller)
|
||||
|
||||
self.state_manager.append_state(name, record)
|
||||
|
||||
def set_run_id(self, run_id: str):
|
||||
for vertex in self.vertices:
|
||||
self.state_manager.subscribe(run_id, vertex.update_graph_state)
|
||||
|
|
@ -500,6 +534,20 @@ class Graph:
|
|||
for source_id in self.predecessor_map.get(vertex.id, [])
|
||||
]
|
||||
|
||||
def get_all_successors(self, vertex, recursive=True):
|
||||
# Recursively get the successors of the current vertex
|
||||
successors = vertex.successors
|
||||
if not successors:
|
||||
return []
|
||||
successors_result = []
|
||||
for successor in successors:
|
||||
# Just return a list of successors
|
||||
if recursive:
|
||||
next_successors = self.get_all_successors(successor)
|
||||
successors_result.extend(next_successors)
|
||||
successors_result.append(successor)
|
||||
return successors_result
|
||||
|
||||
def get_successors(self, vertex):
|
||||
"""Returns the successors of a vertex."""
|
||||
return [
|
||||
|
|
@ -561,6 +609,8 @@ class Graph:
|
|||
return ChatVertex
|
||||
elif node_name in ["ShouldRunNext"]:
|
||||
return RoutingVertex
|
||||
elif node_name in ["SharedState"]:
|
||||
return StateVertex
|
||||
elif node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
|
||||
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type]
|
||||
elif node_name in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
|
||||
|
|
|
|||
|
|
@ -88,12 +88,9 @@ class Vertex:
|
|||
|
||||
def update_graph_state(self, key, new_state, append: bool):
|
||||
if append:
|
||||
if key in self.graph_state:
|
||||
self.graph_state[key].append(new_state)
|
||||
else:
|
||||
self.graph_state[key] = [new_state]
|
||||
self.graph.append_state(key, new_state, caller=self.id)
|
||||
else:
|
||||
self.graph_state[key] = new_state
|
||||
self.graph.update_state(key, new_state, caller=self.id)
|
||||
|
||||
def set_state(self, state: str):
|
||||
self.state = VertexStates[state]
|
||||
|
|
@ -511,7 +508,16 @@ class Vertex:
|
|||
self.params[key] = []
|
||||
self.params[key].extend(built)
|
||||
else:
|
||||
self.params[key].append(built)
|
||||
try:
|
||||
if self.params[key] == built:
|
||||
continue
|
||||
|
||||
self.params[key].append(built)
|
||||
except AttributeError as e:
|
||||
logger.exception(e)
|
||||
raise ValueError(
|
||||
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {built}"
|
||||
) from e
|
||||
|
||||
def _handle_func(self, key, result):
|
||||
"""
|
||||
|
|
@ -670,11 +676,3 @@ class Vertex:
|
|||
if self._built_object is not None
|
||||
else "Failed to build 😵💫"
|
||||
)
|
||||
|
||||
|
||||
class StatefulVertex(Vertex):
|
||||
pass
|
||||
|
||||
|
||||
class StatelessVertex(Vertex):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import ast
|
||||
import json
|
||||
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterator, List, Optional,
|
||||
Union)
|
||||
|
||||
import yaml
|
||||
from langchain_core.messages import AIMessage
|
||||
|
|
@ -8,14 +9,14 @@ from loguru import logger
|
|||
|
||||
from langflow.graph.schema import INPUT_FIELD_NAME
|
||||
from langflow.graph.utils import UnbuiltObject, flatten_list, serialize_field
|
||||
from langflow.graph.vertex.base import StatefulVertex, StatelessVertex
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langflow.schema import Record
|
||||
from langflow.services.monitor.utils import log_vertex_build
|
||||
from langflow.utils.schemas import ChatOutputResponse
|
||||
|
||||
|
||||
class AgentVertex(StatelessVertex):
|
||||
class AgentVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="agents", params=params)
|
||||
|
||||
|
|
@ -58,12 +59,12 @@ class AgentVertex(StatelessVertex):
|
|||
await self._build(user_id=user_id)
|
||||
|
||||
|
||||
class ToolVertex(StatelessVertex):
|
||||
class ToolVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="tools", params=params)
|
||||
|
||||
|
||||
class LLMVertex(StatelessVertex):
|
||||
class LLMVertex(Vertex):
|
||||
built_node_type = None
|
||||
class_built_object = None
|
||||
|
||||
|
|
@ -86,7 +87,7 @@ class LLMVertex(StatelessVertex):
|
|||
self.class_built_object = self._built_object
|
||||
|
||||
|
||||
class ToolkitVertex(StatelessVertex):
|
||||
class ToolkitVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, base_type="toolkits", params=params)
|
||||
|
||||
|
|
@ -100,7 +101,7 @@ class FileToolVertex(ToolVertex):
|
|||
)
|
||||
|
||||
|
||||
class WrapperVertex(StatelessVertex):
|
||||
class WrapperVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, base_type="wrappers")
|
||||
self.steps: List[Callable] = [self._custom_build]
|
||||
|
|
@ -114,7 +115,7 @@ class WrapperVertex(StatelessVertex):
|
|||
await self._build(user_id=user_id)
|
||||
|
||||
|
||||
class DocumentLoaderVertex(StatefulVertex):
|
||||
class DocumentLoaderVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="documentloaders", params=params)
|
||||
|
||||
|
|
@ -123,21 +124,23 @@ class DocumentLoaderVertex(StatefulVertex):
|
|||
# show how many documents are in the list?
|
||||
|
||||
if not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
|
||||
self._built_object
|
||||
)
|
||||
avg_length = sum(
|
||||
len(doc.page_content)
|
||||
for doc in self._built_object
|
||||
if hasattr(doc, "page_content")
|
||||
) / len(self._built_object)
|
||||
return f"""{self.display_name}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
Documents: {self._built_object[:3]}..."""
|
||||
return f"{self.vertex_type}()"
|
||||
|
||||
|
||||
class EmbeddingVertex(StatefulVertex):
|
||||
class EmbeddingVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="embeddings", params=params)
|
||||
|
||||
|
||||
class VectorStoreVertex(StatefulVertex):
|
||||
class VectorStoreVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph, params=None):
|
||||
super().__init__(data, graph=graph, base_type="vectorstores")
|
||||
|
||||
|
|
@ -179,17 +182,17 @@ class VectorStoreVertex(StatefulVertex):
|
|||
self.remove_docs_and_texts_from_params()
|
||||
|
||||
|
||||
class MemoryVertex(StatefulVertex):
|
||||
class MemoryVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="memory")
|
||||
|
||||
|
||||
class RetrieverVertex(StatefulVertex):
|
||||
class RetrieverVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="retrievers")
|
||||
|
||||
|
||||
class TextSplitterVertex(StatefulVertex):
|
||||
class TextSplitterVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
|
||||
super().__init__(data, graph=graph, base_type="textsplitters", params=params)
|
||||
|
||||
|
|
@ -198,14 +201,16 @@ class TextSplitterVertex(StatefulVertex):
|
|||
# show how many documents are in the list?
|
||||
|
||||
if not isinstance(self._built_object, UnbuiltObject):
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
|
||||
self._built_object
|
||||
)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
\nDocuments: {self._built_object[:3]}..."""
|
||||
return f"{self.vertex_type}()"
|
||||
|
||||
|
||||
class ChainVertex(StatelessVertex):
|
||||
class ChainVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="chains")
|
||||
self.steps = [self._custom_build]
|
||||
|
|
@ -235,7 +240,7 @@ class ChainVertex(StatelessVertex):
|
|||
return super()._built_object_repr()
|
||||
|
||||
|
||||
class PromptVertex(StatelessVertex):
|
||||
class PromptVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="prompts")
|
||||
self.steps: List[Callable] = [self._custom_build]
|
||||
|
|
@ -245,18 +250,27 @@ class PromptVertex(StatelessVertex):
|
|||
user_id = kwargs.get("user_id", None)
|
||||
tools = kwargs.get("tools", [])
|
||||
if not self._built or force:
|
||||
if "input_variables" not in self.params or self.params["input_variables"] is None:
|
||||
if (
|
||||
"input_variables" not in self.params
|
||||
or self.params["input_variables"] is None
|
||||
):
|
||||
self.params["input_variables"] = []
|
||||
# Check if it is a ZeroShotPrompt and needs a tool
|
||||
if "ShotPrompt" in self.vertex_type:
|
||||
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
|
||||
tools = (
|
||||
[tool_node.build(user_id=user_id) for tool_node in tools]
|
||||
if tools is not None
|
||||
else []
|
||||
)
|
||||
# flatten the list of tools if it is a list of lists
|
||||
# first check if it is a list
|
||||
if tools and isinstance(tools, list) and isinstance(tools[0], list):
|
||||
tools = flatten_list(tools)
|
||||
self.params["tools"] = tools
|
||||
prompt_params = [
|
||||
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
|
||||
key
|
||||
for key, value in self.params.items()
|
||||
if isinstance(value, str) and key != "format_instructions"
|
||||
]
|
||||
else:
|
||||
prompt_params = ["template"]
|
||||
|
|
@ -266,14 +280,20 @@ class PromptVertex(StatelessVertex):
|
|||
prompt_text = self.params[param]
|
||||
variables = extract_input_variables_from_prompt(prompt_text)
|
||||
self.params["input_variables"].extend(variables)
|
||||
self.params["input_variables"] = list(set(self.params["input_variables"]))
|
||||
self.params["input_variables"] = list(
|
||||
set(self.params["input_variables"])
|
||||
)
|
||||
elif isinstance(self.params, dict):
|
||||
self.params.pop("input_variables", None)
|
||||
|
||||
await self._build(user_id=user_id)
|
||||
|
||||
def _built_object_repr(self):
|
||||
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
|
||||
if (
|
||||
not self.artifacts
|
||||
or self._built_object is None
|
||||
or not hasattr(self._built_object, "format")
|
||||
):
|
||||
return super()._built_object_repr()
|
||||
elif isinstance(self._built_object, UnbuiltObject):
|
||||
return super()._built_object_repr()
|
||||
|
|
@ -285,7 +305,9 @@ class PromptVertex(StatelessVertex):
|
|||
# so the prompt format doesn't break
|
||||
artifacts.pop("handle_keys", None)
|
||||
try:
|
||||
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
|
||||
if not hasattr(self._built_object, "template") and hasattr(
|
||||
self._built_object, "prompt"
|
||||
):
|
||||
template = self._built_object.prompt.template
|
||||
else:
|
||||
template = self._built_object.template
|
||||
|
|
@ -293,17 +315,21 @@ class PromptVertex(StatelessVertex):
|
|||
if value:
|
||||
replace_key = "{" + key + "}"
|
||||
template = template.replace(replace_key, value)
|
||||
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
|
||||
return (
|
||||
template
|
||||
if isinstance(template, str)
|
||||
else f"{self.vertex_type}({template})"
|
||||
)
|
||||
except KeyError:
|
||||
return str(self._built_object)
|
||||
|
||||
|
||||
class OutputParserVertex(StatelessVertex):
|
||||
class OutputParserVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="output_parsers")
|
||||
|
||||
|
||||
class CustomComponentVertex(StatelessVertex):
|
||||
class CustomComponentVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="custom_components")
|
||||
|
||||
|
|
@ -312,7 +338,7 @@ class CustomComponentVertex(StatelessVertex):
|
|||
return self.artifacts["repr"] or super()._built_object_repr()
|
||||
|
||||
|
||||
class ChatVertex(StatelessVertex):
|
||||
class ChatVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="custom_components", is_task=True)
|
||||
self.steps = [self._build, self._run]
|
||||
|
|
@ -431,7 +457,7 @@ class ChatVertex(StatelessVertex):
|
|||
pass
|
||||
|
||||
|
||||
class RoutingVertex(StatelessVertex):
|
||||
class RoutingVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="custom_components")
|
||||
self.use_result = True
|
||||
|
|
@ -457,6 +483,12 @@ class RoutingVertex(StatelessVertex):
|
|||
self._built_result = None
|
||||
|
||||
|
||||
class StateVertex(Vertex):
|
||||
def __init__(self, data: Dict, graph):
|
||||
super().__init__(data, graph=graph, base_type="custom_components")
|
||||
self.steps = [self._build]
|
||||
|
||||
|
||||
def dict_to_codeblock(d: dict) -> str:
|
||||
serialized = {key: serialize_field(val) for key, val in d.items()}
|
||||
json_str = json.dumps(serialized, indent=4)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue