chore: upgrade type annotations in graph and custom modules (#3591)

* refactor(tests): update import statements in conftest.py to use collections.abc module for better compatibility and maintainability

* run pyupgrade on graph module

* [autofix.ci] apply automated fixes

* refactor(attributes.py): change import statement from 'typing.Callable' to 'collections.abc.Callable' for better compatibility
refactor(code_parser.py): update type annotations to use '|' for Union types for better readability
refactor(base_component.py): update type annotations to use '|' for Union types for better readability
refactor(component.py): change import statement from 'typing.Callable' to 'collections.abc.Callable' for better compatibility
refactor(component.py): update type annotations to use '|' for Union types for better readability
refactor(component.py): update type annotations to use 'list' instead of 'List' for consistency

refactor(custom_component.py): update typing imports and annotations for better readability and consistency

refactor(utils.py): change type hint 'List' to 'list' for consistency and compatibility with Python 3.9

* run make format

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-28 15:48:24 -03:00 committed by GitHub
commit 9c8dd8e1f0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 296 additions and 292 deletions

View file

@ -1,5 +1,5 @@
import warnings
from typing import Callable
from collections.abc import Callable
import emoji

View file

@ -1,7 +1,7 @@
import ast
import inspect
import traceback
from typing import Any, Dict, List, Type, Union
from typing import Any
from cachetools import TTLCache, keys
from fastapi import HTTPException
@ -29,7 +29,7 @@ def find_class_ast_node(class_obj):
return None, []
# Read the source code from the file
with open(source_file, "r") as file:
with open(source_file) as file:
source_code = file.read()
# Parse the source code into an AST
@ -59,7 +59,7 @@ class CodeParser:
A parser for Python source code, extracting code details.
"""
def __init__(self, code: Union[str, Type]) -> None:
def __init__(self, code: str | type) -> None:
"""
Initializes the parser with the provided code.
"""
@ -70,7 +70,7 @@ class CodeParser:
# If the code is a class, get its source code
code = inspect.getsource(code)
self.code = code
self.data: Dict[str, Any] = {
self.data: dict[str, Any] = {
"imports": [],
"functions": [],
"classes": [],
@ -99,7 +99,7 @@ class CodeParser:
return tree
def parse_node(self, node: Union[ast.stmt, ast.AST]) -> None:
def parse_node(self, node: ast.stmt | ast.AST) -> None:
"""
Parses an AST node and updates the data
dictionary with the relevant information.
@ -107,7 +107,7 @@ class CodeParser:
if handler := self.handlers.get(type(node)): # type: ignore
handler(node) # type: ignore
def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
def parse_imports(self, node: ast.Import | ast.ImportFrom) -> None:
"""
Extracts "imports" from the code, including aliases.
"""
@ -161,7 +161,7 @@ class CodeParser:
exec(f"import {module} as {alias if alias else module}", eval_env)
return eval_env
def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:
def parse_callable_details(self, node: ast.FunctionDef) -> dict[str, Any]:
"""
Extracts details from a single function or method node.
"""
@ -187,7 +187,7 @@ class CodeParser:
return func.model_dump()
def parse_function_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_function_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the arguments of a function or method node.
"""
@ -202,7 +202,7 @@ class CodeParser:
return args
def parse_positional_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_positional_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the positional arguments of a function or method node.
"""
@ -220,7 +220,7 @@ class CodeParser:
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.args, defaults)]
return args
def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_varargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the *args argument of a function or method node.
"""
@ -231,7 +231,7 @@ class CodeParser:
return args
def parse_keyword_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_keyword_args(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the keyword-only arguments of a function or method node.
"""
@ -242,7 +242,7 @@ class CodeParser:
args = [self.parse_arg(arg, default) for arg, default in zip(node.args.kwonlyargs, kw_defaults)]
return args
def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
def parse_kwargs(self, node: ast.FunctionDef) -> list[dict[str, Any]]:
"""
Parses the **kwargs argument of a function or method node.
"""
@ -253,7 +253,7 @@ class CodeParser:
return args
def parse_function_body(self, node: ast.FunctionDef) -> List[str]:
def parse_function_body(self, node: ast.FunctionDef) -> list[str]:
"""
Parses the body of a function or method node.
"""
@ -394,7 +394,7 @@ class CodeParser:
bases.append(bases_base)
return bases
def parse_code(self) -> Dict[str, Any]:
def parse_code(self) -> dict[str, Any]:
"""
Runs all parsing operations and returns the resulting data.
"""

View file

@ -1,5 +1,5 @@
import operator
from typing import Any, ClassVar, Optional
from typing import Any, ClassVar
from uuid import UUID
import warnings
@ -24,11 +24,11 @@ class BaseComponent:
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
_code: Optional[str] = None
_code: str | None = None
"""The code of the component. Defaults to None."""
_function_entrypoint_name: str = "build"
field_config: dict = {}
_user_id: Optional[str | UUID] = None
_user_id: str | UUID | None = None
_template_config: dict = {}
def __init__(self, **data):

View file

