Add activated_vertices to VertexBuildResponse and update state management in Graph

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-01 23:57:01 -03:00
commit 3d4ab24858
4 changed files with 125 additions and 44 deletions

View file

@ -230,6 +230,7 @@ class ResultDataResponse(BaseModel):
class VertexBuildResponse(BaseModel):
id: Optional[str] = None
inactivated_vertices: Optional[List[str]] = None
activated_vertices: Optional[List[str]] = None
valid: bool
params: Optional[str]
"""JSON string of the params."""

View file

@ -16,9 +16,11 @@ from langflow.graph.vertex.types import (
FileToolVertex,
LLMVertex,
RoutingVertex,
StateVertex,
ToolkitVertex,
)
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.schema import Record
from langflow.utils import payload
if TYPE_CHECKING:
@ -55,6 +57,7 @@ class Graph:
self._vertices = self._graph_data["nodes"]
self._edges = self._graph_data["edges"]
self.inactivated_vertices: set = set()
self.activated_vertices: set = set()
self.edges: List[ContractEdge] = []
self.vertices: List[Vertex] = []
self._build_graph()
@ -62,6 +65,37 @@ class Graph:
self.define_vertices_lists()
self.state_manager = GraphStateManager()
def update_state(
self, name: str, record: Union[str, Record], caller: Optional[str] = None
) -> None:
"""Updates the state of the graph."""
if caller:
# If there is a caller which is a vertex_id, I want to activate
# all StateVertex in self.vertices that are not the caller
# essentially notifying all the other vertices that the state has changed
# This also has to activate their successors
caller_vertex = self.get_vertex(caller)
for vertex in self.vertices:
if vertex.id != caller and isinstance(vertex, StateVertex):
successors = self.get_all_successors(vertex)
self.activated_vertices.add(vertex.id)
for successor in successors:
self.activated_vertices.add(successor.id)
self.state_manager.update_state(name, record)
def reset_activated_vertices(self):
self.activated_vertices = set()
def append_state(
self, name: str, record: Union[str, Record], caller: Optional[str] = None
) -> None:
"""Appends the state of the graph."""
if caller:
self.state_manager.subscribe(name, caller)
self.state_manager.append_state(name, record)
def set_run_id(self, run_id: str):
for vertex in self.vertices:
self.state_manager.subscribe(run_id, vertex.update_graph_state)
@ -500,6 +534,20 @@ class Graph:
for source_id in self.predecessor_map.get(vertex.id, [])
]
def get_all_successors(self, vertex, recursive=True):
# Recursively get the successors of the current vertex
successors = vertex.successors
if not successors:
return []
successors_result = []
for successor in successors:
# Just return a list of successors
if recursive:
next_successors = self.get_all_successors(successor)
successors_result.extend(next_successors)
successors_result.append(successor)
return successors_result
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [
@ -561,6 +609,8 @@ class Graph:
return ChatVertex
elif node_name in ["ShouldRunNext"]:
return RoutingVertex
elif node_name in ["SharedState"]:
return StateVertex
elif node_base_type in lazy_load_vertex_dict.VERTEX_TYPE_MAP:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_base_type]
elif node_name in lazy_load_vertex_dict.VERTEX_TYPE_MAP:

View file

