Makes build method async to support async in CC

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-11-23 09:36:15 -03:00
commit 842ba2835a
14 changed files with 189 additions and 184 deletions

View file

@ -8,21 +8,20 @@ from fastapi import (
status,
)
from fastapi.responses import StreamingResponse
from loguru import logger
from sqlmodel import Session
from langflow.api.utils import build_input_keys_response
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import (
get_current_active_user,
get_current_user_by_jwt,
)
from langflow.services.cache.utils import update_build_status
from loguru import logger
from langflow.services.deps import get_chat_service, get_session, get_cache_service
from sqlmodel import Session
from langflow.services.chat.service import ChatService
from langflow.services.cache.service import BaseCacheService
from langflow.services.cache.utils import update_build_status
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_cache_service, get_chat_service, get_session
router = APIRouter(tags=["Chat"])
@ -164,9 +163,9 @@ async def stream_build(
}
yield str(StreamData(event="log", data=log_dict))
if vertex.is_task:
vertex = try_running_celery_task(vertex, user_id)
vertex = await try_running_celery_task(vertex, user_id)
else:
vertex.build(user_id=user_id)
await vertex.build(user_id=user_id)
params = vertex._built_object_repr()
valid = True
logger.debug(f"Building node {str(vertex.vertex_type)}")
@ -193,7 +192,7 @@ async def stream_build(
yield str(StreamData(event="message", data=response))
langchain_object = graph.build()
langchain_object = await graph.build()
# Now we need to check the input_keys to send them to the client
if hasattr(langchain_object, "input_keys"):
input_keys_response = build_input_keys_response(langchain_object, artifacts)
@ -224,7 +223,7 @@ async def stream_build(
raise HTTPException(status_code=500, detail=str(exc))
def try_running_celery_task(vertex, user_id):
async def try_running_celery_task(vertex, user_id):
# Try running the task in celery
# and set the task_id to the local vertex
# if it fails, run the task locally
@ -236,5 +235,5 @@ def try_running_celery_task(vertex, user_id):
except Exception as exc:
logger.debug(f"Error running task in celery: {exc}")
vertex.task_id = None
vertex.build(user_id=user_id)
await vertex.build(user_id=user_id)
return vertex

View file

@ -1,6 +1,8 @@
from typing import Dict, Generator, List, Type, Union
from langchain.chains.base import Chain
from loguru import logger
from langflow.graph.edge.base import Edge
from langflow.graph.graph.constants import lazy_load_vertex_dict
from langflow.graph.graph.utils import process_flow
@ -8,7 +10,6 @@ from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import FileToolVertex, LLMVertex, ToolkitVertex
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.utils import payload
from loguru import logger
class Graph:
@ -116,13 +117,13 @@ class Graph:
connected_nodes: List[Vertex] = [edge.source for edge in self.edges if edge.target == node]
return connected_nodes
def build(self) -> Chain:
async def build(self) -> Chain:
"""Builds the graph."""
# Get root node
root_node = payload.get_root_node(self)
if root_node is None:
raise ValueError("No root node found")
return root_node.build()
return await root_node.build()
def topological_sort(self) -> List[Vertex]:
"""

View file

@ -1,20 +1,18 @@
import ast
import inspect
import pickle
import types
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from loguru import logger
from langflow.graph.utils import UnbuiltObject
from langflow.graph.vertex.utils import is_basic_type
from langflow.interface.initialize import loading
from langflow.interface.listing import lazy_load_dict
from langflow.utils.constants import DIRECT_TYPES
from loguru import logger
from langflow.utils.util import sync_to_async
import inspect
import types
from typing import Any, Dict, List, Optional
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from langflow.graph.edge.base import Edge
@ -216,18 +214,18 @@ class Vertex:
self._raw_params = params
self.params = params
def _build(self, user_id=None):
async def _build(self, user_id=None):
"""
Initiate the build process.
"""
logger.debug(f"Building {self.vertex_type}")
self._build_each_node_in_params_dict(user_id)
self._get_and_instantiate_class(user_id)
await self._build_each_node_in_params_dict(user_id)
await self._get_and_instantiate_class(user_id)
self._validate_built_object()
self._built = True
def _build_each_node_in_params_dict(self, user_id=None):
async def _build_each_node_in_params_dict(self, user_id=None):
"""
Iterates over each node in the params dictionary and builds it.
"""
@ -236,9 +234,9 @@ class Vertex:
if value == self:
del self.params[key]
continue
self._build_node_and_update_params(key, value, user_id)
await self._build_node_and_update_params(key, value, user_id)
elif isinstance(value, list) and self._is_list_of_nodes(value):
self._build_list_of_nodes_and_update_params(key, value, user_id)
await self._build_list_of_nodes_and_update_params(key, value, user_id)
def _is_node(self, value):
"""
@ -252,7 +250,7 @@ class Vertex:
"""
return all(self._is_node(node) for node in value)
def get_result(self, user_id=None, timeout=None) -> Any:
async def get_result(self, user_id=None, timeout=None) -> Any:
# Check if the Vertex was built already
if self._built:
return self._built_object
@ -268,27 +266,27 @@ class Vertex:
pass
# If there's no task_id, build the vertex locally
self.build(user_id)
await self.build(user_id)
return self._built_object
def _build_node_and_update_params(self, key, node, user_id=None):
async def _build_node_and_update_params(self, key, node, user_id=None):
"""
Builds a given node and updates the params dictionary accordingly.
"""
result = node.get_result(user_id)
result = await node.get_result(user_id)
self._handle_func(key, result)
if isinstance(result, list):
self._extend_params_list_with_result(key, result)
self.params[key] = result
def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
async def _build_list_of_nodes_and_update_params(self, key, nodes: List["Vertex"], user_id=None):
"""
Iterates over a list of nodes, builds each and updates the params dictionary.
"""
self.params[key] = []
for node in nodes:
built = node.get_result(user_id)
built = await node.get_result(user_id)
if isinstance(built, list):
if key not in self.params:
self.params[key] = []
@ -318,14 +316,14 @@ class Vertex:
if isinstance(self.params[key], list):
self.params[key].extend(result)
def _get_and_instantiate_class(self, user_id=None):
async def _get_and_instantiate_class(self, user_id=None):
"""
Gets the class from a dictionary and instantiates it with the params.
"""
if self.base_type is None:
raise ValueError(f"Base type for node {self.vertex_type} not found")
try:
result = loading.instantiate_class(
result = await loading.instantiate_class(
node_type=self.vertex_type,
base_type=self.base_type,
params=self.params,
@ -358,9 +356,9 @@ class Vertex:
raise ValueError(message)
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
if not self._built or force:
self._build(user_id, *args, **kwargs)
await self._build(user_id, *args, **kwargs)
return self._built_object

View file

@ -1,8 +1,8 @@
import ast
from typing import Any, Dict, List, Optional, Union
from langflow.graph.vertex.base import Vertex
from langflow.graph.utils import flatten_list
from langflow.graph.vertex.base import Vertex
from langflow.interface.utils import extract_input_variables_from_prompt
@ -34,18 +34,18 @@ class AgentVertex(Vertex):
elif isinstance(source_node, ChainVertex):
self.chains.append(source_node)
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
if not self._built or force:
self._set_tools_and_chains()
# First, build the tools
for tool_node in self.tools:
tool_node.build(user_id=user_id)
await tool_node.build(user_id=user_id)
# Next, build the chains and the rest
for chain_node in self.chains:
chain_node.build(tools=self.tools, user_id=user_id)
await chain_node.build(tools=self.tools, user_id=user_id)
self._build(user_id=user_id)
await self._build(user_id=user_id)
return self._built_object
@ -62,13 +62,13 @@ class LLMVertex(Vertex):
def __init__(self, data: Dict, params: Optional[Dict] = None):
super().__init__(data, base_type="llms", params=params)
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
# LLM is different because some models might take up too much memory
# or time to load. So we only load them when we need them.ß
if self.vertex_type == self.built_node_type:
return self.class_built_object
if not self._built or force:
self._build(user_id=user_id)
await self._build(user_id=user_id)
self.built_node_type = self.vertex_type
self.class_built_object = self._built_object
# Avoid deepcopying the LLM
@ -90,11 +90,11 @@ class WrapperVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="wrappers")
def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
async def build(self, force: bool = False, user_id=None, *args, **kwargs) -> Any:
if not self._built or force:
if "headers" in self.params:
self.params["headers"] = ast.literal_eval(self.params["headers"])
self._build(user_id=user_id)
await self._build(user_id=user_id)
return self._built_object
@ -193,7 +193,7 @@ class ChainVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="chains")
def build(
async def build(
self,
force: bool = False,
user_id=None,
@ -212,9 +212,9 @@ class ChainVertex(Vertex):
if isinstance(value, PromptVertex):
# Build the PromptVertex, passing the tools if available
tools = kwargs.get("tools", None)
self.params[key] = value.build(tools=tools, force=force)
self.params[key] = await value.build(tools=tools, force=force)
self._build(user_id=user_id)
await self._build(user_id=user_id)
return self._built_object
@ -223,7 +223,7 @@ class PromptVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="prompts")
def build(
async def build(
self,
force: bool = False,
user_id=None,
@ -236,7 +236,7 @@ class PromptVertex(Vertex):
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
tools = [await tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
@ -257,7 +257,7 @@ class PromptVertex(Vertex):
elif isinstance(self.params, dict):
self.params.pop("input_variables", None)
self._build(user_id=user_id)
await self._build(user_id=user_id)
return self._built_object
def _built_object_repr(self):

View file

@ -3,7 +3,6 @@ from uuid import UUID
import yaml
from fastapi import HTTPException
from langflow.field_typing.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
from langflow.interface.custom.component import Component
from langflow.interface.custom.directory_reader import DirectoryReader
@ -189,7 +188,7 @@ class CustomComponent(Component):
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
from langflow.processing.process import build_sorted_vertices, process_tweaks
db_service = get_db_service()
@ -199,7 +198,7 @@ class CustomComponent(Component):
raise ValueError(f"Flow {flow_id} not found")
if tweaks:
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
return build_sorted_vertices(graph_data, self.user_id)
return await build_sorted_vertices(graph_data, self.user_id)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
if not self.user_id:

View file

@ -1,3 +1,4 @@
import inspect
import json
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Type
@ -10,9 +11,6 @@ from langchain.chains.base import Chain
from langchain.document_loaders.base import BaseLoader
from langchain.schema import Document
from langchain.vectorstores.base import VectorStore
from loguru import logger
from pydantic import ValidationError
from langflow.interface.custom_lists import CUSTOM_NODES
from langflow.interface.importing.utils import eval_custom_component_code, get_function, import_by_type
from langflow.interface.initialize.llm import initialize_vertexai
@ -24,6 +22,8 @@ from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.utils import load_file_into_dict
from langflow.interface.wrappers.base import wrapper_creator
from langflow.utils import validate
from loguru import logger
from pydantic import ValidationError
if TYPE_CHECKING:
from langflow import CustomComponent
@ -36,7 +36,7 @@ def build_vertex_in_params(params: Dict) -> Dict:
return {key: value.build() if isinstance(value, Vertex) else value for key, value in params.items()}
def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any:
async def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None) -> Any:
"""Instantiate class from module type and key, and params"""
params = convert_params_to_sets(params)
params = convert_kwargs(params)
@ -48,7 +48,7 @@ def instantiate_class(node_type: str, base_type: str, params: Dict, user_id=None
return custom_node(**params)
logger.debug(f"Instantiating {node_type} of type {base_type}")
class_object = import_by_type(_type=base_type, name=node_type)
return instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id)
return await instantiate_based_on_type(class_object, base_type, node_type, params, user_id=user_id)
def convert_params_to_sets(params):
@ -75,7 +75,7 @@ def convert_kwargs(params):
return params
def instantiate_based_on_type(class_object, base_type, node_type, params, user_id):
async def instantiate_based_on_type(class_object, base_type, node_type, params, user_id):
if base_type == "agents":
return instantiate_agent(node_type, class_object, params)
elif base_type == "prompts":
@ -109,20 +109,28 @@ def instantiate_based_on_type(class_object, base_type, node_type, params, user_i
elif base_type == "memory":
return instantiate_memory(node_type, class_object, params)
elif base_type == "custom_components":
return instantiate_custom_component(node_type, class_object, params, user_id)
return await instantiate_custom_component(node_type, class_object, params, user_id)
elif base_type == "wrappers":
return instantiate_wrapper(node_type, class_object, params)
else:
return class_object(**params)
def instantiate_custom_component(node_type, class_object, params, user_id):
# we need to make a copy of the params because we will be
# modifying it
async def instantiate_custom_component(node_type, class_object, params, user_id):
params_copy = params.copy()
class_object: "CustomComponent" = eval_custom_component_code(params_copy.pop("code"))
custom_component = class_object(user_id=user_id)
built_object = custom_component.build(**params_copy)
# Determine if the build method is asynchronous
is_async = inspect.iscoroutinefunction(custom_component.build)
if is_async:
# Await the build method directly if it's async
built_object = await custom_component.build(**params_copy)
else:
# Call the build method directly if it's sync
built_object = custom_component.build(**params_copy)
return built_object, {"repr": custom_component.custom_repr()}

View file

@ -1,10 +1,12 @@
from typing import Dict, Tuple, Optional, Union
from langflow.graph import Graph
from loguru import logger
from typing import Dict, Optional, Tuple, Union
from uuid import UUID
from loguru import logger
def build_sorted_vertices(data_graph, user_id: Optional[Union[str, UUID]] = None) -> Tuple[Graph, Dict]:
from langflow.graph import Graph
async def build_sorted_vertices(data_graph, user_id: Optional[Union[str, UUID]] = None) -> Tuple[Graph, Dict]:
"""
Build langchain object from data_graph.
"""
@ -14,28 +16,12 @@ def build_sorted_vertices(data_graph, user_id: Optional[Union[str, UUID]] = None
sorted_vertices = graph.topological_sort()
artifacts = {}
for vertex in sorted_vertices:
vertex.build(user_id=user_id)
await vertex.build(user_id=user_id)
if vertex.artifacts:
artifacts.update(vertex.artifacts)
return graph, artifacts
def build_langchain_object(data_graph):
"""
Build langchain object from data_graph.
"""
logger.debug("Building langchain object")
nodes = data_graph["nodes"]
# Add input variables
# nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
return graph.build()
def get_memory_key(langchain_object):
"""
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.

View file

@ -1,19 +1,15 @@
import asyncio
import json
from pathlib import Path
from langchain.schema import AgentAction
from langflow.interface.run import (
build_sorted_vertices,
get_memory_key,
update_memory_keys,
)
from typing import Any, Dict, List, Optional, Tuple, Union
from langchain.chains.base import Chain
from langchain.schema import AgentAction, Document
from langchain.vectorstores.base import VectorStore
from langflow.graph import Graph
from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys
from langflow.services.deps import get_session_service
from loguru import logger
from langflow.graph import Graph
from langchain.chains.base import Chain
from langchain.vectorstores.base import VectorStore
from typing import Any, Dict, List, Optional, Tuple, Union
from langchain.schema import Document
from pydantic import BaseModel
@ -164,8 +160,8 @@ async def process_graph_cached(
if session_id is None:
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
# Load the graph using SessionService
graph, artifacts = session_service.load_session(session_id, data_graph)
built_object = graph.build()
graph, artifacts = await session_service.load_session(session_id, data_graph)
built_object = await graph.build()
processed_inputs = process_inputs(inputs, artifacts)
result = generate_result(built_object, processed_inputs)
# langchain_object is now updated with the new memory
@ -202,7 +198,7 @@ def load_flow_from_json(flow: Union[Path, str, dict], tweaks: Optional[dict] = N
graph = Graph(nodes, edges)
if build:
langchain_object = graph.build()
langchain_object = asyncio.run(graph.build())
if hasattr(langchain_object, "verbose"):
langchain_object.verbose = True

View file

@ -1,4 +1,5 @@
from typing import TYPE_CHECKING
from langflow.interface.run import build_sorted_vertices
from langflow.services.base import Service
from langflow.services.cache.utils import compute_dict_hash
@ -14,7 +15,7 @@ class SessionService(Service):
def __init__(self, cache_service):
self.cache_service: "BaseCacheService" = cache_service
def load_session(self, key, data_graph):
async def load_session(self, key, data_graph):
# Check if the data is cached
if key in self.cache_service:
return self.cache_service.get(key)
@ -23,7 +24,7 @@ class SessionService(Service):
key = self.generate_key(session_id=None, data_graph=data_graph)
# If not cached, build the graph and cache it
graph, artifacts = build_sorted_vertices(data_graph)
graph, artifacts = await build_sorted_vertices(data_graph)
self.cache_service.set(key, (graph, artifacts))

View file

@ -1,13 +1,9 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from asgiref.sync import async_to_sync
from celery.exceptions import SoftTimeLimitExceeded # type: ignore
from langflow.core.celery_app import celery_app
from langflow.processing.process import (
Result,
generate_result,
process_inputs,
)
from langflow.processing.process import Result, generate_result, process_inputs
from langflow.services.deps import get_session_service
from langflow.services.manager import initialize_session_service
@ -27,7 +23,7 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex":
"""
try:
vertex.task_id = self.request.id
vertex.build()
async_to_sync(vertex.build)()
return vertex
except SoftTimeLimitExceeded as e:
raise self.retry(exc=SoftTimeLimitExceeded("Task took too long"), countdown=2) from e
@ -47,7 +43,7 @@ def process_graph_cached_task(
if session_id is None:
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
# Load the graph using SessionService
graph, artifacts = session_service.load_session(session_id, data_graph)
graph, artifacts = async_to_sync(session_service.load_session)(session_id, data_graph)
built_object = graph.build()
processed_inputs = process_inputs(inputs, artifacts)
result = generate_result(built_object, processed_inputs)