Refactor code and update dependencies

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-08 08:46:17 -03:00
commit 103dda198a
9 changed files with 135 additions and 82 deletions

View file

@ -54,10 +54,10 @@ class APIRequest(CustomComponent):
raise ValueError(f"Unsupported method: {method}")
data = body if body else None
data = json.dumps(data)
payload = json.dumps(data)
try:
response = await client.request(
method, url, headers=headers, content=data, timeout=timeout
method, url, headers=headers, content=payload, timeout=timeout
)
try:
result = response.json()
@ -93,14 +93,13 @@ class APIRequest(CustomComponent):
async def build(
self,
method: str,
url: List[str],
urls: List[str],
headers: Optional[dict] = None,
body: Optional[List[Record]] = None,
timeout: int = 5,
) -> List[Record]:
if headers is None:
headers = {}
urls = url if isinstance(url, list) else [url]
bodies = []
if body:
if isinstance(body, list):

View file

@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional
from langflow import CustomComponent
from langflow.schema import Record
@ -62,7 +62,10 @@ class TextToRecordComponent(CustomComponent):
build_config[field.name] = field.to_dict()
def update_build_config(
self, build_config: dict, field_name: str, field_value: Any
self,
build_config: dict,
field_value: Any,
field_name: Optional[str] = None,
):
if field_name == "mode":
build_config["mode"]["value"] = field_value

View file

@ -1,5 +1,5 @@
import uuid
from typing import Any, Text
from typing import Any, Optional
from langflow import CustomComponent
@ -10,7 +10,10 @@ class UUIDGeneratorComponent(CustomComponent):
description = "Generates a unique ID."
def update_build_config(
self, build_config: dict, field_name: Text, field_value: Any
self,
build_config: dict,
field_value: Any,
field_name: Optional[str] = None,
):
if field_name == "unique_id":
build_config[field_name]["value"] = str(uuid.uuid4())

View file