@ -88,12 +88,9 @@ class Vertex:
def update_graph_state(self, key, new_state, append: bool):
if append:
if key in self.graph_state:
self.graph_state[key].append(new_state)
else:
self.graph_state[key] = [new_state]
self.graph.append_state(key, new_state, caller=self.id)
else:
self.graph_state[key] = new_state
self.graph.update_state(key, new_state, caller=self.id)
def set_state(self, state: str):
self.state = VertexStates[state]
@ -511,7 +508,16 @@ class Vertex:
self.params[key] = []
self.params[key].extend(built)
else:
self.params[key].append(built)
try:
if self.params[key] == built:
continue
self.params[key].append(built)
except AttributeError as e:
logger.exception(e)
raise ValueError(
f"Params {key} ({self.params[key]}) is not a list and cannot be extended with {built}"
) from e
def _handle_func(self, key, result):
"""
@ -670,11 +676,3 @@ class Vertex:
if self._built_object is not None
else "Failed to build 😵‍💫"
)
class StatefulVertex(Vertex):
pass
class StatelessVertex(Vertex):
pass

View file

@ -1,6 +1,7 @@
import ast
import json
from typing import AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
from typing import (AsyncIterator, Callable, Dict, Iterator, List, Optional,
Union)
import yaml
from langchain_core.messages import AIMessage
@ -8,14 +9,14 @@ from loguru import logger
from langflow.graph.schema import INPUT_FIELD_NAME
from langflow.graph.utils import UnbuiltObject, flatten_list, serialize_field
from langflow.graph.vertex.base import StatefulVertex, StatelessVertex
from langflow.graph.vertex.base import Vertex
from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.schema import Record
from langflow.services.monitor.utils import log_vertex_build
from langflow.utils.schemas import ChatOutputResponse
class AgentVertex(StatelessVertex):
class AgentVertex(Vertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="agents", params=params)
@ -58,12 +59,12 @@ class AgentVertex(StatelessVertex):
await self._build(user_id=user_id)
class ToolVertex(StatelessVertex):
class ToolVertex(Vertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="tools", params=params)
class LLMVertex(StatelessVertex):
class LLMVertex(Vertex):
built_node_type = None
class_built_object = None
@ -86,7 +87,7 @@ class LLMVertex(StatelessVertex):
self.class_built_object = self._built_object
class ToolkitVertex(StatelessVertex):
class ToolkitVertex(Vertex):
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, base_type="toolkits", params=params)
@ -100,7 +101,7 @@ class FileToolVertex(ToolVertex):
)
class WrapperVertex(StatelessVertex):
class WrapperVertex(Vertex):
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, base_type="wrappers")
self.steps: List[Callable] = [self._custom_build]
@ -114,7 +115,7 @@ class WrapperVertex(StatelessVertex):
await self._build(user_id=user_id)
class DocumentLoaderVertex(StatefulVertex):
class DocumentLoaderVertex(Vertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="documentloaders", params=params)
@ -123,21 +124,23 @@ class DocumentLoaderVertex(StatefulVertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len(
self._built_object
)
avg_length = sum(
len(doc.page_content)
for doc in self._built_object
if hasattr(doc, "page_content")
) / len(self._built_object)
return f"""{self.display_name}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
Documents: {self._built_object[:3]}..."""
return f"{self.vertex_type}()"
class EmbeddingVertex(StatefulVertex):
class EmbeddingVertex(Vertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="embeddings", params=params)
class VectorStoreVertex(StatefulVertex):
class VectorStoreVertex(Vertex):
def __init__(self, data: Dict, graph, params=None):
super().__init__(data, graph=graph, base_type="vectorstores")
@ -179,17 +182,17 @@ class VectorStoreVertex(StatefulVertex):
self.remove_docs_and_texts_from_params()
class MemoryVertex(StatefulVertex):
class MemoryVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="memory")
class RetrieverVertex(StatefulVertex):
class RetrieverVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="retrievers")
class TextSplitterVertex(StatefulVertex):
class TextSplitterVertex(Vertex):
def __init__(self, data: Dict, graph, params: Optional[Dict] = None):
super().__init__(data, graph=graph, base_type="textsplitters", params=params)
@ -198,14 +201,16 @@ class TextSplitterVertex(StatefulVertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
self._built_object
)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
\nDocuments: {self._built_object[:3]}..."""
return f"{self.vertex_type}()"
class ChainVertex(StatelessVertex):
class ChainVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="chains")
self.steps = [self._custom_build]
@ -235,7 +240,7 @@ class ChainVertex(StatelessVertex):
return super()._built_object_repr()
class PromptVertex(StatelessVertex):
class PromptVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="prompts")
self.steps: List[Callable] = [self._custom_build]
@ -245,18 +250,27 @@ class PromptVertex(StatelessVertex):
user_id = kwargs.get("user_id", None)
tools = kwargs.get("tools", [])
if not self._built or force:
if "input_variables" not in self.params or self.params["input_variables"] is None:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
tools = (
[tool_node.build(user_id=user_id) for tool_node in tools]
if tools is not None
else []
)
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
@ -266,14 +280,20 @@ class PromptVertex(StatelessVertex):
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(set(self.params["input_variables"]))
self.params["input_variables"] = list(
set(self.params["input_variables"])
)
elif isinstance(self.params, dict):
self.params.pop("input_variables", None)
await self._build(user_id=user_id)
def _built_object_repr(self):
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
if (
not self.artifacts
or self._built_object is None
or not hasattr(self._built_object, "format")
):
return super()._built_object_repr()
elif isinstance(self._built_object, UnbuiltObject):
return super()._built_object_repr()
@ -285,7 +305,9 @@ class PromptVertex(StatelessVertex):
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
try:
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
if not hasattr(self._built_object, "template") and hasattr(
self._built_object, "prompt"
):
template = self._built_object.prompt.template
else:
template = self._built_object.template
@ -293,17 +315,21 @@ class PromptVertex(StatelessVertex):
if value:
replace_key = "{" + key + "}"
template = template.replace(replace_key, value)
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
return (
template
if isinstance(template, str)
else f"{self.vertex_type}({template})"
)
except KeyError:
return str(self._built_object)
class OutputParserVertex(StatelessVertex):
class OutputParserVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="output_parsers")
class CustomComponentVertex(StatelessVertex):
class CustomComponentVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components")
@ -312,7 +338,7 @@ class CustomComponentVertex(StatelessVertex):
return self.artifacts["repr"] or super()._built_object_repr()
class ChatVertex(StatelessVertex):
class ChatVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components", is_task=True)
self.steps = [self._build, self._run]
@ -431,7 +457,7 @@ class ChatVertex(StatelessVertex):
pass
class RoutingVertex(StatelessVertex):
class RoutingVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components")
self.use_result = True
@ -457,6 +483,12 @@ class RoutingVertex(StatelessVertex):
self._built_result = None
class StateVertex(Vertex):
def __init__(self, data: Dict, graph):
super().__init__(data, graph=graph, base_type="custom_components")
self.steps = [self._build]
def dict_to_codeblock(d: dict) -> str:
serialized = {key: serialize_field(val) for key, val in d.items()}
json_str = json.dumps(serialized, indent=4)