@ -1,6 +1,7 @@
import inspect
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Union, get_type_hints
from typing import TYPE_CHECKING, Any, ClassVar, get_type_hints
from collections.abc import Callable
from uuid import UUID
import nanoid # type: ignore
@ -30,8 +31,8 @@ CONFIG_ATTRIBUTES = ["_display_name", "_description", "_icon", "_name"]
class Component(CustomComponent):
inputs: List["InputTypes"] = []
outputs: List[Output] = []
inputs: list["InputTypes"] = []
outputs: list[Output] = []
code_class_base_inheritance: ClassVar[str] = "Component"
_output_logs: dict[str, Log] = {}
@ -228,7 +229,7 @@ class Component(CustomComponent):
else:
raise ValueError(f"Output {name} not found in {self.__class__.__name__}")
def map_outputs(self, outputs: List[Output]):
def map_outputs(self, outputs: list[Output]):
"""
Maps the given list of outputs to the component.
@ -247,7 +248,7 @@ class Component(CustomComponent):
raise ValueError("Output name cannot be None.")
self._outputs[output.name] = output
def map_inputs(self, inputs: List["InputTypes"]):
def map_inputs(self, inputs: list["InputTypes"]):
"""
Maps the given inputs to the component.
@ -449,7 +450,7 @@ class Component(CustomComponent):
)
raise ValueError(f"Parameter {name} not found in {self.__class__.__name__}. ")
def _get_method_return_type(self, method_name: str) -> List[str]:
def _get_method_return_type(self, method_name: str) -> list[str]:
method = getattr(self, method_name)
return_type = get_type_hints(method)["return"]
extracted_return_types = self._extract_return_type(return_type)
@ -530,7 +531,7 @@ class Component(CustomComponent):
_attributes[key] = input_obj.value or None
self._attributes = _attributes
def _set_outputs(self, outputs: List[dict]):
def _set_outputs(self, outputs: list[dict]):
self.outputs = [Output(**output) for output in outputs]
for output in self.outputs:
setattr(self, output.name, output)
@ -646,7 +647,7 @@ class Component(CustomComponent):
return str(self.repr_value)
return self.repr_value
def build_inputs(self, user_id: Optional[Union[str, UUID]] = None):
def build_inputs(self, user_id: str | UUID | None = None):
"""
Builds the inputs for the custom component.

View file

@ -1,5 +1,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, ClassVar, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, ClassVar, Optional
from collections.abc import Callable, Sequence
import yaml
from cachetools import TTLCache
@ -48,42 +49,42 @@ class CustomComponent(BaseComponent):
_tree (Optional[dict]): The code tree of the custom component.
"""
name: Optional[str] = None
name: str | None = None
"""The name of the component used to styles. Defaults to None."""
display_name: Optional[str] = None
display_name: str | None = None
"""The display name of the component. Defaults to None."""
description: Optional[str] = None
description: str | None = None
"""The description of the component. Defaults to None."""
icon: Optional[str] = None
icon: str | None = None
"""The icon of the component. It should be an emoji. Defaults to None."""
is_input: Optional[bool] = None
is_input: bool | None = None
"""The input state of the component. Defaults to None.
If True, the component must have a field named 'input_value'."""
is_output: Optional[bool] = None
is_output: bool | None = None
"""The output state of the component. Defaults to None.
If True, the component must have a field named 'input_value'."""
field_config: dict = {}
"""The field configuration of the component. Defaults to an empty dictionary."""
field_order: Optional[List[str]] = None
field_order: list[str] | None = None
"""The field order of the component. Defaults to an empty list."""
frozen: Optional[bool] = False
frozen: bool | None = False
"""The default frozen state of the component. Defaults to False."""
build_parameters: Optional[dict] = None
build_parameters: dict | None = None
"""The build parameters of the component. Defaults to None."""
_vertex: Optional["Vertex"] = None
"""The edge target parameter of the component. Defaults to None."""
_code_class_base_inheritance: ClassVar[str] = "CustomComponent"
function_entrypoint_name: ClassVar[str] = "build"
function: Optional[Callable] = None
repr_value: Optional[Any] = ""
status: Optional[Any] = None
function: Callable | None = None
repr_value: Any | None = ""
status: Any | None = None
"""The status of the component. This is displayed on the frontend. Defaults to None."""
_flows_data: Optional[List[Data]] = None
_outputs: List[OutputValue] = []
_logs: List[Log] = []
_flows_data: list[Data] | None = None
_outputs: list[OutputValue] = []
_logs: list[Log] = []
_output_logs: dict[str, Log] = {}
_tracing_service: Optional["TracingService"] = None
_tree: Optional[dict] = None
_tree: dict | None = None
def __init__(self, **data):
"""
@ -215,7 +216,7 @@ class CustomComponent(BaseComponent):
self,
build_config: dotdict,
field_value: Any,
field_name: Optional[str] = None,
field_name: str | None = None,
):
build_config[field_name] = field_value
return build_config
@ -230,7 +231,7 @@ class CustomComponent(BaseComponent):
"""
return self.get_code_tree(self._code or "")
def to_data(self, data: Any, keys: Optional[List[str]] = None, silent_errors: bool = False) -> List[Data]:
def to_data(self, data: Any, keys: list[str] | None = None, silent_errors: bool = False) -> list[Data]:
"""
Converts input data into a list of Data objects.
@ -289,7 +290,7 @@ class CustomComponent(BaseComponent):
return self._extract_return_type(return_type)
def create_references_from_data(self, data: List[Data], include_data: bool = False) -> str:
def create_references_from_data(self, data: list[Data], include_data: bool = False) -> str:
"""
Create references from a list of data.
@ -352,7 +353,7 @@ class CustomComponent(BaseComponent):
return build_methods[0] if build_methods else {}
@property
def get_function_entrypoint_return_type(self) -> List[Any]:
def get_function_entrypoint_return_type(self) -> list[Any]:
"""
Gets the return type of the function entrypoint for the custom component.
@ -361,7 +362,7 @@ class CustomComponent(BaseComponent):
"""
return self.get_method_return_type(self._function_entrypoint_name)
def _extract_return_type(self, return_type: Any) -> List[Any]:
def _extract_return_type(self, return_type: Any) -> list[Any]:
return post_process_type(return_type)
@property
@ -451,7 +452,7 @@ class CustomComponent(BaseComponent):
Callable: A function that returns the value at the given index.
"""
def get_index(iterable: List[Any]):
def get_index(iterable: list[Any]):
return iterable[value] if iterable else iterable
return get_index
@ -465,18 +466,18 @@ class CustomComponent(BaseComponent):
"""
return validate.create_function(self._code, self._function_entrypoint_name)
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
async def load_flow(self, flow_id: str, tweaks: dict | None = None) -> "Graph":
if not self.user_id:
raise ValueError("Session is invalid")
return await load_flow(user_id=str(self._user_id), flow_id=flow_id, tweaks=tweaks)
async def run_flow(
self,
inputs: Optional[Union[dict, List[dict]]] = None,
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
output_type: Optional[str] = "chat",
tweaks: Optional[dict] = None,
inputs: dict | list[dict] | None = None,
flow_id: str | None = None,
flow_name: str | None = None,
output_type: str | None = "chat",
tweaks: dict | None = None,
) -> Any:
return await run_flow(
inputs=inputs,
@ -487,7 +488,7 @@ class CustomComponent(BaseComponent):
user_id=str(self._user_id),
)
def list_flows(self) -> List[Data]:
def list_flows(self) -> list[Data]:
if not self.user_id:
raise ValueError("Session is invalid")
try:
@ -508,7 +509,7 @@ class CustomComponent(BaseComponent):
"""
raise NotImplementedError
def log(self, message: LoggableType | list[LoggableType], name: Optional[str] = None):
def log(self, message: LoggableType | list[LoggableType], name: str | None = None):
"""
Logs a message.
@ -531,7 +532,7 @@ class CustomComponent(BaseComponent):
)
return frontend_node
def get_langchain_callbacks(self) -> List["BaseCallbackHandler"]:
def get_langchain_callbacks(self) -> list["BaseCallbackHandler"]:
if self._tracing_service:
return self._tracing_service.get_langchain_callbacks()
return []

View file