@ -3,7 +3,6 @@ from collections import defaultdict, deque
from itertools import chain
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
from langchain.chains.base import Chain
from loguru import logger
from langflow.graph.edge.base import ContractEdge
@ -22,7 +21,6 @@ from langflow.graph.vertex.types import (
)
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.schema import Record
from langflow.utils import payload
if TYPE_CHECKING:
from langflow.graph.schema import ResultData
@ -62,7 +60,7 @@ class Graph:
self.activated_vertices: List[str] = []
self.vertices_layers: List[List[str]] = []
self.vertices_to_run: set[str] = set()
self.stop_vertex = None
self.stop_vertex: Optional[str] = None
self.inactive_vertices: set = set()
self.edges: List[ContractEdge] = []
@ -196,7 +194,7 @@ class Graph:
async def run(
self,
inputs: list[Dict[str, Union[str, list[str]]]],
outputs: list[str],
outputs: Optional[list[str]] = None,
session_id: Optional[str] = None,
stream: bool = False,
) -> List[List[Optional["ResultData"]]]:
@ -210,11 +208,26 @@ class Graph:
if not isinstance(inputs, list):
inputs = [inputs]
for input_dict in inputs:
components: list[str] = input_dict.get("components", [])
components: Union[str, list[str]] = input_dict.get("components", [])
if not isinstance(components, list):
components = [components]
if INPUT_FIELD_NAME not in input_dict:
input_value = ""
else:
_input_value = input_dict[INPUT_FIELD_NAME]
if isinstance(_input_value, str):
input_value = _input_value
else:
raise ValueError(
f"Invalid input value: {input_value}. Expected string"
)
run_outputs = await self._run(
inputs={INPUT_FIELD_NAME: input_dict.get(INPUT_FIELD_NAME, "")},
inputs={INPUT_FIELD_NAME: input_value},
input_components=components,
outputs=outputs,
outputs=outputs or [],
stream=stream,
session_id=session_id or "",
)
@ -265,7 +278,9 @@ class Graph:
def build_parent_child_map(self):
parent_child_map = defaultdict(list)
for vertex in self.vertices:
parent_child_map[vertex.id] = [child.id for child in self.get_successors(vertex)]
parent_child_map[vertex.id] = [
child.id for child in self.get_successors(vertex)
]
return parent_child_map
def increment_run_count(self):
@ -296,7 +311,7 @@ class Graph:
return predecessor_map, successor_map
@classmethod
def from_payload(cls, payload: Dict, flow_id: str) -> "Graph":
def from_payload(cls, payload: Dict, flow_id: Optional[str] = None) -> "Graph":
"""
Creates a graph from a payload.
@ -479,7 +494,11 @@ class Graph:
return
self.vertices.remove(vertex)
self.vertex_map.pop(vertex_id)
self.edges = [edge for edge in self.edges if edge.source_id != vertex_id and edge.target_id != vertex_id]
self.edges = [
edge
for edge in self.edges
if edge.source_id != vertex_id and edge.target_id != vertex_id
]
def _build_vertex_params(self) -> None:
"""Identifies and handles the LLM vertex within the graph."""
@ -500,7 +519,9 @@ class Graph:
return
for vertex in self.vertices:
if not self._validate_vertex(vertex):
raise ValueError(f"{vertex.display_name} is not connected to any other components")
raise ValueError(
f"{vertex.display_name} is not connected to any other components"
)
def _validate_vertex(self, vertex: Vertex) -> bool:
"""Validates a vertex."""
@ -541,18 +562,10 @@ class Graph:
vertices.append(vertex)
return vertices
async def build(self) -> Chain:
"""Builds the graph."""
# Get root vertex
root_vertex = payload.get_root_vertex(self)
if root_vertex is None:
raise ValueError("No root vertex found")
return await root_vertex.build()
async def process(self) -> "Graph":
"""Processes the graph with vertices in each layer run in parallel."""
vertices_layers = self.sorted_vertices_layers
vertex_task_run_count = {}
vertex_task_run_count: Dict[str, int] = {}
for layer_index, layer in enumerate(vertices_layers):
tasks = []
for vertex_id in layer:
@ -606,7 +619,9 @@ class Graph:
def dfs(vertex):
if state[vertex] == 1:
# We have a cycle
raise ValueError("Graph contains a cycle, cannot perform topological sort")
raise ValueError(
"Graph contains a cycle, cannot perform topological sort"
)
if state[vertex] == 0:
state[vertex] = 1
for edge in vertex.edges:
@ -630,7 +645,10 @@ class Graph:
def get_predecessors(self, vertex):
"""Returns the predecessors of a vertex."""
return [self.get_vertex(source_id) for source_id in self.predecessor_map.get(vertex.id, [])]
return [
self.get_vertex(source_id)
for source_id in self.predecessor_map.get(vertex.id, [])
]
def get_all_successors(self, vertex, recursive=True, flat=True):
# Recursively get the successors of the current vertex
@ -671,7 +689,10 @@ class Graph:
def get_successors(self, vertex):
"""Returns the successors of a vertex."""
return [self.get_vertex(target_id) for target_id in self.successor_map.get(vertex.id, [])]
return [
self.get_vertex(target_id)
for target_id in self.successor_map.get(vertex.id, [])
]
def get_vertex_neighbors(self, vertex: Vertex) -> Dict[Vertex, int]:
"""Returns the neighbors of a vertex."""
@ -717,7 +738,9 @@ class Graph:
edges_added.add((source.id, target.id))
return edges
def _get_vertex_class(self, node_type: str, node_base_type: str, node_id: str) -> Type[Vertex]:
def _get_vertex_class(
self, node_type: str, node_base_type: str, node_id: str
) -> Type[Vertex]:
"""Returns the node class based on the node type."""
# First we check for the node_base_type
node_name = node_id.split("-")[0]
@ -750,14 +773,18 @@ class Graph:
vertex_type: str = vertex_data["type"] # type: ignore
vertex_base_type: str = vertex_data["node"]["template"]["_type"] # type: ignore
VertexClass = self._get_vertex_class(vertex_type, vertex_base_type, vertex_data["id"])
VertexClass = self._get_vertex_class(
vertex_type, vertex_base_type, vertex_data["id"]
)
vertex_instance = VertexClass(vertex, graph=self)
vertex_instance.set_top_level(self.top_level_vertices)
vertices.append(vertex_instance)
return vertices
def get_children_by_vertex_type(self, vertex: Vertex, vertex_type: str) -> List[Vertex]:
def get_children_by_vertex_type(
self, vertex: Vertex, vertex_type: str
) -> List[Vertex]:
"""Returns the children of a vertex based on the vertex type."""
children = []
vertex_types = [vertex.data["type"]]
@ -769,7 +796,9 @@ class Graph:
def __repr__(self):
vertex_ids = [vertex.id for vertex in self.vertices]
edges_repr = "\n".join([f"{edge.source_id} --> {edge.target_id}" for edge in self.edges])
edges_repr = "\n".join(
[f"{edge.source_id} --> {edge.target_id}" for edge in self.edges]
)
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str, is_start: bool = False) -> List[Vertex]:
@ -912,7 +941,9 @@ class Graph:
return refined_layers
def sort_chat_inputs_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_chat_inputs_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
chat_inputs_first = []
for layer in vertices_layers:
for vertex_id in layer:
@ -934,7 +965,7 @@ class Graph:
) -> List[str]:
"""Sorts the vertices in the graph."""
self.mark_all_vertices("ACTIVE")
if stop_component_id:
if stop_component_id is not None:
self.stop_vertex = stop_component_id
vertices = self.sort_up_to_vertex(stop_component_id)
elif start_component_id:
@ -966,11 +997,15 @@ class Graph:
self.vertices_to_run.remove(vertex_id)
return should_run
def sort_interface_components_first(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_interface_components_first(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices containing ChatInput or ChatOutput come first."""
def contains_interface_component(vertex):
return any(component.value in vertex for component in InterfaceComponentTypes)
return any(
component.value in vertex for component in InterfaceComponentTypes
)
# Sort each inner list so that vertices containing ChatInput or ChatOutput come first
sorted_vertices = [
@ -982,16 +1017,22 @@ class Graph:
]
return sorted_vertices
def sort_by_avg_build_time(self, vertices_layers: List[List[str]]) -> List[List[str]]:
def sort_by_avg_build_time(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
if len(vertices_ids) == 1:
return vertices_ids
vertices_ids.sort(key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time)
vertices_ids.sort(
key=lambda vertex_id: self.get_vertex(vertex_id).avg_build_time
)
return vertices_ids
sorted_vertices = [sort_layer_by_avg_build_time(layer) for layer in vertices_layers]
sorted_vertices = [
sort_layer_by_avg_build_time(layer) for layer in vertices_layers
]
return sorted_vertices

View file

@ -148,8 +148,8 @@ class CustomComponent(Component):
def update_build_config(
self,
build_config: dotdict,
field_name: Optional[str],
field_value: Any,
field_name: Optional[str] = None,
):
build_config[field_name] = field_value
return build_config
@ -390,7 +390,7 @@ class CustomComponent(Component):
raise ValueError(f"Flow {flow_name} not found")
graph = await self.load_flow(flow_id, tweaks)
input_value_dict = {"input_value": input_value}
input_value_dict = [{"input_value": input_value}]
return await graph.run(input_value_dict, stream=False)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]:

View file

@ -43,7 +43,9 @@ def add_output_types(
raise HTTPException(
status_code=400,
detail={
"error": ("Invalid return type. Please check your code and try again."),
"error": (
"Invalid return type. Please check your code and try again."
),
"traceback": traceback.format_exc(),
},
)
@ -75,14 +77,18 @@ def reorder_fields(frontend_node: CustomComponentFrontendNode, field_order: List
frontend_node.field_order = field_order
def add_base_classes(frontend_node: CustomComponentFrontendNode, return_types: List[str]):
def add_base_classes(
frontend_node: CustomComponentFrontendNode, return_types: List[str]
):
"""Add base classes to the frontend node"""
for return_type_instance in return_types:
if return_type_instance is None:
raise HTTPException(
status_code=400,
detail={
"error": ("Invalid return type. Please check your code and try again."),
"error": (
"Invalid return type. Please check your code and try again."
),
"traceback": traceback.format_exc(),
},
)
@ -170,7 +176,9 @@ def add_new_custom_field(
)
if "name" in field_config:
warnings.warn("The 'name' key in field_config is used to build the object and can't be changed.")
warnings.warn(
"The 'name' key in field_config is used to build the object and can't be changed."
)
required = field_config.pop("required", field_required)
placeholder = field_config.pop("placeholder", "")
@ -269,7 +277,9 @@ def run_build_config(
raise HTTPException(
status_code=400,
detail={
"error": ("Invalid type convertion. Please check your code and try again."),
"error": (
"Invalid type convertion. Please check your code and try again."
),
"traceback": traceback.format_exc(),
},
) from exc
@ -387,10 +397,16 @@ def build_custom_component_template(
add_extra_fields(frontend_node, field_config, entrypoint_args)
frontend_node = add_code_field(frontend_node, custom_component.code, field_config.get("code", {}))
frontend_node = add_code_field(
frontend_node, custom_component.code, field_config.get("code", {})
)
add_base_classes(frontend_node, custom_component.get_function_entrypoint_return_type)
add_output_types(frontend_node, custom_component.get_function_entrypoint_return_type)
add_base_classes(
frontend_node, custom_component.get_function_entrypoint_return_type
)
add_output_types(
frontend_node, custom_component.get_function_entrypoint_return_type
)
reorder_fields(frontend_node, custom_instance._get_field_order())
@ -439,7 +455,9 @@ def build_custom_components(components_paths: List[str]):
custom_component_dict = build_custom_component_list_from_path(path_str)
if custom_component_dict:
category = next(iter(custom_component_dict))
logger.info(f"Loading {len(custom_component_dict[category])} component(s) from category {category}")
logger.info(
f"Loading {len(custom_component_dict[category])} component(s) from category {category}"
)
custom_components_from_file = merge_nested_dicts_with_renaming(
custom_components_from_file, custom_component_dict
)
@ -467,7 +485,9 @@ def update_field_dict(
try:
dd_build_config = dotdict(build_config)
custom_component_instance.update_build_config(
dd_build_config, update_field, update_field_value
build_config=dd_build_config,
field_value=update_field,
field_name=update_field_value,
)
build_config = dd_build_config
except Exception as exc:

View file

@ -1,13 +1,14 @@
import asyncio
import json
from pathlib import Path
from typing import Optional, Union
from langflow.graph import Graph
from langflow.processing.process import fix_memory_inputs, process_tweaks
from langflow.processing.process import process_tweaks
def load_flow_from_json(flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True):
def load_flow_from_json(
flow: Union[Path, str, dict], tweaks: Optional[dict] = None
) -> Graph:
"""
Load flow from a JSON file or a JSON object.
@ -24,29 +25,13 @@ def load_flow_from_json(flow: Union[Path, str, dict], tweaks: Optional[dict] = N
elif isinstance(flow, dict):
flow_graph = flow
else:
raise TypeError("Input must be either a file path (str) or a JSON object (dict)")
raise TypeError(
"Input must be either a file path (str) or a JSON object (dict)"
)
graph_data = flow_graph["data"]
if tweaks is not None:
graph_data = process_tweaks(graph_data, tweaks)
nodes = graph_data["nodes"]
edges = graph_data["edges"]
graph = Graph(nodes, edges)
if build:
langchain_object = asyncio.run(graph.build())
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True
if hasattr(langchain_object, "return_intermediate_steps"):
# Deactivating until we have a frontend solution
# to display intermediate steps
langchain_object.return_intermediate_steps = False
fix_memory_inputs(langchain_object)
return langchain_object
return graph
graph = Graph.from_payload(graph_data)
return graph

View file

@ -224,7 +224,7 @@ async def run_graph(
if inputs is None:
inputs = [{}]
outputs = await graph.run(
run_outputs = await graph.run(
inputs,
outputs or [],
stream=stream,
@ -232,7 +232,7 @@ async def run_graph(
)
if session_id and session_service:
session_service.update_session(session_id, (graph, artifacts))
return outputs, session_id
return run_outputs, session_id
def validate_input(

View file

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Union
from typing import TYPE_CHECKING, Any, Callable, Coroutine
from loguru import logger
@ -74,11 +74,13 @@ class TaskService(Service):
result = await result
return task.id, result
async def launch_task(self, task_func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
async def launch_task(
self, task_func: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
logger.debug(f"Launching task {task_func} with args {args} and kwargs {kwargs}")
logger.debug(f"Using backend {self.backend}")
task = self.backend.launch_task(task_func, *args, **kwargs)
return await task if isinstance(task, Coroutine) else task
def get_task(self, task_id: Union[int, str]) -> Any:
def get_task(self, task_id: str) -> Any:
return self.backend.get_task(task_id)