diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 8a3f67ddf..2dc79e85a 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -3,11 +3,15 @@ from typing import Annotated, Any, List, Optional, Union import sqlalchemy as sa from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status +from loguru import logger +from sqlmodel import select + from langflow.api.utils import update_frontend_node_with_template_values from langflow.api.v1.schemas import ( CustomComponentCode, PreloadResponse, ProcessResponse, + RunResponse, TaskResponse, TaskStatusResponse, UploadFileResponse, @@ -15,15 +19,23 @@ from langflow.api.v1.schemas import ( from langflow.interface.custom.custom_component import CustomComponent from langflow.interface.custom.directory_reader import DirectoryReader from langflow.interface.custom.utils import build_custom_component_template -from langflow.processing.process import build_graph_and_generate_result, process_graph_cached, process_tweaks +from langflow.processing.process import ( + build_graph_and_generate_result, + process_graph_cached, + process_tweaks, + run_graph, +) from langflow.services.auth.utils import api_key_security, get_current_active_user from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow from langflow.services.database.models.user.model import User -from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service +from langflow.services.deps import ( + get_session, + get_session_service, + get_settings_service, + get_task_service, +) from langflow.services.session.service import SessionService -from loguru import logger -from sqlmodel import select try: from langflow.worker import process_graph_cached_task @@ -33,9 +45,10 @@ except ImportError: raise NotImplementedError("Celery is not installed") -from langflow.services.task.service import TaskService from sqlmodel import Session +from langflow.services.task.service import TaskService + # build router router = APIRouter(tags=["Base"]) @@ -80,9 +93,15 @@ async def process_graph_data( ) if session_id is None: # Generate a session ID - session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data) + session_id = get_session_service().generate_key( + session_id=session_id, data_graph=graph_data + ) task_id, task = await task_service.launch_task( - process_graph_cached_task if task_service.use_celery else process_graph_cached, + ( + process_graph_cached_task + if task_service.use_celery + else process_graph_cached + ), graph_data, inputs, clear_cache, @@ -176,7 +195,11 @@ async def preload_flow( else: if session_id is None: session_id = flow_id - flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first() + flow = session.exec( + select(Flow) + .where(Flow.id == flow_id) + .where(Flow.user_id == api_key_user.id) + ).first() if flow is None: raise ValueError(f"Flow {flow_id} not found") @@ -197,6 +220,76 @@ async def preload_flow( raise HTTPException(status_code=500, detail=str(exc)) from exc +@router.post("/run/{flow_id}", response_model=ProcessResponse) +async def run_flow_with_caching( + session: Annotated[Session, Depends(get_session)], + flow_id: str, + inputs: Optional[Union[List[dict], dict]] = None, + tweaks: Optional[dict] = None, + session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 + api_key_user: User = Depends(api_key_security), + session_service: SessionService = Depends(get_session_service), +): + try: + if session_id: + session_data = await session_service.load_session(session_id) + graph, artifacts = session_data if session_data else (None, None) + task_result: Any = None + task_status = None + if not graph: + raise ValueError("Graph not found in the session") + task_result = await run_graph( + graph, + session_id, + inputs, + artifacts=artifacts, + session_service=session_service, + ) + + else: + # Get the flow that matches the flow_id and belongs to the user + # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() + flow = session.exec( + select(Flow) + .where(Flow.id == flow_id) + .where(Flow.user_id == api_key_user.id) + ).first() + if flow is None: + raise ValueError(f"Flow {flow_id} not found") + + if flow.data is None: + raise ValueError(f"Flow {flow_id} has no data") + graph_data = flow.data + graph_data = process_tweaks(graph_data, tweaks) + task_result = await run_graph( + graph_data, + inputs, + tweaks, + session_id, + session_service=session_service, + ) + + return RunResponse( + outputs=task_result, session_id=session_id, status=task_status + ) + except sa.exc.StatementError as exc: + # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') + if "badly formed hexadecimal UUID string" in str(exc): + # This means the Flow ID is not a valid UUID which means it can't find the flow + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + except ValueError as exc: + if f"Flow {flow_id} not found" in str(exc): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc + else: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc) + ) from exc + + @router.post( "/predict/{flow_id}", response_model=ProcessResponse, @@ -269,7 +362,11 @@ async def process( # Get the flow that matches the flow_id and belongs to the user # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() - flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first() + flow = session.exec( + select(Flow) + .where(Flow.id == flow_id) + .where(Flow.user_id == api_key_user.id) + ).first() if flow is None: raise ValueError(f"Flow {flow_id} not found") @@ -289,12 +386,18 @@ async def process( # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') if "badly formed hexadecimal UUID string" in str(exc): # This means the Flow ID is not a valid UUID which means it can't find the flow - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc except ValueError as exc: if f"Flow {flow_id} not found" in str(exc): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) + ) from exc else: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc) + ) from exc except Exception as e: # Log stack trace logger.exception(e) @@ -364,12 +467,16 @@ async def custom_component( built_frontend_node = build_custom_component_template(component, user_id=user.id) - built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node) + built_frontend_node = update_frontend_node_with_template_values( + built_frontend_node, raw_code.frontend_node + ) return built_frontend_node @router.post("/custom_component/reload", status_code=HTTPStatus.OK) -async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)): +async def reload_custom_component( + path: str, user: User = Depends(get_current_active_user) +): from langflow.interface.custom.utils import build_custom_component_template try: @@ -391,6 +498,8 @@ async def custom_component_update( ): component = CustomComponent(code=raw_code.code) - component_node = build_custom_component_template(component, user_id=user.id, update_field=raw_code.field) + component_node = build_custom_component_template( + component, user_id=user.id, update_field=raw_code.field + ) # Update the field return component_node diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index adb26202a..38db8f85a 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -66,6 +66,14 @@ class ProcessResponse(BaseModel): backend: Optional[str] = None +class RunResponse(BaseModel): + """Run response schema.""" + + outputs: Optional[List[Any]] = None + status: Optional[str] = None + session_id: Optional[str] = None + + class PreloadResponse(BaseModel): """Preload response schema.""" @@ -73,9 +81,6 @@ class PreloadResponse(BaseModel): is_clear: Optional[bool] = None -# TaskStatusResponse( -# status=task.status, result=task.result if task.ready() else None -# ) class TaskStatusResponse(BaseModel): """Task status response schema.""" diff --git a/src/backend/langflow/components/chains/ConversationChain.py b/src/backend/langflow/components/chains/ConversationChain.py index 3183954a3..7d9d28dcc 100644 --- a/src/backend/langflow/components/chains/ConversationChain.py +++ b/src/backend/langflow/components/chains/ConversationChain.py @@ -23,7 +23,7 @@ class ConversationChainComponent(CustomComponent): def build( self, - inputs: str, + input_value: str, llm: BaseLanguageModel, memory: Optional[BaseMemory] = None, ) -> Text: diff --git a/src/backend/langflow/components/chains/LLMCheckerChain.py b/src/backend/langflow/components/chains/LLMCheckerChain.py index bfee0b5a9..15a540311 100644 --- a/src/backend/langflow/components/chains/LLMCheckerChain.py +++ b/src/backend/langflow/components/chains/LLMCheckerChain.py @@ -18,7 +18,7 @@ class LLMCheckerChainComponent(CustomComponent): def build( self, - inputs: str, + input_value: str, llm: BaseLanguageModel, ) -> Text: diff --git a/src/backend/langflow/components/chains/LLMMathChain.py b/src/backend/langflow/components/chains/LLMMathChain.py index 919de34e6..7fb253b83 100644 --- a/src/backend/langflow/components/chains/LLMMathChain.py +++ b/src/backend/langflow/components/chains/LLMMathChain.py @@ -24,7 +24,7 @@ class LLMMathChainComponent(CustomComponent): def build( self, - inputs: Text, + input_value: Text, llm: BaseLanguageModel, llm_chain: LLMChain, input_key: str = "question", diff --git a/src/backend/langflow/components/chains/RetrievalQA.py b/src/backend/langflow/components/chains/RetrievalQA.py index 2fe31353e..4968afe87 100644 --- a/src/backend/langflow/components/chains/RetrievalQA.py +++ b/src/backend/langflow/components/chains/RetrievalQA.py @@ -27,7 +27,7 @@ class RetrievalQAComponent(CustomComponent): self, combine_documents_chain: BaseCombineDocumentsChain, retriever: BaseRetriever, - inputs: str = "", + input_value: str = "", memory: Optional[BaseMemory] = None, input_key: str = "query", output_key: str = "result", diff --git a/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py b/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py index faf3ab7dd..8be64c631 100644 --- a/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py +++ b/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py @@ -26,7 +26,7 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent): def build( self, - inputs: str, + input_value: str, retriever: BaseRetriever, llm: BaseLanguageModel, chain_type: str, diff --git a/src/backend/langflow/components/chains/SQLGenerator.py b/src/backend/langflow/components/chains/SQLGenerator.py index ea22a6de0..39b8fe394 100644 --- a/src/backend/langflow/components/chains/SQLGenerator.py +++ b/src/backend/langflow/components/chains/SQLGenerator.py @@ -28,7 +28,7 @@ class SQLGeneratorComponent(CustomComponent): def build( self, - inputs: Text, + input_value: Text, db: SQLDatabase, llm: BaseLanguageModel, top_k: int = 5, diff --git a/src/backend/langflow/components/io/ChatInput.py b/src/backend/langflow/components/io/ChatInput.py index 6d96a6b96..0666f92d1 100644 --- a/src/backend/langflow/components/io/ChatInput.py +++ b/src/backend/langflow/components/io/ChatInput.py @@ -11,7 +11,7 @@ class ChatInput(CustomComponent): def build_config(self): return { - "message": { + "input_value": { "input_types": ["Text"], "display_name": "Message", "multiline": True, @@ -35,26 +35,26 @@ class ChatInput(CustomComponent): self, sender: Optional[str] = "User", sender_name: Optional[str] = "User", - message: Optional[str] = None, + input_value: Optional[str] = None, session_id: Optional[str] = None, return_record: Optional[bool] = False, ) -> Union[Text, Record]: if return_record: - if isinstance(message, Record): + if isinstance(input_value, Record): # Update the data of the record - message.data["sender"] = sender - message.data["sender_name"] = sender_name - message.data["session_id"] = session_id + input_value.data["sender"] = sender + input_value.data["sender_name"] = sender_name + input_value.data["session_id"] = session_id else: - message = Record( - text=message, + input_value = Record( + text=input_value, data={ "sender": sender, "sender_name": sender_name, "session_id": session_id, }, ) - if not message: - message = "" - self.status = message - return message + if not input_value: + input_value = "" + self.status = input_value + return input_value diff --git a/src/backend/langflow/components/io/ChatOutput.py b/src/backend/langflow/components/io/ChatOutput.py index 05639cdb2..72667374f 100644 --- a/src/backend/langflow/components/io/ChatOutput.py +++ b/src/backend/langflow/components/io/ChatOutput.py @@ -17,7 +17,7 @@ class ChatOutput(CustomComponent): def build_config(self): return { - "message": {"input_types": ["Text"], "display_name": "Message"}, + "input_value": {"input_types": ["Text"], "display_name": "Message"}, "sender": { "options": ["Machine", "User"], "display_name": "Sender Type", @@ -39,25 +39,25 @@ class ChatOutput(CustomComponent): sender: Optional[str] = "Machine", sender_name: Optional[str] = "AI", session_id: Optional[str] = None, - message: Optional[str] = None, + input_value: Optional[str] = None, return_record: Optional[bool] = False, ) -> Union[Text, Record]: if return_record: - if isinstance(message, Record): + if isinstance(input_value, Record): # Update the data of the record - message.data["sender"] = sender - message.data["sender_name"] = sender_name - message.data["session_id"] = session_id + input_value.data["sender"] = sender + input_value.data["sender_name"] = sender_name + input_value.data["session_id"] = session_id else: - message = Record( - text=message, + input_value = Record( + text=input_value, data={ "sender": sender, "sender_name": sender_name, "session_id": session_id, }, ) - if not message: - message = "" - self.status = message - return message + if not input_value: + input_value = "" + self.status = input_value + return input_value diff --git a/src/backend/langflow/components/models/AmazonBedrockModel.py b/src/backend/langflow/components/models/AmazonBedrockModel.py index a2e008e2e..68e404773 100644 --- a/src/backend/langflow/components/models/AmazonBedrockModel.py +++ b/src/backend/langflow/components/models/AmazonBedrockModel.py @@ -39,7 +39,7 @@ class AmazonBedrockComponent(CustomComponent): def build( self, - inputs: str, + input_value: str, model_id: str = "anthropic.claude-instant-v1", credentials_profile_name: Optional[str] = None, region_name: Optional[str] = None, diff --git a/src/backend/langflow/components/models/AnthropicModel.py b/src/backend/langflow/components/models/AnthropicModel.py index 793bec46a..be6e46d9a 100644 --- a/src/backend/langflow/components/models/AnthropicModel.py +++ b/src/backend/langflow/components/models/AnthropicModel.py @@ -9,7 +9,9 @@ from langflow.field_typing import Text class AnthropicLLM(CustomComponent): display_name: str = "AnthropicModel" - description: str = "Generate text using Anthropic Chat&Completion large language models." + description: str = ( + "Generate text using Anthropic Chat&Completion large language models." + ) def build_config(self): return { @@ -53,7 +55,7 @@ class AnthropicLLM(CustomComponent): def build( self, model: str, - inputs: str, + input_value: str, anthropic_api_key: Optional[str] = None, max_tokens: Optional[int] = None, temperature: Optional[float] = None, @@ -66,7 +68,9 @@ class AnthropicLLM(CustomComponent): try: output = ChatAnthropic( model_name=model, - anthropic_api_key=(SecretStr(anthropic_api_key) if anthropic_api_key else None), + anthropic_api_key=( + SecretStr(anthropic_api_key) if anthropic_api_key else None + ), max_tokens_to_sample=max_tokens, # type: ignore temperature=temperature, anthropic_api_url=api_endpoint, diff --git a/src/backend/langflow/components/models/AzureOpenAIModel.py b/src/backend/langflow/components/models/AzureOpenAIModel.py index 1e646e43a..be1f724bf 100644 --- a/src/backend/langflow/components/models/AzureOpenAIModel.py +++ b/src/backend/langflow/components/models/AzureOpenAIModel.py @@ -9,7 +9,9 @@ from langflow import CustomComponent class AzureChatOpenAIComponent(CustomComponent): display_name: str = "AzureOpenAI Model" description: str = "Generate text using LLM model from Azure OpenAI." - documentation: str = "https://python.langchain.com/docs/integrations/llms/azure_openai" + documentation: str = ( + "https://python.langchain.com/docs/integrations/llms/azure_openai" + ) beta = False AZURE_OPENAI_MODELS = [ @@ -78,7 +80,7 @@ class AzureChatOpenAIComponent(CustomComponent): self, model: str, azure_endpoint: str, - inputs: str, + input_value: str, azure_deployment: str, api_key: str, api_version: str, diff --git a/src/backend/langflow/components/models/BaiduQianfanChatModel.py b/src/backend/langflow/components/models/BaiduQianfanChatModel.py index 88051d0e9..9eadb7013 100644 --- a/src/backend/langflow/components/models/BaiduQianfanChatModel.py +++ b/src/backend/langflow/components/models/BaiduQianfanChatModel.py @@ -73,7 +73,7 @@ class QianfanChatEndpointComponent(CustomComponent): def build( self, - inputs: str, + input_value: str, model: str = "ERNIE-Bot-turbo", qianfan_ak: Optional[str] = None, qianfan_sk: Optional[str] = None, diff --git a/src/backend/langflow/components/models/CTransformersModel.py b/src/backend/langflow/components/models/CTransformersModel.py index 932b1b351..60cc2eb12 100644 --- a/src/backend/langflow/components/models/CTransformersModel.py +++ b/src/backend/langflow/components/models/CTransformersModel.py @@ -35,11 +35,13 @@ class CTransformersComponent(CustomComponent): self, model: str, model_file: str, - inputs: str, + input_value: str, model_type: str, config: Optional[Dict] = None, ) -> Text: - output = CTransformers(model=model, model_file=model_file, model_type=model_type, config=config) + output = CTransformers( + model=model, model_file=model_file, model_type=model_type, config=config + ) message = output.invoke(inputs) result = message.content if hasattr(message, "content") else message self.status = result diff --git a/src/backend/langflow/components/models/CohereModel.py b/src/backend/langflow/components/models/CohereModel.py index 3912cb855..28b198ec1 100644 --- a/src/backend/langflow/components/models/CohereModel.py +++ b/src/backend/langflow/components/models/CohereModel.py @@ -34,7 +34,7 @@ class CohereComponent(CustomComponent): def build( self, cohere_api_key: str, - inputs: str, + input_value: str, max_tokens: int = 256, temperature: float = 0.75, ) -> Text: diff --git a/src/backend/langflow/components/models/GoogleGenerativeAIModel.py b/src/backend/langflow/components/models/GoogleGenerativeAIModel.py index ce967bd57..2ff01c4c7 100644 --- a/src/backend/langflow/components/models/GoogleGenerativeAIModel.py +++ b/src/backend/langflow/components/models/GoogleGenerativeAIModel.py @@ -57,7 +57,7 @@ class GoogleGenerativeAIComponent(CustomComponent): self, google_api_key: str, model: str, - inputs: str, + input_value: str, max_output_tokens: Optional[int] = None, temperature: float = 0.1, top_k: Optional[int] = None, diff --git a/src/backend/langflow/components/models/HuggingFaceModel.py b/src/backend/langflow/components/models/HuggingFaceModel.py index 4357ede61..394938344 100644 --- a/src/backend/langflow/components/models/HuggingFaceModel.py +++ b/src/backend/langflow/components/models/HuggingFaceModel.py @@ -4,7 +4,6 @@ from langchain_community.chat_models.huggingface import ChatHuggingFace from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint from langflow import CustomComponent - from langflow.field_typing import Text @@ -30,7 +29,7 @@ class HuggingFaceEndpointsComponent(CustomComponent): def build( self, - inputs: str, + input_value: str, endpoint_url: str, task: str = "text2text-generation", huggingfacehub_api_token: Optional[str] = None, diff --git a/src/backend/langflow/components/models/LlamaCppModel.py b/src/backend/langflow/components/models/LlamaCppModel.py index af0de5159..53a6f8ace 100644 --- a/src/backend/langflow/components/models/LlamaCppModel.py +++ b/src/backend/langflow/components/models/LlamaCppModel.py @@ -62,7 +62,7 @@ class LlamaCppComponent(CustomComponent): def build( self, model_path: str, - inputs: str, + input_value: str, grammar: Optional[str] = None, cache: Optional[bool] = None, client: Optional[Any] = None, diff --git a/src/backend/langflow/components/models/OllamaModel.py b/src/backend/langflow/components/models/OllamaModel.py index 129f96482..3dc8dacab 100644 --- a/src/backend/langflow/components/models/OllamaModel.py +++ b/src/backend/langflow/components/models/OllamaModel.py @@ -171,7 +171,7 @@ class ChatOllamaComponent(CustomComponent): self, base_url: Optional[str], model: str, - inputs: str, + input_value: str, mirostat: Optional[str], mirostat_eta: Optional[float] = None, mirostat_tau: Optional[float] = None, diff --git a/src/backend/langflow/components/models/OpenAIModel.py b/src/backend/langflow/components/models/OpenAIModel.py index 1cc352b20..07ba7013c 100644 --- a/src/backend/langflow/components/models/OpenAIModel.py +++ b/src/backend/langflow/components/models/OpenAIModel.py @@ -1,6 +1,7 @@ from typing import Optional from langchain_openai import ChatOpenAI + from langflow import CustomComponent from langflow.field_typing import NestedDict, Text @@ -60,7 +61,7 @@ class OpenAIModelComponent(CustomComponent): def build( self, - inputs: Text, + input_value: Text, max_tokens: Optional[int] = 256, model_kwargs: NestedDict = {}, model_name: str = "gpt-4-1106-preview", diff --git a/src/backend/langflow/components/models/VertexAiModel.py b/src/backend/langflow/components/models/VertexAiModel.py index eee804e02..81338f723 100644 --- a/src/backend/langflow/components/models/VertexAiModel.py +++ b/src/backend/langflow/components/models/VertexAiModel.py @@ -62,7 +62,7 @@ class ChatVertexAIComponent(CustomComponent): def build( self, - inputs: str, + input_value: str, credentials: Optional[str], project: str, examples: Optional[List[BaseMessage]] = [], diff --git a/src/backend/langflow/components/utilities/RunnableExecutor.py b/src/backend/langflow/components/utilities/RunnableExecutor.py index f83f352b4..5533e6d1d 100644 --- a/src/backend/langflow/components/utilities/RunnableExecutor.py +++ b/src/backend/langflow/components/utilities/RunnableExecutor.py @@ -32,7 +32,7 @@ class RunnableExecComponent(CustomComponent): def build( self, input_key: str, - inputs: str, + input_value: str, runnable: Runnable, output_key: str = "output", ) -> Text: diff --git a/src/backend/langflow/components/vectorstores/ChromaSearch.py b/src/backend/langflow/components/vectorstores/ChromaSearch.py index c6eb1ebac..5dd33abf2 100644 --- a/src/backend/langflow/components/vectorstores/ChromaSearch.py +++ b/src/backend/langflow/components/vectorstores/ChromaSearch.py @@ -2,6 +2,7 @@ from typing import List, Optional import chromadb # type: ignore from langchain_community.vectorstores.chroma import Chroma + from langflow import CustomComponent from langflow.field_typing import Embeddings, Text from langflow.schema import Record, docs_to_records @@ -57,7 +58,7 @@ class ChromaSearchComponent(CustomComponent): def build( self, - inputs: Text, + input_value: Text, search_type: str, collection_name: str, embedding: Embeddings, @@ -92,7 +93,8 @@ class ChromaSearchComponent(CustomComponent): if chroma_server_host is not None: chroma_settings = chromadb.config.Settings( - chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None, + chroma_server_cors_allow_origins=chroma_server_cors_allow_origins + or None, chroma_server_host=chroma_server_host, chroma_server_port=chroma_server_port or None, chroma_server_grpc_port=chroma_server_grpc_port or None, diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 68e16bed8..341a4729c 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -1,3 +1,4 @@ +import asyncio from collections import defaultdict, deque from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union @@ -40,8 +41,10 @@ class Graph: self._runs = 0 self._updates = 0 self.flow_id = flow_id - self._inputs = [] - self._outputs = [] + self._is_input_vertices = [] + self._is_output_vertices = [] + self._has_session_id_vertices = [] + self._sorted_vertices_layers = [] self.top_level_vertices = [] for vertex in self._vertices: @@ -54,38 +57,37 @@ class Graph: self.inactive_vertices = set() self._build_graph() self.build_graph_maps() - self.define_inputs_and_outputs() + self.define_vertices_lists() - def define_inputs_and_outputs(self): + @property + def sorted_vertices_layers(self): + if not self._sorted_vertices_layers: + self.sort_vertices() + return self._sorted_vertices_layers + + def define_vertices_lists(self): """ - Defines the input and output vertices of the graph. + Defines the lists of vertices that are inputs, outputs, and have session_id. """ + attributes = ["is_input", "is_output", "has_session_id"] for vertex in self.vertices: - if vertex.is_input: - self._inputs.append(vertex.id) - if vertex.is_output: - self._outputs.append(vertex.id) + for attribute in attributes: + if getattr(vertex, attribute): + getattr(self, f"_{attribute}_vertices").append(vertex.id) - def run(self, inputs: Dict[str, str]) -> List["ResultData"]: + async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]: """Runs the graph with the given inputs.""" - - # inputs is {"message": "Hello, world!"} - # we need to go through self.inputs and update the self._raw_params - # of the vertices that are inputs - for vertex_id in self.inputs: vertex = self.get_vertex(vertex_id) if vertex is None: raise ValueError(f"Vertex {vertex_id} not found") vertex.update_raw_params(inputs) try: - self.build() + await self.process() self.increment_run_count() except Exception as exc: logger.exception(exc) raise ValueError(f"Error running graph: {exc}") from exc - - # Now we get the outputs from the self.outputs outputs = [] for vertex_id in self.outputs: vertex = self.get_vertex(vertex_id) @@ -94,6 +96,23 @@ class Graph: outputs.append(vertex.result) return outputs + async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]: + """Runs the graph with the given inputs.""" + + # inputs is {"message": "Hello, world!"} + # we need to go through self.inputs and update the self._raw_params + # of the vertices that are inputs + # if the value is a list, we need to run multiple times + outputs = [] + inputs_values = inputs.get("input_value") + if not isinstance(inputs_values, list): + inputs_values = [inputs_values] + for input_value in inputs_values: + run_outputs = await self._run({"input_value": input_value}) + logger.debug(f"Run outputs: {run_outputs}") + outputs.extend(run_outputs) + return outputs + @property def metadata(self): return { @@ -404,6 +423,36 @@ class Graph: raise ValueError("No root vertex found") return await root_vertex.build() + async def process(self) -> "Graph": + """Processes the graph with vertices in each layer run in parallel.""" + vertices_layers = self.sorted_vertices_layers + + for layer_index, layer in enumerate(vertices_layers): + tasks = [] + for vertex_id in layer: + vertex = self.get_vertex(vertex_id) + task = asyncio.create_task( + vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}" + ) + tasks.append(task) + logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks") + await self._execute_tasks(tasks) + logger.debug("Graph processing complete") + return self + + async def _execute_tasks(self, tasks): + """Executes tasks in parallel, handling exceptions for each task.""" + results = [] + for task in asyncio.as_completed(tasks): + try: + result = await task + results.append(result) + except Exception as e: + # Log the exception along with the task name for easier debugging + task_name = task.get_name() + logger.error(f"Task {task_name} failed with exception: {e}") + return results + def topological_sort(self) -> List[Vertex]: """ Performs a topological sort of the vertices in the graph. @@ -671,6 +720,7 @@ class Graph: vertices_layers = self.sort_by_avg_build_time(vertices_layers) vertices_layers = self.sort_chat_inputs_first(vertices_layers) self.increment_run_count() + self._sorted_vertices_layers = vertices_layers return vertices_layers def sort_interface_components_first( diff --git a/src/backend/langflow/graph/vertex/base.py b/src/backend/langflow/graph/vertex/base.py index e87389c7b..3e1133491 100644 --- a/src/backend/langflow/graph/vertex/base.py +++ b/src/backend/langflow/graph/vertex/base.py @@ -25,7 +25,7 @@ if TYPE_CHECKING: from langflow.graph.graph.base import Graph -class VertexStates(Enum): +class VertexStates(str, Enum): """Vertex are related to it being active, inactive, or in an error state.""" ACTIVE = "active" @@ -53,6 +53,7 @@ class Vertex: output_component_name in self.id for output_component_name in OUTPUT_COMPONENTS ) + self.has_session_id = None self._custom_component = None self.has_external_input = False self.has_external_output = False @@ -223,6 +224,8 @@ class Vertex: if isinstance(value, dict) } + self.has_session_id = "session_id" in template_dicts + self.required_inputs = [ template_dicts[key]["type"] for key, value in template_dicts.items() diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index 549c0dad3..a8c81f041 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -47,10 +47,10 @@ class CustomComponent(Component): """The icon of the component. It should be an emoji. Defaults to None.""" is_input: Optional[bool] = None """The input state of the component. Defaults to None. - If True, the component must have a field named 'message'.""" + If True, the component must have a field named 'input_value'.""" is_output: Optional[bool] = None """The output state of the component. Defaults to None. - If True, the component must have a field named 'message'.""" + If True, the component must have a field named 'input_value'.""" code: Optional[str] = None """The code of the component. Defaults to None.""" field_config: dict = {} diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index ad4f8fb78..69e47b242 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -7,6 +7,9 @@ from langchain.schema import AgentAction, Document from langchain_community.vectorstores import VectorStore from langchain_core.messages import AIMessage from langchain_core.runnables.base import Runnable +from loguru import logger +from pydantic import BaseModel + from langflow.graph.graph.base import Graph from langflow.graph.vertex.base import Vertex from langflow.interface.custom.custom_component import CustomComponent @@ -17,8 +20,6 @@ from langflow.interface.run import ( ) from langflow.services.deps import get_session_service from langflow.services.session.service import SessionService -from loguru import logger -from pydantic import BaseModel def fix_memory_inputs(langchain_object): @@ -146,7 +147,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"): result = await runnable.ainvoke(inputs) else: - raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}") + raise ValueError( + f"Runnable {runnable} does not support inputs of type {type(inputs)}" + ) # Check if the result is a list of AIMessages if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result): result = [r.content for r in result] @@ -155,7 +158,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]): return result -async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict): +async def process_inputs_dict( + built_object: Union[Chain, VectorStore, Runnable], inputs: dict +): if isinstance(built_object, Chain): if inputs is None: raise ValueError("Inputs must be provided for a Chain") @@ -190,7 +195,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]): return await process_runnable(built_object, inputs) -async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]): +async def generate_result( + built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]] +): if isinstance(inputs, dict): result = await process_inputs_dict(built_object, inputs) elif isinstance(inputs, List) and isinstance(built_object, Runnable): @@ -222,7 +229,9 @@ async def process_graph_cached( if clear_cache: session_service.clear_session(session_id) if session_id is None: - session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph) + session_id = session_service.generate_key( + session_id=session_id, data_graph=data_graph + ) # Load the graph using SessionService session = await session_service.load_session(session_id, data_graph) graph, artifacts = session if session else (None, None) @@ -258,14 +267,34 @@ async def build_graph_and_generate_result( return Result(result=result, session_id=session_id) -def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]: +async def run_graph( + graph: Union["Graph", dict], + session_id: str, + inputs: Optional[Union[dict, List[dict]]] = None, + artifacts: Optional[Dict[str, Any]] = None, + session_service: Optional[SessionService] = None, +): + """Run the graph and generate the result""" + if isinstance(graph, dict): + graph = Graph.from_payload(graph) + outputs = await graph.run(inputs) + if session_id and session_service: + session_service.update_session(session_id, (graph, artifacts)) + return outputs + + +def validate_input( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> List[Dict[str, Any]]: if not isinstance(graph_data, dict) or not isinstance(tweaks, dict): raise ValueError("graph_data and tweaks should be dictionaries") nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes") if not isinstance(nodes, list): - raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key") + raise ValueError( + "graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key" + ) return nodes @@ -274,7 +303,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None: template_data = node.get("data", {}).get("node", {}).get("template") if not isinstance(template_data, dict): - logger.warning(f"Template data for node {node.get('id')} should be a dictionary") + logger.warning( + f"Template data for node {node.get('id')} should be a dictionary" + ) return for tweak_name, tweak_value in node_tweaks.items(): @@ -289,7 +320,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None: vertex.params[tweak_name] = tweak_value -def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]: +def process_tweaks( + graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] +) -> Dict[str, Any]: """ This function is used to tweak the graph data using the node id and the tweaks dict. @@ -310,7 +343,9 @@ def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]] if node_tweaks := tweaks.get(node_id): apply_tweaks(node, node_tweaks) else: - logger.warning("Each node should be a dictionary with an 'id' key of type str") + logger.warning( + "Each node should be a dictionary with an 'id' key of type str" + ) return graph_data @@ -322,6 +357,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]): if node_tweaks := tweaks.get(node_id): apply_tweaks_on_vertex(vertex, node_tweaks) else: - logger.warning("Each node should be a Vertex with an 'id' attribute of type str") + logger.warning( + "Each node should be a Vertex with an 'id' attribute of type str" + ) return graph diff --git a/src/backend/langflow/template/frontend_node/base.py b/src/backend/langflow/template/frontend_node/base.py index 9f62c7054..cc8356104 100644 --- a/src/backend/langflow/template/frontend_node/base.py +++ b/src/backend/langflow/template/frontend_node/base.py @@ -49,10 +49,10 @@ class FrontendNode(BaseModel): """Icon of the frontend node.""" is_input: Optional[bool] = None """Whether the frontend node is used as an input when processing the Graph. - If True, there should be a field named 'message'.""" + If True, there should be a field named 'input_value'.""" is_output: Optional[bool] = None """Whether the frontend node is used as an output when processing the Graph. - If True, there should be a field named 'message'.""" + If True, there should be a field named 'input_value'.""" is_composition: Optional[bool] = None """Whether the frontend node is used for composition.""" base_classes: List[str]