@ -109,7 +109,7 @@ class DirectoryReader:
"""
if not os.path.isfile(file_path):
return None
with open(file_path, "r", encoding="utf-8") as file:
with open(file_path, encoding="utf-8") as file:
# UnicodeDecodeError: 'charmap' codec can't decode byte 0x9d in position 3069: character maps to <undefined>
try:
return file.read()

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Type
from typing import TYPE_CHECKING
from langflow.utils import validate
@ -6,7 +6,7 @@ if TYPE_CHECKING:
from langflow.custom import CustomComponent
def eval_custom_component_code(code: str) -> Type["CustomComponent"]:
def eval_custom_component_code(code: str) -> type["CustomComponent"]:
"""Evaluate custom component code"""
class_name = validate.extract_class_name(code)
return validate.create_class(code, class_name)

View file

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from pydantic import BaseModel, Field
@ -9,11 +9,11 @@ class ClassCodeDetails(BaseModel):
"""
name: str
doc: Optional[str] = None
doc: str | None = None
bases: list
attributes: list
methods: list
init: Optional[dict] = Field(default_factory=dict)
init: dict | None = Field(default_factory=dict)
class CallableCodeDetails(BaseModel):
@ -22,10 +22,10 @@ class CallableCodeDetails(BaseModel):
"""
name: str
doc: Optional[str] = None
doc: str | None = None
args: list
body: list
return_type: Optional[Any] = None
return_type: Any | None = None
has_return: bool = False

View file

@ -3,7 +3,7 @@ import contextlib
import re
import traceback
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any
from uuid import UUID
from fastapi import HTTPException
@ -33,7 +33,7 @@ class UpdateBuildConfigError(Exception):
pass
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: list[str]):
"""Add output types to the frontend node"""
for return_type in return_types:
if return_type is None:
@ -56,7 +56,7 @@ def add_output_types(frontend_node: CustomComponentFrontendNode, return_types: L
frontend_node.add_output_type(return_type)
def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: List[str]):
def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: list[str]):
"""Reorder fields in the frontend node based on the specified field_order."""
if not field_order:
return
@ -72,7 +72,7 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: List
frontend_node.field_order = field_order
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: list[str]):
"""Add base classes to the frontend node"""
for return_type_instance in return_types:
if return_type_instance is None:
@ -243,7 +243,7 @@ def add_extra_fields(frontend_node, field_config, function_args):
)
def get_field_dict(field: Union[Input, dict]):
def get_field_dict(field: Input | dict):
"""Get the field dictionary from a Input or a dict"""
if isinstance(field, Input):
return dotdict(field.model_dump(by_alias=True, exclude_none=True))
@ -252,7 +252,7 @@ def get_field_dict(field: Union[Input, dict]):
def run_build_inputs(
custom_component: Component,
user_id: Optional[Union[str, UUID]] = None,
user_id: str | UUID | None = None,
):
"""Run the build inputs of a custom component."""
try:
@ -264,7 +264,7 @@ def run_build_inputs(
raise HTTPException(status_code=500, detail=str(exc)) from exc
def get_component_instance(custom_component: CustomComponent, user_id: Optional[Union[str, UUID]] = None):
def get_component_instance(custom_component: CustomComponent, user_id: str | UUID | None = None):
try:
if custom_component._code is None:
raise ValueError("Code is None")
@ -295,8 +295,8 @@ def get_component_instance(custom_component: CustomComponent, user_id: Optional[
def run_build_config(
custom_component: CustomComponent,
user_id: Optional[Union[str, UUID]] = None,
) -> Tuple[dict, CustomComponent]:
user_id: str | UUID | None = None,
) -> tuple[dict, CustomComponent]:
"""Build the field configuration for a custom component"""
try:
@ -318,7 +318,7 @@ def run_build_config(
try:
custom_instance = custom_class(_user_id=user_id)
build_config: Dict = custom_instance.build_config()
build_config: dict = custom_instance.build_config()
for field_name, field in build_config.copy().items():
# Allow user to build Input as well
@ -358,7 +358,7 @@ def add_code_field(frontend_node: CustomComponentFrontendNode, raw_code):
def build_custom_component_template_from_inputs(
custom_component: Union[Component, CustomComponent], user_id: Optional[Union[str, UUID]] = None
custom_component: Component | CustomComponent, user_id: str | UUID | None = None
):
# The List of Inputs fills the role of the build_config and the entrypoint_args
cc_instance = get_component_instance(custom_component, user_id=user_id)
@ -384,8 +384,8 @@ def build_custom_component_template_from_inputs(
def build_custom_component_template(
custom_component: CustomComponent,
user_id: Optional[Union[str, UUID]] = None,
) -> Tuple[Dict[str, Any], CustomComponent | Component]:
user_id: str | UUID | None = None,
) -> tuple[dict[str, Any], CustomComponent | Component]:
"""Build a custom component template"""
try:
if not hasattr(custom_component, "template_config"):
@ -442,7 +442,7 @@ def create_component_template(component):
return component_template, component_instance
def build_custom_components(components_paths: List[str]):
def build_custom_components(components_paths: list[str]):
"""Build custom components from the specified paths."""
if not components_paths:
return {}
@ -467,7 +467,7 @@ def build_custom_components(components_paths: List[str]):
return custom_components_from_file
async def abuild_custom_components(components_paths: List[str]):
async def abuild_custom_components(components_paths: list[str]):
"""Build custom components from the specified paths."""
if not components_paths:
return {}
@ -494,10 +494,10 @@ async def abuild_custom_components(components_paths: List[str]):
def update_field_dict(
custom_component_instance: "CustomComponent",
field_dict: Dict,
build_config: Dict,
update_field: Optional[str] = None,
update_field_value: Optional[Any] = None,
field_dict: dict,
build_config: dict,
update_field: str | None = None,
update_field_value: Any | None = None,
call: bool = False,
):
"""Update the field dictionary by calling options() or value() if they are callable"""
@ -523,7 +523,7 @@ def update_field_dict(
return build_config
def sanitize_field_config(field_config: Union[Dict, Input]):
def sanitize_field_config(field_config: dict | Input):
# If any of the already existing keys are in field_config, remove them
if isinstance(field_config, Input):
field_dict = field_config.to_dict()

View file

@ -1,4 +1,4 @@
from typing import Any, List, Optional
from typing import Any
from pydantic import ConfigDict, Field, field_validator
from typing_extensions import TypedDict
@ -12,12 +12,12 @@ class ResultPair(BaseModel):
class Payload(BaseModel):
result_pairs: List[ResultPair] = []
result_pairs: list[ResultPair] = []
def __iter__(self):
return iter(self.result_pairs)
def add_result_pair(self, result: Any, extra: Optional[Any] = None) -> None:
def add_result_pair(self, result: Any, extra: Any | None = None) -> None:
self.result_pairs.append(ResultPair(result=result, extra=extra))
def get_last_result_pair(self) -> ResultPair:
@ -42,7 +42,7 @@ class TargetHandle(BaseModel):
model_config = ConfigDict(populate_by_name=True)
field_name: str = Field(..., alias="fieldName", description="Field name for the target handle.")
id: str = Field(..., description="Unique identifier for the target handle.")
input_types: List[str] = Field(
input_types: list[str] = Field(
default_factory=list, alias="inputTypes", description="List of input types for the target handle."
)
type: str = Field(..., description="Type of the target handle.")
@ -55,8 +55,8 @@ class SourceHandle(BaseModel):
)
data_type: str = Field(..., alias="dataType", description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
name: Optional[str] = Field(None, description="Name of the source handle.")
output_types: List[str] = Field(default_factory=list, description="List of output types for the source handle.")
name: str | None = Field(None, description="Name of the source handle.")
output_types: list[str] = Field(default_factory=list, description="List of output types for the source handle.")
@field_validator("name", mode="before")
@classmethod
@ -74,14 +74,14 @@ class SourceHandleDict(TypedDict, total=False):
baseClasses: list[str]
dataType: str
id: str
name: Optional[str]
output_types: List[str]
name: str | None
output_types: list[str]
class TargetHandleDict(TypedDict):
fieldName: str
id: str
inputTypes: Optional[List[str]]
inputTypes: list[str] | None
type: str

View file

@ -7,7 +7,8 @@ from collections import defaultdict, deque
from datetime import datetime, timezone
from functools import partial
from itertools import chain
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Optional
from collections.abc import Generator
import nest_asyncio
from loguru import logger
@ -46,11 +47,11 @@ class Graph:
self,
start: Optional["Component"] = None,
end: Optional["Component"] = None,
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
description: Optional[str] = None,
user_id: Optional[str] = None,
log_config: Optional[LogConfig] = None,
flow_id: str | None = None,
flow_name: str | None = None,
description: str | None = None,
user_id: str | None = None,
log_config: LogConfig | None = None,
) -> None:
"""
Initializes a new instance of the Graph class.
@ -72,39 +73,39 @@ class Graph:
self.flow_name = flow_name
self.description = description
self.user_id = user_id
self._is_input_vertices: List[str] = []
self._is_output_vertices: List[str] = []
self._is_state_vertices: List[str] = []
self._has_session_id_vertices: List[str] = []
self._sorted_vertices_layers: List[List[str]] = []
self._is_input_vertices: list[str] = []
self._is_output_vertices: list[str] = []
self._is_state_vertices: list[str] = []
self._has_session_id_vertices: list[str] = []
self._sorted_vertices_layers: list[list[str]] = []
self._run_id = ""
self._start_time = datetime.now(timezone.utc)
self.inactivated_vertices: set = set()
self.activated_vertices: List[str] = []
self.vertices_layers: List[List[str]] = []
self.activated_vertices: list[str] = []
self.vertices_layers: list[list[str]] = []
self.vertices_to_run: set[str] = set()
self.stop_vertex: Optional[str] = None
self.stop_vertex: str | None = None
self.inactive_vertices: set = set()
self.edges: List[CycleEdge] = []
self.vertices: List[Vertex] = []
self.edges: list[CycleEdge] = []
self.vertices: list[Vertex] = []
self.run_manager = RunnableVerticesManager()
self.state_manager = GraphStateManager()
self._vertices: List[NodeData] = []
self._edges: List[EdgeData] = []
self.top_level_vertices: List[str] = []
self.vertex_map: Dict[str, Vertex] = {}
self.predecessor_map: Dict[str, List[str]] = defaultdict(list)
self.successor_map: Dict[str, List[str]] = defaultdict(list)
self.in_degree_map: Dict[str, int] = defaultdict(int)
self.parent_child_map: Dict[str, List[str]] = defaultdict(list)
self._vertices: list[NodeData] = []
self._edges: list[EdgeData] = []
self.top_level_vertices: list[str] = []
self.vertex_map: dict[str, Vertex] = {}
self.predecessor_map: dict[str, list[str]] = defaultdict(list)
self.successor_map: dict[str, list[str]] = defaultdict(list)
self.in_degree_map: dict[str, int] = defaultdict(int)
self.parent_child_map: dict[str, list[str]] = defaultdict(list)
self._run_queue: deque[str] = deque()
self._first_layer: List[str] = []
self._first_layer: list[str] = []
self._lock = asyncio.Lock()
self.raw_graph_data: GraphData = {"nodes": [], "edges": []}
self._is_cyclic: Optional[bool] = None
self._cycles: Optional[List[tuple[str, str]]] = None
self._call_order: List[str] = []
self._snapshots: List[Dict[str, Any]] = []
self._is_cyclic: bool | None = None
self._cycles: list[tuple[str, str]] | None = None
self._call_order: list[str] = []
self._snapshots: list[dict[str, Any]] = []
try:
self.tracing_service: "TracingService" | None = get_tracing_service()
except Exception as exc:
@ -147,15 +148,15 @@ class Graph:
def dumps(
self,
name: Optional[str] = None,
description: Optional[str] = None,
endpoint_name: Optional[str] = None,
name: str | None = None,
description: str | None = None,
endpoint_name: str | None = None,
) -> str:
graph_dict = self.dump(name, description, endpoint_name)
return json.dumps(graph_dict, indent=4, sort_keys=True)
def dump(
self, name: Optional[str] = None, description: Optional[str] = None, endpoint_name: Optional[str] = None
self, name: str | None = None, description: str | None = None, endpoint_name: str | None = None
) -> GraphDump:
if self.raw_graph_data != {"nodes": [], "edges": []}:
data_dict = self.raw_graph_data
@ -180,7 +181,7 @@ class Graph:
graph_dict["endpoint_name"] = str(endpoint_name)
return graph_dict
def add_nodes_and_edges(self, nodes: List[NodeData], edges: List[EdgeData]):
def add_nodes_and_edges(self, nodes: list[NodeData], edges: list[EdgeData]):
self._vertices = nodes
self._edges = edges
self.raw_graph_data = {"nodes": nodes, "edges": edges}
@ -222,7 +223,7 @@ class Graph:
self.add_component(start._id, start)
self.add_component(end._id, end)
def add_component_edge(self, source_id: str, output_input_tuple: Tuple[str, str], target_id: str):
def add_component_edge(self, source_id: str, output_input_tuple: tuple[str, str], target_id: str):
source_vertex = self.get_vertex(source_id)
if not isinstance(source_vertex, ComponentVertex):
raise ValueError(f"Source vertex {source_id} is not a component vertex.")
@ -255,7 +256,7 @@ class Graph:
}
self._add_edge(edge_data)
async def async_start(self, inputs: Optional[List[dict]] = None, max_iterations: Optional[int] = None):
async def async_start(self, inputs: list[dict] | None = None, max_iterations: int | None = None):
if not self._prepared:
raise ValueError("Graph not prepared. Call prepare() first.")
# The idea is for this to return a generator that yields the result of
@ -288,9 +289,9 @@ class Graph:
def start(
self,
inputs: Optional[List[dict]] = None,
max_iterations: Optional[int] = None,
config: Optional[StartConfigDict] = None,
inputs: list[dict] | None = None,
max_iterations: int | None = None,
config: StartConfigDict | None = None,
) -> Generator:
if config is not None:
self.__apply_config(config)
@ -333,7 +334,7 @@ class Graph:
self.build_graph_maps(self.edges)
self.define_vertices_lists()
def get_state(self, name: str) -> Optional[Data]:
def get_state(self, name: str) -> Data | None:
"""
Returns the state of the graph with the given name.
@ -345,7 +346,7 @@ class Graph:
"""
return self.state_manager.get_state(name, run_id=self._run_id)
def update_state(self, name: str, record: Union[str, Data], caller: Optional[str] = None) -> None:
def update_state(self, name: str, record: str | Data, caller: str | None = None) -> None:
"""
Updates the state of the graph with the given name.
@ -419,7 +420,7 @@ class Graph:
"""
self.activated_vertices = []
def append_state(self, name: str, record: Union[str, Data], caller: Optional[str] = None) -> None:
def append_state(self, name: str, record: str | Data, caller: str | None = None) -> None:
"""
Appends the state of the graph with the given name.
@ -513,7 +514,7 @@ class Graph:
await self.tracing_service.end(outputs, error)
@property
def sorted_vertices_layers(self) -> List[List[str]]:
def sorted_vertices_layers(self) -> list[list[str]]:
"""
The sorted layers of vertices in the graph.
@ -534,7 +535,7 @@ class Graph:
if getattr(vertex, attribute):
getattr(self, f"_{attribute}_vertices").append(vertex.id)
def _set_inputs(self, input_components: list[str], inputs: Dict[str, str], input_type: InputType | None):
def _set_inputs(self, input_components: list[str], inputs: dict[str, str], input_type: InputType | None):
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
# If the vertex is not in the input_components list
@ -550,14 +551,14 @@ class Graph:
async def _run(
self,
inputs: Dict[str, str],
inputs: dict[str, str],
input_components: list[str],
input_type: InputType | None,
outputs: list[str],
stream: bool,
session_id: str,
fallback_to_env_vars: bool,
) -> List[Optional["ResultData"]]:
) -> list[Optional["ResultData"]]:
"""
Runs the graph with the given inputs.
@ -622,14 +623,14 @@ class Graph:
def run(
self,
inputs: list[Dict[str, str]],
input_components: Optional[list[list[str]]] = None,
types: Optional[list[InputType | None]] = None,
outputs: Optional[list[str]] = None,
session_id: Optional[str] = None,
inputs: list[dict[str, str]],
input_components: list[list[str]] | None = None,
types: list[InputType | None] | None = None,
outputs: list[str] | None = None,
session_id: str | None = None,
stream: bool = False,
fallback_to_env_vars: bool = False,
) -> List[RunOutputs]:
) -> list[RunOutputs]:
"""
Run the graph with the given inputs and return the outputs.
@ -671,14 +672,14 @@ class Graph:
async def arun(
self,
inputs: list[Dict[str, str]],
inputs_components: Optional[list[list[str]]] = None,
types: Optional[list[InputType | None]] = None,
outputs: Optional[list[str]] = None,
session_id: Optional[str] = None,
inputs: list[dict[str, str]],
inputs_components: list[list[str]] | None = None,
types: list[InputType | None] | None = None,
outputs: list[str] | None = None,
session_id: str | None = None,
stream: bool = False,
fallback_to_env_vars: bool = False,
) -> List[RunOutputs]:
) -> list[RunOutputs]:
"""
Runs the graph with the given inputs.
@ -752,7 +753,7 @@ class Graph:
"flow_name": self.flow_name,
}
def build_graph_maps(self, edges: Optional[List[CycleEdge]] = None, vertices: Optional[List["Vertex"]] = None):
def build_graph_maps(self, edges: list[CycleEdge] | None = None, vertices: list["Vertex"] | None = None):
"""
Builds the adjacency maps for the graph.
"""
@ -788,9 +789,7 @@ class Graph:
if state == VertexStates.INACTIVE:
self.run_manager.remove_from_predecessors(vertex_id)
def _mark_branch(
self, vertex_id: str, state: str, visited: Optional[set] = None, output_name: Optional[str] = None
):
def _mark_branch(self, vertex_id: str, state: str, visited: set | None = None, output_name: str | None = None):
"""Marks a branch of the graph."""
if visited is None:
visited = set()
@ -809,7 +808,7 @@ class Graph:
continue
self._mark_branch(child_id, state, visited)
def mark_branch(self, vertex_id: str, state: str, output_name: Optional[str] = None):
def mark_branch(self, vertex_id: str, state: str, output_name: str | None = None):
self._mark_branch(vertex_id=vertex_id, state=state, output_name=output_name)
new_predecessor_map, _ = self.build_adjacency_maps(self.edges)
self.run_manager.update_run_state(
@ -817,14 +816,14 @@ class Graph:
vertices_to_run=self.vertices_to_run,
)
def get_edge(self, source_id: str, target_id: str) -> Optional[CycleEdge]:
def get_edge(self, source_id: str, target_id: str) -> CycleEdge | None:
"""Returns the edge between two vertices."""
for edge in self.edges:
if edge.source_id == source_id and edge.target_id == target_id:
return edge
return None
def build_parent_child_map(self, vertices: List["Vertex"]):
def build_parent_child_map(self, vertices: list["Vertex"]):
parent_child_map = defaultdict(list)
for vertex in vertices:
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
@ -919,10 +918,10 @@ class Graph:
@classmethod
def from_payload(
cls,
payload: Dict,
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
user_id: Optional[str] = None,
payload: dict,
flow_id: str | None = None,
flow_name: str | None = None,
user_id: str | None = None,
) -> "Graph":
"""
Creates a graph from a payload.
@ -989,7 +988,7 @@ class Graph:
def update(self, other: "Graph") -> "Graph":
# Existing vertices in self graph
existing_vertex_ids = set(vertex.id for vertex in self.vertices)
existing_vertex_ids = {vertex.id for vertex in self.vertices}
# Vertex IDs in the other graph
other_vertex_ids = set(other.vertex_map.keys())
@ -1146,14 +1145,14 @@ class Graph:
return None
return self._run_queue.popleft()
def extend_run_queue(self, vertices: List[str]):
def extend_run_queue(self, vertices: list[str]):
self._run_queue.extend(vertices)
async def astep(
self,
inputs: Optional["InputValueRequest"] = None,
files: Optional[list[str]] = None,
user_id: Optional[str] = None,
files: list[str] | None = None,
user_id: str | None = None,
):
if not self._prepared:
raise ValueError("Graph not prepared. Call prepare() first.")
@ -1204,8 +1203,8 @@ class Graph:
def step(
self,
inputs: Optional["InputValueRequest"] = None,
files: Optional[list[str]] = None,
user_id: Optional[str] = None,
files: list[str] | None = None,
user_id: str | None = None,
):
# Call astep but synchronously
loop = asyncio.get_event_loop()
@ -1216,9 +1215,9 @@ class Graph:
vertex_id: str,
get_cache: GetCache | None = None,
set_cache: SetCache | None = None,
inputs_dict: Optional[Dict[str, str]] = None,
files: Optional[list[str]] = None,
user_id: Optional[str] = None,
inputs_dict: dict[str, str] | None = None,
files: list[str] | None = None,
user_id: str | None = None,
fallback_to_env_vars: bool = False,
) -> VertexBuildResult:
"""
@ -1308,9 +1307,9 @@ class Graph:
def get_vertex_edges(
self,
vertex_id: str,
is_target: Optional[bool] = None,
is_source: Optional[bool] = None,
) -> List[CycleEdge]:
is_target: bool | None = None,
is_source: bool | None = None,
) -> list[CycleEdge]:
"""Returns a list of edges for a given vertex."""
# The idea here is to return the edges that have the vertex_id as source or target
# or both
@ -1321,9 +1320,9 @@ class Graph:
or (edge.target_id == vertex_id and is_target is not False)
]
def get_vertices_with_target(self, vertex_id: str) -> List["Vertex"]:
def get_vertices_with_target(self, vertex_id: str) -> list["Vertex"]:
"""Returns the vertices connected to a vertex."""
vertices: List["Vertex"] = []
vertices: list["Vertex"] = []
for edge in self.edges:
if edge.target_id == vertex_id:
vertex = self.get_vertex(edge.source_id)
@ -1332,11 +1331,11 @@ class Graph:
vertices.append(vertex)
return vertices
async def process(self, fallback_to_env_vars: bool, start_component_id: Optional[str] = None) -> "Graph":
async def process(self, fallback_to_env_vars: bool, start_component_id: str | None = None) -> "Graph":
"""Processes the graph with vertices in each layer run in parallel."""
first_layer = self.sort_vertices(start_component_id=start_component_id)
vertex_task_run_count: Dict[str, int] = {}
vertex_task_run_count: dict[str, int] = {}
to_process = deque(first_layer)
layer_index = 0
chat_service = get_chat_service()
@ -1379,7 +1378,7 @@ class Graph:
logger.debug("Graph processing complete")
return self
def find_next_runnable_vertices(self, vertex_id: str, vertex_successors_ids: List[str]) -> List[str]:
def find_next_runnable_vertices(self, vertex_id: str, vertex_successors_ids: list[str]) -> list[str]:
next_runnable_vertices = set()
for v_id in vertex_successors_ids:
if not self.is_vertex_runnable(v_id):
@ -1389,7 +1388,7 @@ class Graph:
return list(next_runnable_vertices)
async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: "Vertex", cache: bool = True) -> List[str]:
async def get_next_runnable_vertices(self, lock: asyncio.Lock, vertex: "Vertex", cache: bool = True) -> list[str]:
v_id = vertex.id
v_successors_ids = vertex.successors_ids
async with lock:
@ -1406,11 +1405,11 @@ class Graph:
await set_cache_coro(data=self, lock=lock)
return next_runnable_vertices
async def _execute_tasks(self, tasks: List[asyncio.Task], lock: asyncio.Lock) -> List[str]:
async def _execute_tasks(self, tasks: list[asyncio.Task], lock: asyncio.Lock) -> list[str]:
"""Executes tasks in parallel, handling exceptions for each task."""
results = []
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
vertices: List["Vertex"] = []
vertices: list["Vertex"] = []
for i, result in enumerate(completed_tasks):
task_name = tasks[i].get_name()
@ -1437,7 +1436,7 @@ class Graph:
no_duplicate_results = list(set(results))
return no_duplicate_results
def topological_sort(self) -> List["Vertex"]:
def topological_sort(self) -> list["Vertex"]:
"""
Performs a topological sort of the vertices in the graph.
@ -1513,13 +1512,13 @@ class Graph:
return successors_result
def get_successors(self, vertex: "Vertex") -> List["Vertex"]:
def get_successors(self, vertex: "Vertex") -> list["Vertex"]:
"""Returns the successors of a vertex."""
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
def get_vertex_neighbors(self, vertex: "Vertex") -> Dict["Vertex", int]:
def get_vertex_neighbors(self, vertex: "Vertex") -> dict["Vertex", int]:
"""Returns the neighbors of a vertex."""
neighbors: Dict["Vertex", int] = {}
neighbors: dict["Vertex", int] = {}
for edge in self.edges:
if edge.source_id == vertex.id:
neighbor = self.get_vertex(edge.target_id)
@ -1537,7 +1536,7 @@ class Graph:
neighbors[neighbor] += 1
return neighbors
def _build_edges(self) -> List[CycleEdge]:
def _build_edges(self) -> list[CycleEdge]:
"""Builds the edges of the graph."""
# Edge takes two vertices as arguments, so we need to build the vertices first
# and then build the edges
@ -1562,7 +1561,7 @@ class Graph:
new_edge = CycleEdge(source, target, edge)
return new_edge
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type["Vertex"]:
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> type["Vertex"]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -1579,9 +1578,9 @@ class Graph:
return lazy_load_vertex_dict.VERTEX_TYPE_MAP[node_type]
return Vertex
def _build_vertices(self) -> List["Vertex"]:
def _build_vertices(self) -> list["Vertex"]:
"""Builds the vertices of the graph."""
vertices: List["Vertex"] = []
vertices: list["Vertex"] = []
for frontend_data in self._vertices:
try:
vertex_instance = self.get_vertex(frontend_data["id"])
@ -1604,7 +1603,7 @@ class Graph:
vertex_instance.set_top_level(self.top_level_vertices)
return vertex_instance
def prepare(self, stop_component_id: Optional[str] = None, start_component_id: Optional[str] = None):
def prepare(self, stop_component_id: str | None = None, start_component_id: str | None = None):
self.initialize()
if stop_component_id and start_component_id:
raise ValueError("You can only provide one of stop_component_id or start_component_id")
@ -1628,7 +1627,7 @@ class Graph:
self._record_snapshot()
return self
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> list[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -1653,9 +1652,9 @@ class Graph:
def layered_topological_sort(
self,
vertices: List["Vertex"],
vertices: list["Vertex"],
filter_graphs: bool = False,
) -> List[List[str]]:
) -> list[list[str]]:
"""Performs a layered topological sort of the vertices in the graph."""
vertices_ids = {vertex.id for vertex in vertices}
# Queue for vertices with no incoming edges
@ -1665,7 +1664,7 @@ class Graph:
# if filter_graphs then only vertex.is_input will be considered
if self.in_degree_map[vertex.id] == 0 and (not filter_graphs or vertex.is_input)
)
layers: List[List[str]] = []
layers: list[list[str]] = []
visited = set(queue)
current_layer = 0
@ -1738,7 +1737,7 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_chat_inputs_first(self, vertices_layers: list[list[str]]) -> list[list[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -1753,7 +1752,7 @@ class Graph:
return vertices_layers
def sort_layer_by_dependency(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_layer_by_dependency(self, vertices_layers: list[list[str]]) -> list[list[str]]:
"""Sorts the vertices in each layer by dependency, ensuring no vertex depends on a subsequent vertex."""
sorted_layers = []
@ -1763,7 +1762,7 @@ class Graph:
return sorted_layers
def _sort_single_layer_by_dependency(self, layer: List[str]) -> List[str]:
def _sort_single_layer_by_dependency(self, layer: list[str]) -> list[str]:
"""Sorts a single layer by dependency using a stable sorting method."""
# Build a map of each vertex to its index in the layer for quick lookup.
index_map = {vertex: index for index, vertex in enumerate(layer)}
@ -1772,7 +1771,7 @@ class Graph:
return sorted_layer
def _max_dependency_index(self, vertex_id: str, index_map: Dict[str, int]) -> int:
def _max_dependency_index(self, vertex_id: str, index_map: dict[str, int]) -> int:
"""Finds the highest index a given vertex's dependencies occupy in the same layer."""
vertex = self.get_vertex(vertex_id)
max_index = -1
@ -1781,9 +1780,9 @@ class Graph:
max_index = max(max_index, index_map[successor.id])
return max_index
def __to_dict(self) -> Dict[str, Dict[str, List[str]]]:
def __to_dict(self) -> dict[str, dict[str, list[str]]]:
"""Converts the graph to a dictionary."""
result: Dict = dict()
result: dict = dict()
for vertex in self.vertices:
vertex_id = vertex.id
sucessors = [i.id for i in self.get_all_successors(vertex)]
@ -1798,9 +1797,9 @@ class Graph:
def sort_vertices(
self,
stop_component_id: Optional[str] = None,
start_component_id: Optional[str] = None,
) -> List[str]:
stop_component_id: str | None = None,
start_component_id: str | None = None,
) -> list[str]:
"""Sorts the vertices in the graph."""
self.mark_all_vertices("ACTIVE")
if stop_component_id is not None:
@ -1833,7 +1832,7 @@ class Graph:
self._first_layer = first_layer
return first_layer
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_interface_components_first(self, vertices_layers: list[list[str]]) -> list[list[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
@ -1849,10 +1848,10 @@ class Graph:
]
return sorted_vertices
def sort_by_avg_build_time(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_by_avg_build_time(self, vertices_layers: list[list[str]]) -> list[list[str]]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]:
def sort_layer_by_avg_build_time(vertices_ids: list[str]) -> list[str]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
@ -1880,7 +1879,7 @@ class Graph:
"""
self.run_manager.build_run_map(predecessor_map=self.predecessor_map, vertices_to_run=self.vertices_to_run)
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> List[str]:
def find_runnable_predecessors_for_successors(self, vertex_id: str) -> list[str]:
"""
For each successor of the current vertex, find runnable predecessors if any.
This checks the direct predecessors of each successor to identify any that are
@ -1892,7 +1891,7 @@ class Graph:
return runnable_vertices
def find_runnable_predecessors_for_successor(self, vertex_id: str) -> List[str]:
def find_runnable_predecessors_for_successor(self, vertex_id: str) -> list[str]:
runnable_vertices = []
visited = set()
@ -1938,8 +1937,8 @@ class Graph:
top_level_vertices.append(vertex_id)
return top_level_vertices
def build_in_degree(self, edges: List[CycleEdge]) -> Dict[str, int]:
in_degree: Dict[str, int] = defaultdict(int)
def build_in_degree(self, edges: list[CycleEdge]) -> dict[str, int]:
in_degree: dict[str, int] = defaultdict(int)
for edge in edges:
in_degree[edge.target_id] += 1
for vertex in self.vertices:
@ -1947,7 +1946,7 @@ class Graph:
in_degree[vertex.id] = 0
return in_degree
def build_adjacency_maps(self, edges: List[CycleEdge]) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
def build_adjacency_maps(self, edges: list[CycleEdge]) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
"""Returns the adjacency maps for the graph."""
predecessor_map: dict[str, list[str]] = defaultdict(list)
successor_map: dict[str, list[str]] = defaultdict(list)

View file

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING
from collections.abc import Callable
from loguru import logger

View file

@ -1,6 +1,5 @@
import copy
from collections import defaultdict, deque
from typing import Dict, List
PRIORITY_LIST_OF_INPUTS = ["webhook", "chat"]
@ -238,7 +237,7 @@ def get_updated_edges(base_flow, g_nodes, g_edges, group_node_id):
return updated_edges
def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> List[str]:
def get_successors(graph: dict[str, dict[str, list[str]]], vertex_id: str) -> list[str]:
successors_result = []
stack = [vertex_id]
visited = set()
@ -252,7 +251,7 @@ def get_successors(graph: Dict[str, Dict[str, List[str]]], vertex_id: str) -> Li
return successors_result
def sort_up_to_vertex(graph: Dict[str, Dict[str, List[str]]], vertex_id: str, is_start: bool = False) -> List[str]:
def sort_up_to_vertex(graph: dict[str, dict[str, list[str]]], vertex_id: str, is_start: bool = False) -> list[str]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
try:
stop_or_start_vertex = graph[vertex_id]

View file

@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, List, Optional
from typing import Any
from pydantic import BaseModel, Field, field_serializer, model_validator
@ -9,16 +9,16 @@ from langflow.utils.schemas import ChatOutputResponse, ContainsEnumMeta
class ResultData(BaseModel):
results: Optional[Any] = Field(default_factory=dict)
artifacts: Optional[Any] = Field(default_factory=dict)
outputs: Optional[dict] = Field(default_factory=dict)
logs: Optional[dict] = Field(default_factory=dict)
messages: Optional[list[ChatOutputResponse]] = Field(default_factory=list)
timedelta: Optional[float] = None
duration: Optional[str] = None
component_display_name: Optional[str] = None
component_id: Optional[str] = None
used_frozen_result: Optional[bool] = False
results: Any | None = Field(default_factory=dict)
artifacts: Any | None = Field(default_factory=dict)
outputs: dict | None = Field(default_factory=dict)
logs: dict | None = Field(default_factory=dict)
messages: list[ChatOutputResponse] | None = Field(default_factory=list)
timedelta: float | None = None
duration: str | None = None
component_display_name: str | None = None
component_id: str | None = None
used_frozen_result: bool | None = False
@field_serializer("results")
def serialize_results(self, value):
@ -82,4 +82,4 @@ OUTPUT_COMPONENTS = [
class RunOutputs(BaseModel):
inputs: dict = Field(default_factory=dict)
outputs: List[Optional[ResultData]] = Field(default_factory=list)
outputs: list[ResultData | None] = Field(default_factory=list)

View file

@ -1,4 +1,5 @@
from typing import Any, Callable, get_type_hints
from typing import Any, get_type_hints
from collections.abc import Callable
from pydantic import ConfigDict, computed_field, create_model
from pydantic.fields import FieldInfo

View file

@ -1,6 +1,7 @@
import json
from enum import Enum
from typing import TYPE_CHECKING, Any, Generator, Optional, Union
from typing import TYPE_CHECKING, Any, Optional
from collections.abc import Generator
from uuid import UUID
from langchain_core.documents import Document
@ -54,7 +55,7 @@ def fix_prompt(prompt: str):
return prompt + " {input}"
def flatten_list(list_of_lists: list[Union[list, Any]]) -> list:
def flatten_list(list_of_lists: list[list | Any]) -> list:
"""Flatten list of lists."""
new_list = []
for item in list_of_lists:
@ -135,7 +136,7 @@ def _vertex_to_primitive_dict(target: "Vertex") -> dict:
async def log_transaction(
flow_id: Union[str, UUID], source: "Vertex", status, target: Optional["Vertex"] = None, error=None
flow_id: str | UUID, source: "Vertex", status, target: Optional["Vertex"] = None, error=None
) -> None:
try:
if not get_settings_service().settings.transactions_storage_enabled:
@ -164,7 +165,7 @@ def log_vertex_build(
valid: bool,
params: Any,
data: "ResultDataResponse",
artifacts: Optional[dict] = None,
artifacts: dict | None = None,
):
try:
if not get_settings_service().settings.vertex_builds_storage_enabled:

View file

@ -6,7 +6,8 @@ import traceback
import types
import json
from enum import Enum
from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional, Set
from typing import TYPE_CHECKING, Any, Optional
from collections.abc import AsyncIterator, Callable, Iterator, Mapping
import pandas as pd
from loguru import logger
@ -46,9 +47,9 @@ class Vertex:
self,
data: NodeData,
graph: "Graph",
base_type: Optional[str] = None,
base_type: str | None = None,
is_task: bool = False,
params: Optional[Dict] = None,
params: dict | None = None,
) -> None:
# is_external means that the Vertex send or receives data from
# an external source (e.g the chat)
@ -66,36 +67,36 @@ class Vertex:
self.has_external_output = False
self.graph = graph
self._data = data.copy()
self.base_type: Optional[str] = base_type
self.outputs: List[Dict] = []
self.base_type: str | None = base_type
self.outputs: list[dict] = []
self._parse_data()
self._built_object = UnbuiltObject()
self._built_result = None
self._built = False
self._successors_ids: Optional[List[str]] = None
self.artifacts: Dict[str, Any] = {}
self.artifacts_raw: Dict[str, Any] = {}
self.artifacts_type: Dict[str, str] = {}
self.steps: List[Callable] = [self._build]
self.steps_ran: List[Callable] = []
self.task_id: Optional[str] = None
self._successors_ids: list[str] | None = None
self.artifacts: dict[str, Any] = {}
self.artifacts_raw: dict[str, Any] = {}
self.artifacts_type: dict[str, str] = {}
self.steps: list[Callable] = [self._build]
self.steps_ran: list[Callable] = []
self.task_id: str | None = None
self.is_task = is_task
self.params = params or {}
self.parent_node_id: Optional[str] = self._data.get("parent_node_id")
self.load_from_db_fields: List[str] = []
self.parent_node_id: str | None = self._data.get("parent_node_id")
self.load_from_db_fields: list[str] = []
self.parent_is_top_level = False
self.layer = None
self.result: Optional[ResultData] = None
self.results: Dict[str, Any] = {}
self.outputs_logs: Dict[str, OutputValue] = {}
self.logs: Dict[str, Log] = {}
self.result: ResultData | None = None
self.results: dict[str, Any] = {}
self.outputs_logs: dict[str, OutputValue] = {}
self.logs: dict[str, Log] = {}
try:
self.is_interface_component = self.vertex_type in InterfaceComponentTypes
except ValueError:
self.is_interface_component = False
self.use_result = False
self.build_times: List[float] = []
self.build_times: list[float] = []
self.state = VertexStates.ACTIVE
def set_input_value(self, name: str, value: Any):
@ -168,31 +169,31 @@ class Vertex:
pass
@property
def edges(self) -> List["CycleEdge"]:
def edges(self) -> list["CycleEdge"]:
return self.graph.get_vertex_edges(self.id)
@property
def outgoing_edges(self) -> List["CycleEdge"]:
def outgoing_edges(self) -> list["CycleEdge"]:
return [edge for edge in self.edges if edge.source_id == self.id]
@property
def incoming_edges(self) -> List["CycleEdge"]:
def incoming_edges(self) -> list["CycleEdge"]:
return [edge for edge in self.edges if edge.target_id == self.id]
@property
def edges_source_names(self) -> Set[str | None]:
def edges_source_names(self) -> set[str | None]:
return {edge.source_handle.name for edge in self.edges}
@property
def predecessors(self) -> List["Vertex"]:
def predecessors(self) -> list["Vertex"]:
return self.graph.get_predecessors(self)
@property
def successors(self) -> List["Vertex"]:
def successors(self) -> list["Vertex"]:
return self.graph.get_successors(self)
@property
def successors_ids(self) -> List[str]:
def successors_ids(self) -> list[str]:
return self.graph.successor_map.get(self.id, [])
def __getstate__(self):
@ -208,7 +209,7 @@ class Vertex:
self._built_object = state.get("_built_object") or UnbuiltObject()
self._built_result = state.get("_built_result") or UnbuiltResult()
def set_top_level(self, top_level_vertices: List[str]) -> None:
def set_top_level(self, top_level_vertices: list[str]) -> None:
self.parent_is_top_level = self.parent_node_id in top_level_vertices
def _parse_data(self) -> None:
@ -478,7 +479,7 @@ class Vertex:
self._built = True
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[dict]:
def extract_messages_from_artifacts(self, artifacts: dict[str, Any]) -> list[dict]:
"""
Extracts messages from the artifacts.
@ -565,7 +566,7 @@ class Vertex:
async def _build_dict_and_update_params(
self,
key,
vertices_dict: Dict[str, "Vertex"],
vertices_dict: dict[str, "Vertex"],
):
"""
Iterates over a dictionary of vertices, builds each and updates the params dictionary.
@ -589,7 +590,7 @@ class Vertex:
"""
return all(self._is_vertex(vertex) for vertex in value)
async def get_result(self, requester: "Vertex", target_handle_name: Optional[str] = None) -> Any:
async def get_result(self, requester: "Vertex", target_handle_name: str | None = None) -> Any:
"""
Retrieves the result of the vertex.
@ -601,7 +602,7 @@ class Vertex:
async with self._lock:
return await self._get_result(requester, target_handle_name)
async def _get_result(self, requester: "Vertex", target_handle_name: Optional[str] = None) -> Any:
async def _get_result(self, requester: "Vertex", target_handle_name: str | None = None) -> Any:
"""
Retrieves the result of the built component.
@ -635,7 +636,7 @@ class Vertex:
async def _build_list_of_vertices_and_update_params(
self,
key,
vertices: List["Vertex"],
vertices: list["Vertex"],
):
"""
Iterates over a list of vertices, builds each and updates the params dictionary.
@ -737,7 +738,7 @@ class Vertex:
if self.display_name in ["Text Output"]:
raise ValueError(f"You are trying to stream to a {self.display_name}. Try using a Chat Output instead.")
def _reset(self, params_update: Optional[Dict[str, Any]] = None):
def _reset(self, params_update: dict[str, Any] | None = None):
self._built = False
self._built_object = UnbuiltObject()
self._built_result = UnbuiltResult()
@ -757,8 +758,8 @@ class Vertex:
async def build(
self,
user_id=None,
inputs: Optional[Dict[str, Any]] = None,
files: Optional[list[str]] = None,
inputs: dict[str, Any] | None = None,
files: list[str] | None = None,
requester: Optional["Vertex"] = None,
**kwargs,
) -> Any:

View file

@ -1,5 +1,3 @@
from typing import Dict
from typing_extensions import NotRequired, TypedDict
@ -10,7 +8,7 @@ class Position(TypedDict):
class NodeData(TypedDict):
id: str
data: Dict
data: dict
dragging: NotRequired[bool]
height: NotRequired[int]
width: NotRequired[int]

View file

@ -1,6 +1,7 @@
import asyncio
import json
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Generator, Iterator, List, cast
from typing import TYPE_CHECKING, Any, cast
from collections.abc import AsyncIterator, Generator, Iterator
import yaml
from langchain_core.messages import AIMessage, AIMessageChunk
@ -141,7 +142,7 @@ class ComponentVertex(Vertex):
asyncio.create_task(log_transaction(source=self, target=requester, flow_id=str(flow_id), status="success"))
return result
def extract_messages_from_artifacts(self, artifacts: Dict[str, Any]) -> List[dict]:
def extract_messages_from_artifacts(self, artifacts: dict[str, Any]) -> list[dict]:
"""
Extracts messages from the artifacts.
@ -454,7 +455,7 @@ class StateVertex(ComponentVertex):
self.is_state = False
@property
def successors_ids(self) -> List[str]:
def successors_ids(self) -> list[str]:
if self._successors_ids is None:
self.is_state = False
return super().successors_ids

View file

@ -6,7 +6,8 @@ import shutil
import tempfile
from contextlib import contextmanager, suppress
from pathlib import Path
from typing import TYPE_CHECKING, AsyncGenerator
from typing import TYPE_CHECKING
from collections.abc import AsyncGenerator
import orjson
import pytest
@ -155,7 +156,7 @@ def get_graph(_type="basic"):
elif _type == "openapi":
path = pytest.OPENAPI_EXAMPLE_PATH
with open(path, "r") as f:
with open(path) as f:
flow_graph = json.load(f)
data_graph = flow_graph["data"]
nodes = data_graph["nodes"]
@ -167,7 +168,7 @@ def get_graph(_type="basic"):
@pytest.fixture
def basic_graph_data():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
with open(pytest.BASIC_EXAMPLE_PATH) as f:
return json.load(f)
@ -188,55 +189,55 @@ def openapi_graph():
@pytest.fixture
def json_flow():
with open(pytest.BASIC_EXAMPLE_PATH, "r") as f:
with open(pytest.BASIC_EXAMPLE_PATH) as f:
return f.read()
@pytest.fixture
def grouped_chat_json_flow():
with open(pytest.GROUPED_CHAT_EXAMPLE_PATH, "r") as f:
with open(pytest.GROUPED_CHAT_EXAMPLE_PATH) as f:
return f.read()
@pytest.fixture
def one_grouped_chat_json_flow():
with open(pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH, "r") as f:
with open(pytest.ONE_GROUPED_CHAT_EXAMPLE_PATH) as f:
return f.read()
@pytest.fixture
def vector_store_grouped_json_flow():
with open(pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH, "r") as f:
with open(pytest.VECTOR_STORE_GROUPED_EXAMPLE_PATH) as f:
return f.read()
@pytest.fixture
def json_flow_with_prompt_and_history():
with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY, "r") as f:
with open(pytest.BASIC_CHAT_WITH_PROMPT_AND_HISTORY) as f:
return f.read()
@pytest.fixture
def json_simple_api_test():
with open(pytest.SIMPLE_API_TEST, "r") as f:
with open(pytest.SIMPLE_API_TEST) as f:
return f.read()
@pytest.fixture
def json_vector_store():
with open(pytest.VECTOR_STORE_PATH, "r") as f:
with open(pytest.VECTOR_STORE_PATH) as f:
return f.read()
@pytest.fixture
def json_webhook_test():
with open(pytest.WEBHOOK_TEST, "r") as f:
with open(pytest.WEBHOOK_TEST) as f:
return f.read()
@pytest.fixture
def json_memory_chatbot_no_llm():
with open(pytest.MEMORY_CHATBOT_NO_LLM, "r") as f:
with open(pytest.MEMORY_CHATBOT_NO_LLM) as f:
return f.read()
@ -345,13 +346,13 @@ def flow(client, json_flow: str, active_user):
@pytest.fixture
def json_chat_input():
with open(pytest.CHAT_INPUT, "r") as f:
with open(pytest.CHAT_INPUT) as f:
return f.read()
@pytest.fixture
def json_two_outputs():
with open(pytest.TWO_OUTPUTS, "r") as f:
with open(pytest.TWO_OUTPUTS) as f:
return f.read()

View file

@ -1,4 +1,5 @@
from typing import Sequence, Union
from typing import Union
from collections.abc import Sequence
import pytest
from pydantic import ValidationError