Merge branch 'zustand/io/migration' of github.com:logspace-ai/langflow into zustand/io/migration

This commit is contained in:
igorrCarvalho 2024-02-27 14:09:48 -03:00
commit 952561d628
29 changed files with 317 additions and 103 deletions

View file

@ -3,11 +3,15 @@ from typing import Annotated, Any, List, Optional, Union
import sqlalchemy as sa
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
from loguru import logger
from sqlmodel import select
from langflow.api.utils import update_frontend_node_with_template_values
from langflow.api.v1.schemas import (
CustomComponentCode,
PreloadResponse,
ProcessResponse,
RunResponse,
TaskResponse,
TaskStatusResponse,
UploadFileResponse,
@ -15,15 +19,23 @@ from langflow.api.v1.schemas import (
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.interface.custom.utils import build_custom_component_template
from langflow.processing.process import build_graph_and_generate_result, process_graph_cached, process_tweaks
from langflow.processing.process import (
build_graph_and_generate_result,
process_graph_cached,
process_tweaks,
run_graph,
)
from langflow.services.auth.utils import api_key_security, get_current_active_user
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
from langflow.services.database.models.user.model import User
from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service
from langflow.services.deps import (
get_session,
get_session_service,
get_settings_service,
get_task_service,
)
from langflow.services.session.service import SessionService
from loguru import logger
from sqlmodel import select
try:
from langflow.worker import process_graph_cached_task
@ -33,9 +45,10 @@ except ImportError:
raise NotImplementedError("Celery is not installed")
from langflow.services.task.service import TaskService
from sqlmodel import Session
from langflow.services.task.service import TaskService
# build router
router = APIRouter(tags=["Base"])
@ -80,9 +93,15 @@ async def process_graph_data(
)
if session_id is None:
# Generate a session ID
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
session_id = get_session_service().generate_key(
session_id=session_id, data_graph=graph_data
)
task_id, task = await task_service.launch_task(
process_graph_cached_task if task_service.use_celery else process_graph_cached,
(
process_graph_cached_task
if task_service.use_celery
else process_graph_cached
),
graph_data,
inputs,
clear_cache,
@ -176,7 +195,11 @@ async def preload_flow(
else:
if session_id is None:
session_id = flow_id
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
@ -197,6 +220,76 @@ async def preload_flow(
raise HTTPException(status_code=500, detail=str(exc)) from exc
@router.post("/run/{flow_id}", response_model=ProcessResponse)
async def run_flow_with_caching(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
inputs: Optional[Union[List[dict], dict]] = None,
tweaks: Optional[dict] = None,
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
api_key_user: User = Depends(api_key_security),
session_service: SessionService = Depends(get_session_service),
):
try:
if session_id:
session_data = await session_service.load_session(session_id)
graph, artifacts = session_data if session_data else (None, None)
task_result: Any = None
task_status = None
if not graph:
raise ValueError("Graph not found in the session")
task_result = await run_graph(
graph,
session_id,
inputs,
artifacts=artifacts,
session_service=session_service,
)
else:
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
if flow.data is None:
raise ValueError(f"Flow {flow_id} has no data")
graph_data = flow.data
graph_data = process_tweaks(graph_data, tweaks)
task_result = await run_graph(
graph_data,
inputs,
tweaks,
session_id,
session_service=session_service,
)
return RunResponse(
outputs=task_result, session_id=session_id, status=task_status
)
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
@router.post(
"/predict/{flow_id}",
response_model=ProcessResponse,
@ -269,7 +362,11 @@ async def process(
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
@ -289,12 +386,18 @@ async def process(
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
else:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
except Exception as e:
# Log stack trace
logger.exception(e)
@ -364,12 +467,16 @@ async def custom_component(
built_frontend_node = build_custom_component_template(component, user_id=user.id)
built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node)
built_frontend_node = update_frontend_node_with_template_values(
built_frontend_node, raw_code.frontend_node
)
return built_frontend_node
@router.post("/custom_component/reload", status_code=HTTPStatus.OK)
async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)):
async def reload_custom_component(
path: str, user: User = Depends(get_current_active_user)
):
from langflow.interface.custom.utils import build_custom_component_template
try:
@ -391,6 +498,8 @@ async def custom_component_update(
):
component = CustomComponent(code=raw_code.code)
component_node = build_custom_component_template(component, user_id=user.id, update_field=raw_code.field)
component_node = build_custom_component_template(
component, user_id=user.id, update_field=raw_code.field
)
# Update the field
return component_node

View file

@ -66,6 +66,14 @@ class ProcessResponse(BaseModel):
backend: Optional[str] = None
class RunResponse(BaseModel):
"""Run response schema."""
outputs: Optional[List[Any]] = None
status: Optional[str] = None
session_id: Optional[str] = None
class PreloadResponse(BaseModel):
"""Preload response schema."""
@ -73,9 +81,6 @@ class PreloadResponse(BaseModel):
is_clear: Optional[bool] = None
# TaskStatusResponse(
# status=task.status, result=task.result if task.ready() else None
# )
class TaskStatusResponse(BaseModel):
"""Task status response schema."""

View file

@ -23,7 +23,7 @@ class ConversationChainComponent(CustomComponent):
def build(
self,
inputs: str,
input_value: str,
llm: BaseLanguageModel,
memory: Optional[BaseMemory] = None,
) -> Text:

View file

@ -18,7 +18,7 @@ class LLMCheckerChainComponent(CustomComponent):
def build(
self,
inputs: str,
input_value: str,
llm: BaseLanguageModel,
) -> Text:

View file

@ -24,7 +24,7 @@ class LLMMathChainComponent(CustomComponent):
def build(
self,
inputs: Text,
input_value: Text,
llm: BaseLanguageModel,
llm_chain: LLMChain,
input_key: str = "question",

View file

@ -27,7 +27,7 @@ class RetrievalQAComponent(CustomComponent):
self,
combine_documents_chain: BaseCombineDocumentsChain,
retriever: BaseRetriever,
inputs: str = "",
input_value: str = "",
memory: Optional[BaseMemory] = None,
input_key: str = "query",
output_key: str = "result",

View file

@ -26,7 +26,7 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
def build(
self,
inputs: str,
input_value: str,
retriever: BaseRetriever,
llm: BaseLanguageModel,
chain_type: str,

View file

@ -28,7 +28,7 @@ class SQLGeneratorComponent(CustomComponent):
def build(
self,
inputs: Text,
input_value: Text,
db: SQLDatabase,
llm: BaseLanguageModel,
top_k: int = 5,

View file

@ -11,7 +11,7 @@ class ChatInput(CustomComponent):
def build_config(self):
return {
"message": {
"input_value": {
"input_types": ["Text"],
"display_name": "Message",
"multiline": True,
@ -35,26 +35,26 @@ class ChatInput(CustomComponent):
self,
sender: Optional[str] = "User",
sender_name: Optional[str] = "User",
message: Optional[str] = None,
input_value: Optional[str] = None,
session_id: Optional[str] = None,
return_record: Optional[bool] = False,
) -> Union[Text, Record]:
if return_record:
if isinstance(message, Record):
if isinstance(input_value, Record):
# Update the data of the record
message.data["sender"] = sender
message.data["sender_name"] = sender_name
message.data["session_id"] = session_id
input_value.data["sender"] = sender
input_value.data["sender_name"] = sender_name
input_value.data["session_id"] = session_id
else:
message = Record(
text=message,
input_value = Record(
text=input_value,
data={
"sender": sender,
"sender_name": sender_name,
"session_id": session_id,
},
)
if not message:
message = ""
self.status = message
return message
if not input_value:
input_value = ""
self.status = input_value
return input_value

View file

@ -17,7 +17,7 @@ class ChatOutput(CustomComponent):
def build_config(self):
return {
"message": {"input_types": ["Text"], "display_name": "Message"},
"input_value": {"input_types": ["Text"], "display_name": "Message"},
"sender": {
"options": ["Machine", "User"],
"display_name": "Sender Type",
@ -39,25 +39,25 @@ class ChatOutput(CustomComponent):
sender: Optional[str] = "Machine",
sender_name: Optional[str] = "AI",
session_id: Optional[str] = None,
message: Optional[str] = None,
input_value: Optional[str] = None,
return_record: Optional[bool] = False,
) -> Union[Text, Record]:
if return_record:
if isinstance(message, Record):
if isinstance(input_value, Record):
# Update the data of the record
message.data["sender"] = sender
message.data["sender_name"] = sender_name
message.data["session_id"] = session_id
input_value.data["sender"] = sender
input_value.data["sender_name"] = sender_name
input_value.data["session_id"] = session_id
else:
message = Record(
text=message,
input_value = Record(
text=input_value,
data={
"sender": sender,
"sender_name": sender_name,
"session_id": session_id,
},
)
if not message:
message = ""
self.status = message
return message
if not input_value:
input_value = ""
self.status = input_value
return input_value

View file

@ -39,7 +39,7 @@ class AmazonBedrockComponent(CustomComponent):
def build(
self,
inputs: str,
input_value: str,
model_id: str = "anthropic.claude-instant-v1",
credentials_profile_name: Optional[str] = None,
region_name: Optional[str] = None,

View file

@ -9,7 +9,9 @@ from langflow.field_typing import Text
class AnthropicLLM(CustomComponent):
display_name: str = "AnthropicModel"
description: str = "Generate text using Anthropic Chat&Completion large language models."
description: str = (
"Generate text using Anthropic Chat&Completion large language models."
)
def build_config(self):
return {
@ -53,7 +55,7 @@ class AnthropicLLM(CustomComponent):
def build(
self,
model: str,
inputs: str,
input_value: str,
anthropic_api_key: Optional[str] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
@ -66,7 +68,9 @@ class AnthropicLLM(CustomComponent):
try:
output = ChatAnthropic(
model_name=model,
anthropic_api_key=(SecretStr(anthropic_api_key) if anthropic_api_key else None),
anthropic_api_key=(
SecretStr(anthropic_api_key) if anthropic_api_key else None
),
max_tokens_to_sample=max_tokens, # type: ignore
temperature=temperature,
anthropic_api_url=api_endpoint,

View file

@ -9,7 +9,9 @@ from langflow import CustomComponent
class AzureChatOpenAIComponent(CustomComponent):
display_name: str = "AzureOpenAI Model"
description: str = "Generate text using LLM model from Azure OpenAI."
documentation: str = "https://python.langchain.com/docs/integrations/llms/azure_openai"
documentation: str = (
"https://python.langchain.com/docs/integrations/llms/azure_openai"
)
beta = False
AZURE_OPENAI_MODELS = [
@ -78,7 +80,7 @@ class AzureChatOpenAIComponent(CustomComponent):
self,
model: str,
azure_endpoint: str,
inputs: str,
input_value: str,
azure_deployment: str,
api_key: str,
api_version: str,

View file

@ -73,7 +73,7 @@ class QianfanChatEndpointComponent(CustomComponent):
def build(
self,
inputs: str,
input_value: str,
model: str = "ERNIE-Bot-turbo",
qianfan_ak: Optional[str] = None,
qianfan_sk: Optional[str] = None,

View file

@ -35,11 +35,13 @@ class CTransformersComponent(CustomComponent):
self,
model: str,
model_file: str,
inputs: str,
input_value: str,
model_type: str,
config: Optional[Dict] = None,
) -> Text:
output = CTransformers(model=model, model_file=model_file, model_type=model_type, config=config)
output = CTransformers(
model=model, model_file=model_file, model_type=model_type, config=config
)
message = output.invoke(inputs)
result = message.content if hasattr(message, "content") else message
self.status = result

View file

@ -34,7 +34,7 @@ class CohereComponent(CustomComponent):
def build(
self,
cohere_api_key: str,
inputs: str,
input_value: str,
max_tokens: int = 256,
temperature: float = 0.75,
) -> Text:

View file

@ -57,7 +57,7 @@ class GoogleGenerativeAIComponent(CustomComponent):
self,
google_api_key: str,
model: str,
inputs: str,
input_value: str,
max_output_tokens: Optional[int] = None,
temperature: float = 0.1,
top_k: Optional[int] = None,

View file

@ -4,7 +4,6 @@ from langchain_community.chat_models.huggingface import ChatHuggingFace
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langflow import CustomComponent
from langflow.field_typing import Text
@ -30,7 +29,7 @@ class HuggingFaceEndpointsComponent(CustomComponent):
def build(
self,
inputs: str,
input_value: str,
endpoint_url: str,
task: str = "text2text-generation",
huggingfacehub_api_token: Optional[str] = None,

View file

@ -62,7 +62,7 @@ class LlamaCppComponent(CustomComponent):
def build(
self,
model_path: str,
inputs: str,
input_value: str,
grammar: Optional[str] = None,
cache: Optional[bool] = None,
client: Optional[Any] = None,

View file

@ -171,7 +171,7 @@ class ChatOllamaComponent(CustomComponent):
self,
base_url: Optional[str],
model: str,
inputs: str,
input_value: str,
mirostat: Optional[str],
mirostat_eta: Optional[float] = None,
mirostat_tau: Optional[float] = None,

View file

@ -1,6 +1,7 @@
from typing import Optional
from langchain_openai import ChatOpenAI
from langflow import CustomComponent
from langflow.field_typing import NestedDict, Text
@ -60,7 +61,7 @@ class OpenAIModelComponent(CustomComponent):
def build(
self,
inputs: Text,
input_value: Text,
max_tokens: Optional[int] = 256,
model_kwargs: NestedDict = {},
model_name: str = "gpt-4-1106-preview",

View file

@ -62,7 +62,7 @@ class ChatVertexAIComponent(CustomComponent):
def build(
self,
inputs: str,
input_value: str,
credentials: Optional[str],
project: str,
examples: Optional[List[BaseMessage]] = [],

View file

@ -32,7 +32,7 @@ class RunnableExecComponent(CustomComponent):
def build(
self,
input_key: str,
inputs: str,
input_value: str,
runnable: Runnable,
output_key: str = "output",
) -> Text:

View file

@ -2,6 +2,7 @@ from typing import List, Optional
import chromadb # type: ignore
from langchain_community.vectorstores.chroma import Chroma
from langflow import CustomComponent
from langflow.field_typing import Embeddings, Text
from langflow.schema import Record, docs_to_records
@ -57,7 +58,7 @@ class ChromaSearchComponent(CustomComponent):
def build(
self,
inputs: Text,
input_value: Text,
search_type: str,
collection_name: str,
embedding: Embeddings,
@ -92,7 +93,8 @@ class ChromaSearchComponent(CustomComponent):
if chroma_server_host is not None:
chroma_settings = chromadb.config.Settings(
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
or None,
chroma_server_host=chroma_server_host,
chroma_server_port=chroma_server_port or None,
chroma_server_grpc_port=chroma_server_grpc_port or None,

View file

@ -1,3 +1,4 @@
import asyncio
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
@ -40,8 +41,10 @@ class Graph:
self._runs = 0
self._updates = 0
self.flow_id = flow_id
self._inputs = []
self._outputs = []
self._is_input_vertices = []
self._is_output_vertices = []
self._has_session_id_vertices = []
self._sorted_vertices_layers = []
self.top_level_vertices = []
for vertex in self._vertices:
@ -54,38 +57,37 @@ class Graph:
self.inactive_vertices = set()
self._build_graph()
self.build_graph_maps()
self.define_inputs_and_outputs()
self.define_vertices_lists()
def define_inputs_and_outputs(self):
@property
def sorted_vertices_layers(self):
if not self._sorted_vertices_layers:
self.sort_vertices()
return self._sorted_vertices_layers
def define_vertices_lists(self):
"""
Defines the input and output vertices of the graph.
Defines the lists of vertices that are inputs, outputs, and have session_id.
"""
attributes = ["is_input", "is_output", "has_session_id"]
for vertex in self.vertices:
if vertex.is_input:
self._inputs.append(vertex.id)
if vertex.is_output:
self._outputs.append(vertex.id)
for attribute in attributes:
if getattr(vertex, attribute):
getattr(self, f"_{attribute}_vertices").append(vertex.id)
def run(self, inputs: Dict[str, str]) -> List["ResultData"]:
async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
# we need to go through self.inputs and update the self._raw_params
# of the vertices that are inputs
for vertex_id in self.inputs:
vertex = self.get_vertex(vertex_id)
if vertex is None:
raise ValueError(f"Vertex {vertex_id} not found")
vertex.update_raw_params(inputs)
try:
self.build()
await self.process()
self.increment_run_count()
except Exception as exc:
logger.exception(exc)
raise ValueError(f"Error running graph: {exc}") from exc
# Now we get the outputs from the self.outputs
outputs = []
for vertex_id in self.outputs:
vertex = self.get_vertex(vertex_id)
@ -94,6 +96,23 @@ class Graph:
outputs.append(vertex.result)
return outputs
async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]:
"""Runs the graph with the given inputs."""
# inputs is {"message": "Hello, world!"}
# we need to go through self.inputs and update the self._raw_params
# of the vertices that are inputs
# if the value is a list, we need to run multiple times
outputs = []
inputs_values = inputs.get("input_value")
if not isinstance(inputs_values, list):
inputs_values = [inputs_values]
for input_value in inputs_values:
run_outputs = await self._run({"input_value": input_value})
logger.debug(f"Run outputs: {run_outputs}")
outputs.extend(run_outputs)
return outputs
@property
def metadata(self):
return {
@ -404,6 +423,36 @@ class Graph:
raise ValueError("No root vertex found")
return await root_vertex.build()
async def process(self) -> "Graph":
"""Processes the graph with vertices in each layer run in parallel."""
vertices_layers = self.sorted_vertices_layers
for layer_index, layer in enumerate(vertices_layers):
tasks = []
for vertex_id in layer:
vertex = self.get_vertex(vertex_id)
task = asyncio.create_task(
vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}"
)
tasks.append(task)
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
await self._execute_tasks(tasks)
logger.debug("Graph processing complete")
return self
async def _execute_tasks(self, tasks):
"""Executes tasks in parallel, handling exceptions for each task."""
results = []
for task in asyncio.as_completed(tasks):
try:
result = await task
results.append(result)
except Exception as e:
# Log the exception along with the task name for easier debugging
task_name = task.get_name()
logger.error(f"Task {task_name} failed with exception: {e}")
return results
def topological_sort(self) -> List[Vertex]:
"""
Performs a topological sort of the vertices in the graph.
@ -671,6 +720,7 @@ class Graph:
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
vertices_layers = self.sort_chat_inputs_first(vertices_layers)
self.increment_run_count()
self._sorted_vertices_layers = vertices_layers
return vertices_layers
def sort_interface_components_first(

View file

@ -25,7 +25,7 @@ if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
class VertexStates(Enum):
class VertexStates(str, Enum):
"""Vertex are related to it being active, inactive, or in an error state."""
ACTIVE = "active"
@ -53,6 +53,7 @@ class Vertex:
output_component_name in self.id
for output_component_name in OUTPUT_COMPONENTS
)
self.has_session_id = None
self._custom_component = None
self.has_external_input = False
self.has_external_output = False
@ -223,6 +224,8 @@ class Vertex:
if isinstance(value, dict)
}
self.has_session_id = "session_id" in template_dicts
self.required_inputs = [
template_dicts[key]["type"]
for key, value in template_dicts.items()

View file

@ -47,10 +47,10 @@ class CustomComponent(Component):
"""The icon of the component. It should be an emoji. Defaults to None."""
is_input: Optional[bool] = None
"""The input state of the component. Defaults to None.
If True, the component must have a field named 'message'."""
If True, the component must have a field named 'input_value'."""
is_output: Optional[bool] = None
"""The output state of the component. Defaults to None.
If True, the component must have a field named 'message'."""
If True, the component must have a field named 'input_value'."""
code: Optional[str] = None
"""The code of the component. Defaults to None."""
field_config: dict = {}

View file

@ -7,6 +7,9 @@ from langchain.schema import AgentAction, Document
from langchain_community.vectorstores import VectorStore
from langchain_core.messages import AIMessage
from langchain_core.runnables.base import Runnable
from loguru import logger
from pydantic import BaseModel
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.interface.custom.custom_component import CustomComponent
@ -17,8 +20,6 @@ from langflow.interface.run import (
)
from langflow.services.deps import get_session_service
from langflow.services.session.service import SessionService
from loguru import logger
from pydantic import BaseModel
def fix_memory_inputs(langchain_object):
@ -146,7 +147,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
result = await runnable.ainvoke(inputs)
else:
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
raise ValueError(
f"Runnable {runnable} does not support inputs of type {type(inputs)}"
)
# Check if the result is a list of AIMessages
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
result = [r.content for r in result]
@ -155,7 +158,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
return result
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
async def process_inputs_dict(
built_object: Union[Chain, VectorStore, Runnable], inputs: dict
):
if isinstance(built_object, Chain):
if inputs is None:
raise ValueError("Inputs must be provided for a Chain")
@ -190,7 +195,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
return await process_runnable(built_object, inputs)
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
async def generate_result(
built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]
):
if isinstance(inputs, dict):
result = await process_inputs_dict(built_object, inputs)
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
@ -222,7 +229,9 @@ async def process_graph_cached(
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
session_id = session_service.generate_key(
session_id=session_id, data_graph=data_graph
)
# Load the graph using SessionService
session = await session_service.load_session(session_id, data_graph)
graph, artifacts = session if session else (None, None)
@ -258,14 +267,34 @@ async def build_graph_and_generate_result(
return Result(result=result, session_id=session_id)
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
async def run_graph(
graph: Union["Graph", dict],
session_id: str,
inputs: Optional[Union[dict, List[dict]]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
):
"""Run the graph and generate the result"""
if isinstance(graph, dict):
graph = Graph.from_payload(graph)
outputs = await graph.run(inputs)
if session_id and session_service:
session_service.update_session(session_id, (graph, artifacts))
return outputs
def validate_input(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> List[Dict[str, Any]]:
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
raise ValueError("graph_data and tweaks should be dictionaries")
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
if not isinstance(nodes, list):
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
raise ValueError(
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
)
return nodes
@ -274,7 +303,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
template_data = node.get("data", {}).get("node", {}).get("template")
if not isinstance(template_data, dict):
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
logger.warning(
f"Template data for node {node.get('id')} should be a dictionary"
)
return
for tweak_name, tweak_value in node_tweaks.items():
@ -289,7 +320,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None:
vertex.params[tweak_name] = tweak_value
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
def process_tweaks(
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
) -> Dict[str, Any]:
"""
This function is used to tweak the graph data using the node id and the tweaks dict.
@ -310,7 +343,9 @@ def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
if node_tweaks := tweaks.get(node_id):
apply_tweaks(node, node_tweaks)
else:
logger.warning("Each node should be a dictionary with an 'id' key of type str")
logger.warning(
"Each node should be a dictionary with an 'id' key of type str"
)
return graph_data
@ -322,6 +357,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]):
if node_tweaks := tweaks.get(node_id):
apply_tweaks_on_vertex(vertex, node_tweaks)
else:
logger.warning("Each node should be a Vertex with an 'id' attribute of type str")
logger.warning(
"Each node should be a Vertex with an 'id' attribute of type str"
)
return graph

View file

@ -49,10 +49,10 @@ class FrontendNode(BaseModel):
"""Icon of the frontend node."""
is_input: Optional[bool] = None
"""Whether the frontend node is used as an input when processing the Graph.
If True, there should be a field named 'message'."""
If True, there should be a field named 'input_value'."""
is_output: Optional[bool] = None
"""Whether the frontend node is used as an output when processing the Graph.
If True, there should be a field named 'message'."""
If True, there should be a field named 'input_value'."""
is_composition: Optional[bool] = None
"""Whether the frontend node is used for composition."""
base_classes: List[str]