Merge branch 'zustand/io/migration' of github.com:logspace-ai/langflow into zustand/io/migration
This commit is contained in:
commit
952561d628
29 changed files with 317 additions and 103 deletions
|
|
@ -3,11 +3,15 @@ from typing import Annotated, Any, List, Optional, Union
|
|||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
from langflow.api.utils import update_frontend_node_with_template_values
|
||||
from langflow.api.v1.schemas import (
|
||||
CustomComponentCode,
|
||||
PreloadResponse,
|
||||
ProcessResponse,
|
||||
RunResponse,
|
||||
TaskResponse,
|
||||
TaskStatusResponse,
|
||||
UploadFileResponse,
|
||||
|
|
@ -15,15 +19,23 @@ from langflow.api.v1.schemas import (
|
|||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.interface.custom.utils import build_custom_component_template
|
||||
from langflow.processing.process import build_graph_and_generate_result, process_graph_cached, process_tweaks
|
||||
from langflow.processing.process import (
|
||||
build_graph_and_generate_result,
|
||||
process_graph_cached,
|
||||
process_tweaks,
|
||||
run_graph,
|
||||
)
|
||||
from langflow.services.auth.utils import api_key_security, get_current_active_user
|
||||
from langflow.services.cache.utils import save_uploaded_file
|
||||
from langflow.services.database.models.flow import Flow
|
||||
from langflow.services.database.models.user.model import User
|
||||
from langflow.services.deps import get_session, get_session_service, get_settings_service, get_task_service
|
||||
from langflow.services.deps import (
|
||||
get_session,
|
||||
get_session_service,
|
||||
get_settings_service,
|
||||
get_task_service,
|
||||
)
|
||||
from langflow.services.session.service import SessionService
|
||||
from loguru import logger
|
||||
from sqlmodel import select
|
||||
|
||||
try:
|
||||
from langflow.worker import process_graph_cached_task
|
||||
|
|
@ -33,9 +45,10 @@ except ImportError:
|
|||
raise NotImplementedError("Celery is not installed")
|
||||
|
||||
|
||||
from langflow.services.task.service import TaskService
|
||||
from sqlmodel import Session
|
||||
|
||||
from langflow.services.task.service import TaskService
|
||||
|
||||
# build router
|
||||
router = APIRouter(tags=["Base"])
|
||||
|
||||
|
|
@ -80,9 +93,15 @@ async def process_graph_data(
|
|||
)
|
||||
if session_id is None:
|
||||
# Generate a session ID
|
||||
session_id = get_session_service().generate_key(session_id=session_id, data_graph=graph_data)
|
||||
session_id = get_session_service().generate_key(
|
||||
session_id=session_id, data_graph=graph_data
|
||||
)
|
||||
task_id, task = await task_service.launch_task(
|
||||
process_graph_cached_task if task_service.use_celery else process_graph_cached,
|
||||
(
|
||||
process_graph_cached_task
|
||||
if task_service.use_celery
|
||||
else process_graph_cached
|
||||
),
|
||||
graph_data,
|
||||
inputs,
|
||||
clear_cache,
|
||||
|
|
@ -176,7 +195,11 @@ async def preload_flow(
|
|||
else:
|
||||
if session_id is None:
|
||||
session_id = flow_id
|
||||
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
|
||||
flow = session.exec(
|
||||
select(Flow)
|
||||
.where(Flow.id == flow_id)
|
||||
.where(Flow.user_id == api_key_user.id)
|
||||
).first()
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
|
|
@ -197,6 +220,76 @@ async def preload_flow(
|
|||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
@router.post("/run/{flow_id}", response_model=ProcessResponse)
|
||||
async def run_flow_with_caching(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
inputs: Optional[Union[List[dict], dict]] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
api_key_user: User = Depends(api_key_security),
|
||||
session_service: SessionService = Depends(get_session_service),
|
||||
):
|
||||
try:
|
||||
if session_id:
|
||||
session_data = await session_service.load_session(session_id)
|
||||
graph, artifacts = session_data if session_data else (None, None)
|
||||
task_result: Any = None
|
||||
task_status = None
|
||||
if not graph:
|
||||
raise ValueError("Graph not found in the session")
|
||||
task_result = await run_graph(
|
||||
graph,
|
||||
session_id,
|
||||
inputs,
|
||||
artifacts=artifacts,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
else:
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
|
||||
flow = session.exec(
|
||||
select(Flow)
|
||||
.where(Flow.id == flow_id)
|
||||
.where(Flow.user_id == api_key_user.id)
|
||||
).first()
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
if flow.data is None:
|
||||
raise ValueError(f"Flow {flow_id} has no data")
|
||||
graph_data = flow.data
|
||||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
task_result = await run_graph(
|
||||
graph_data,
|
||||
inputs,
|
||||
tweaks,
|
||||
session_id,
|
||||
session_service=session_service,
|
||||
)
|
||||
|
||||
return RunResponse(
|
||||
outputs=task_result, session_id=session_id, status=task_status
|
||||
)
|
||||
except sa.exc.StatementError as exc:
|
||||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
# This means the Flow ID is not a valid UUID which means it can't find the flow
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
if f"Flow {flow_id} not found" in str(exc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/predict/{flow_id}",
|
||||
response_model=ProcessResponse,
|
||||
|
|
@ -269,7 +362,11 @@ async def process(
|
|||
|
||||
# Get the flow that matches the flow_id and belongs to the user
|
||||
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
|
||||
flow = session.exec(select(Flow).where(Flow.id == flow_id).where(Flow.user_id == api_key_user.id)).first()
|
||||
flow = session.exec(
|
||||
select(Flow)
|
||||
.where(Flow.id == flow_id)
|
||||
.where(Flow.user_id == api_key_user.id)
|
||||
).first()
|
||||
if flow is None:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
|
||||
|
|
@ -289,12 +386,18 @@ async def process(
|
|||
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
|
||||
if "badly formed hexadecimal UUID string" in str(exc):
|
||||
# This means the Flow ID is not a valid UUID which means it can't find the flow
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
except ValueError as exc:
|
||||
if f"Flow {flow_id} not found" in str(exc):
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
|
||||
) from exc
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)) from exc
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
|
||||
) from exc
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
@ -364,12 +467,16 @@ async def custom_component(
|
|||
|
||||
built_frontend_node = build_custom_component_template(component, user_id=user.id)
|
||||
|
||||
built_frontend_node = update_frontend_node_with_template_values(built_frontend_node, raw_code.frontend_node)
|
||||
built_frontend_node = update_frontend_node_with_template_values(
|
||||
built_frontend_node, raw_code.frontend_node
|
||||
)
|
||||
return built_frontend_node
|
||||
|
||||
|
||||
@router.post("/custom_component/reload", status_code=HTTPStatus.OK)
|
||||
async def reload_custom_component(path: str, user: User = Depends(get_current_active_user)):
|
||||
async def reload_custom_component(
|
||||
path: str, user: User = Depends(get_current_active_user)
|
||||
):
|
||||
from langflow.interface.custom.utils import build_custom_component_template
|
||||
|
||||
try:
|
||||
|
|
@ -391,6 +498,8 @@ async def custom_component_update(
|
|||
):
|
||||
component = CustomComponent(code=raw_code.code)
|
||||
|
||||
component_node = build_custom_component_template(component, user_id=user.id, update_field=raw_code.field)
|
||||
component_node = build_custom_component_template(
|
||||
component, user_id=user.id, update_field=raw_code.field
|
||||
)
|
||||
# Update the field
|
||||
return component_node
|
||||
|
|
|
|||
|
|
@ -66,6 +66,14 @@ class ProcessResponse(BaseModel):
|
|||
backend: Optional[str] = None
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
"""Run response schema."""
|
||||
|
||||
outputs: Optional[List[Any]] = None
|
||||
status: Optional[str] = None
|
||||
session_id: Optional[str] = None
|
||||
|
||||
|
||||
class PreloadResponse(BaseModel):
|
||||
"""Preload response schema."""
|
||||
|
||||
|
|
@ -73,9 +81,6 @@ class PreloadResponse(BaseModel):
|
|||
is_clear: Optional[bool] = None
|
||||
|
||||
|
||||
# TaskStatusResponse(
|
||||
# status=task.status, result=task.result if task.ready() else None
|
||||
# )
|
||||
class TaskStatusResponse(BaseModel):
|
||||
"""Task status response schema."""
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class ConversationChainComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
llm: BaseLanguageModel,
|
||||
memory: Optional[BaseMemory] = None,
|
||||
) -> Text:
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class LLMCheckerChainComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
llm: BaseLanguageModel,
|
||||
) -> Text:
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class LLMMathChainComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: Text,
|
||||
input_value: Text,
|
||||
llm: BaseLanguageModel,
|
||||
llm_chain: LLMChain,
|
||||
input_key: str = "question",
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class RetrievalQAComponent(CustomComponent):
|
|||
self,
|
||||
combine_documents_chain: BaseCombineDocumentsChain,
|
||||
retriever: BaseRetriever,
|
||||
inputs: str = "",
|
||||
input_value: str = "",
|
||||
memory: Optional[BaseMemory] = None,
|
||||
input_key: str = "query",
|
||||
output_key: str = "result",
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
retriever: BaseRetriever,
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str,
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class SQLGeneratorComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: Text,
|
||||
input_value: Text,
|
||||
db: SQLDatabase,
|
||||
llm: BaseLanguageModel,
|
||||
top_k: int = 5,
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ class ChatInput(CustomComponent):
|
|||
|
||||
def build_config(self):
|
||||
return {
|
||||
"message": {
|
||||
"input_value": {
|
||||
"input_types": ["Text"],
|
||||
"display_name": "Message",
|
||||
"multiline": True,
|
||||
|
|
@ -35,26 +35,26 @@ class ChatInput(CustomComponent):
|
|||
self,
|
||||
sender: Optional[str] = "User",
|
||||
sender_name: Optional[str] = "User",
|
||||
message: Optional[str] = None,
|
||||
input_value: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
return_record: Optional[bool] = False,
|
||||
) -> Union[Text, Record]:
|
||||
if return_record:
|
||||
if isinstance(message, Record):
|
||||
if isinstance(input_value, Record):
|
||||
# Update the data of the record
|
||||
message.data["sender"] = sender
|
||||
message.data["sender_name"] = sender_name
|
||||
message.data["session_id"] = session_id
|
||||
input_value.data["sender"] = sender
|
||||
input_value.data["sender_name"] = sender_name
|
||||
input_value.data["session_id"] = session_id
|
||||
else:
|
||||
message = Record(
|
||||
text=message,
|
||||
input_value = Record(
|
||||
text=input_value,
|
||||
data={
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
if not message:
|
||||
message = ""
|
||||
self.status = message
|
||||
return message
|
||||
if not input_value:
|
||||
input_value = ""
|
||||
self.status = input_value
|
||||
return input_value
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ class ChatOutput(CustomComponent):
|
|||
|
||||
def build_config(self):
|
||||
return {
|
||||
"message": {"input_types": ["Text"], "display_name": "Message"},
|
||||
"input_value": {"input_types": ["Text"], "display_name": "Message"},
|
||||
"sender": {
|
||||
"options": ["Machine", "User"],
|
||||
"display_name": "Sender Type",
|
||||
|
|
@ -39,25 +39,25 @@ class ChatOutput(CustomComponent):
|
|||
sender: Optional[str] = "Machine",
|
||||
sender_name: Optional[str] = "AI",
|
||||
session_id: Optional[str] = None,
|
||||
message: Optional[str] = None,
|
||||
input_value: Optional[str] = None,
|
||||
return_record: Optional[bool] = False,
|
||||
) -> Union[Text, Record]:
|
||||
if return_record:
|
||||
if isinstance(message, Record):
|
||||
if isinstance(input_value, Record):
|
||||
# Update the data of the record
|
||||
message.data["sender"] = sender
|
||||
message.data["sender_name"] = sender_name
|
||||
message.data["session_id"] = session_id
|
||||
input_value.data["sender"] = sender
|
||||
input_value.data["sender_name"] = sender_name
|
||||
input_value.data["session_id"] = session_id
|
||||
else:
|
||||
message = Record(
|
||||
text=message,
|
||||
input_value = Record(
|
||||
text=input_value,
|
||||
data={
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
"session_id": session_id,
|
||||
},
|
||||
)
|
||||
if not message:
|
||||
message = ""
|
||||
self.status = message
|
||||
return message
|
||||
if not input_value:
|
||||
input_value = ""
|
||||
self.status = input_value
|
||||
return input_value
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class AmazonBedrockComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
model_id: str = "anthropic.claude-instant-v1",
|
||||
credentials_profile_name: Optional[str] = None,
|
||||
region_name: Optional[str] = None,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ from langflow.field_typing import Text
|
|||
|
||||
class AnthropicLLM(CustomComponent):
|
||||
display_name: str = "AnthropicModel"
|
||||
description: str = "Generate text using Anthropic Chat&Completion large language models."
|
||||
description: str = (
|
||||
"Generate text using Anthropic Chat&Completion large language models."
|
||||
)
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
|
|
@ -53,7 +55,7 @@ class AnthropicLLM(CustomComponent):
|
|||
def build(
|
||||
self,
|
||||
model: str,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
anthropic_api_key: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
|
|
@ -66,7 +68,9 @@ class AnthropicLLM(CustomComponent):
|
|||
try:
|
||||
output = ChatAnthropic(
|
||||
model_name=model,
|
||||
anthropic_api_key=(SecretStr(anthropic_api_key) if anthropic_api_key else None),
|
||||
anthropic_api_key=(
|
||||
SecretStr(anthropic_api_key) if anthropic_api_key else None
|
||||
),
|
||||
max_tokens_to_sample=max_tokens, # type: ignore
|
||||
temperature=temperature,
|
||||
anthropic_api_url=api_endpoint,
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ from langflow import CustomComponent
|
|||
class AzureChatOpenAIComponent(CustomComponent):
|
||||
display_name: str = "AzureOpenAI Model"
|
||||
description: str = "Generate text using LLM model from Azure OpenAI."
|
||||
documentation: str = "https://python.langchain.com/docs/integrations/llms/azure_openai"
|
||||
documentation: str = (
|
||||
"https://python.langchain.com/docs/integrations/llms/azure_openai"
|
||||
)
|
||||
beta = False
|
||||
|
||||
AZURE_OPENAI_MODELS = [
|
||||
|
|
@ -78,7 +80,7 @@ class AzureChatOpenAIComponent(CustomComponent):
|
|||
self,
|
||||
model: str,
|
||||
azure_endpoint: str,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
azure_deployment: str,
|
||||
api_key: str,
|
||||
api_version: str,
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class QianfanChatEndpointComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
model: str = "ERNIE-Bot-turbo",
|
||||
qianfan_ak: Optional[str] = None,
|
||||
qianfan_sk: Optional[str] = None,
|
||||
|
|
|
|||
|
|
@ -35,11 +35,13 @@ class CTransformersComponent(CustomComponent):
|
|||
self,
|
||||
model: str,
|
||||
model_file: str,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
model_type: str,
|
||||
config: Optional[Dict] = None,
|
||||
) -> Text:
|
||||
output = CTransformers(model=model, model_file=model_file, model_type=model_type, config=config)
|
||||
output = CTransformers(
|
||||
model=model, model_file=model_file, model_type=model_type, config=config
|
||||
)
|
||||
message = output.invoke(inputs)
|
||||
result = message.content if hasattr(message, "content") else message
|
||||
self.status = result
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class CohereComponent(CustomComponent):
|
|||
def build(
|
||||
self,
|
||||
cohere_api_key: str,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
max_tokens: int = 256,
|
||||
temperature: float = 0.75,
|
||||
) -> Text:
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ class GoogleGenerativeAIComponent(CustomComponent):
|
|||
self,
|
||||
google_api_key: str,
|
||||
model: str,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
temperature: float = 0.1,
|
||||
top_k: Optional[int] = None,
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from langchain_community.chat_models.huggingface import ChatHuggingFace
|
|||
from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
from langflow.field_typing import Text
|
||||
|
||||
|
||||
|
|
@ -30,7 +29,7 @@ class HuggingFaceEndpointsComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
endpoint_url: str,
|
||||
task: str = "text2text-generation",
|
||||
huggingfacehub_api_token: Optional[str] = None,
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class LlamaCppComponent(CustomComponent):
|
|||
def build(
|
||||
self,
|
||||
model_path: str,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
grammar: Optional[str] = None,
|
||||
cache: Optional[bool] = None,
|
||||
client: Optional[Any] = None,
|
||||
|
|
|
|||
|
|
@ -171,7 +171,7 @@ class ChatOllamaComponent(CustomComponent):
|
|||
self,
|
||||
base_url: Optional[str],
|
||||
model: str,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
mirostat: Optional[str],
|
||||
mirostat_eta: Optional[float] = None,
|
||||
mirostat_tau: Optional[float] = None,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import NestedDict, Text
|
||||
|
||||
|
|
@ -60,7 +61,7 @@ class OpenAIModelComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: Text,
|
||||
input_value: Text,
|
||||
max_tokens: Optional[int] = 256,
|
||||
model_kwargs: NestedDict = {},
|
||||
model_name: str = "gpt-4-1106-preview",
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ class ChatVertexAIComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
credentials: Optional[str],
|
||||
project: str,
|
||||
examples: Optional[List[BaseMessage]] = [],
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class RunnableExecComponent(CustomComponent):
|
|||
def build(
|
||||
self,
|
||||
input_key: str,
|
||||
inputs: str,
|
||||
input_value: str,
|
||||
runnable: Runnable,
|
||||
output_key: str = "output",
|
||||
) -> Text:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from typing import List, Optional
|
|||
|
||||
import chromadb # type: ignore
|
||||
from langchain_community.vectorstores.chroma import Chroma
|
||||
|
||||
from langflow import CustomComponent
|
||||
from langflow.field_typing import Embeddings, Text
|
||||
from langflow.schema import Record, docs_to_records
|
||||
|
|
@ -57,7 +58,7 @@ class ChromaSearchComponent(CustomComponent):
|
|||
|
||||
def build(
|
||||
self,
|
||||
inputs: Text,
|
||||
input_value: Text,
|
||||
search_type: str,
|
||||
collection_name: str,
|
||||
embedding: Embeddings,
|
||||
|
|
@ -92,7 +93,8 @@ class ChromaSearchComponent(CustomComponent):
|
|||
|
||||
if chroma_server_host is not None:
|
||||
chroma_settings = chromadb.config.Settings(
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None,
|
||||
chroma_server_cors_allow_origins=chroma_server_cors_allow_origins
|
||||
or None,
|
||||
chroma_server_host=chroma_server_host,
|
||||
chroma_server_port=chroma_server_port or None,
|
||||
chroma_server_grpc_port=chroma_server_grpc_port or None,
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
from collections import defaultdict, deque
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List, Optional, Type, Union
|
||||
|
||||
|
|
@ -40,8 +41,10 @@ class Graph:
|
|||
self._runs = 0
|
||||
self._updates = 0
|
||||
self.flow_id = flow_id
|
||||
self._inputs = []
|
||||
self._outputs = []
|
||||
self._is_input_vertices = []
|
||||
self._is_output_vertices = []
|
||||
self._has_session_id_vertices = []
|
||||
self._sorted_vertices_layers = []
|
||||
|
||||
self.top_level_vertices = []
|
||||
for vertex in self._vertices:
|
||||
|
|
@ -54,38 +57,37 @@ class Graph:
|
|||
self.inactive_vertices = set()
|
||||
self._build_graph()
|
||||
self.build_graph_maps()
|
||||
self.define_inputs_and_outputs()
|
||||
self.define_vertices_lists()
|
||||
|
||||
def define_inputs_and_outputs(self):
|
||||
@property
|
||||
def sorted_vertices_layers(self):
|
||||
if not self._sorted_vertices_layers:
|
||||
self.sort_vertices()
|
||||
return self._sorted_vertices_layers
|
||||
|
||||
def define_vertices_lists(self):
|
||||
"""
|
||||
Defines the input and output vertices of the graph.
|
||||
Defines the lists of vertices that are inputs, outputs, and have session_id.
|
||||
"""
|
||||
attributes = ["is_input", "is_output", "has_session_id"]
|
||||
for vertex in self.vertices:
|
||||
if vertex.is_input:
|
||||
self._inputs.append(vertex.id)
|
||||
if vertex.is_output:
|
||||
self._outputs.append(vertex.id)
|
||||
for attribute in attributes:
|
||||
if getattr(vertex, attribute):
|
||||
getattr(self, f"_{attribute}_vertices").append(vertex.id)
|
||||
|
||||
def run(self, inputs: Dict[str, str]) -> List["ResultData"]:
|
||||
async def _run(self, inputs: Dict[str, str]) -> List["ResultData"]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
# inputs is {"message": "Hello, world!"}
|
||||
# we need to go through self.inputs and update the self._raw_params
|
||||
# of the vertices that are inputs
|
||||
|
||||
for vertex_id in self.inputs:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
if vertex is None:
|
||||
raise ValueError(f"Vertex {vertex_id} not found")
|
||||
vertex.update_raw_params(inputs)
|
||||
try:
|
||||
self.build()
|
||||
await self.process()
|
||||
self.increment_run_count()
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(f"Error running graph: {exc}") from exc
|
||||
|
||||
# Now we get the outputs from the self.outputs
|
||||
outputs = []
|
||||
for vertex_id in self.outputs:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
|
|
@ -94,6 +96,23 @@ class Graph:
|
|||
outputs.append(vertex.result)
|
||||
return outputs
|
||||
|
||||
async def run(self, inputs: Dict[str, Union[str, list[str]]]) -> List["ResultData"]:
|
||||
"""Runs the graph with the given inputs."""
|
||||
|
||||
# inputs is {"message": "Hello, world!"}
|
||||
# we need to go through self.inputs and update the self._raw_params
|
||||
# of the vertices that are inputs
|
||||
# if the value is a list, we need to run multiple times
|
||||
outputs = []
|
||||
inputs_values = inputs.get("input_value")
|
||||
if not isinstance(inputs_values, list):
|
||||
inputs_values = [inputs_values]
|
||||
for input_value in inputs_values:
|
||||
run_outputs = await self._run({"input_value": input_value})
|
||||
logger.debug(f"Run outputs: {run_outputs}")
|
||||
outputs.extend(run_outputs)
|
||||
return outputs
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return {
|
||||
|
|
@ -404,6 +423,36 @@ class Graph:
|
|||
raise ValueError("No root vertex found")
|
||||
return await root_vertex.build()
|
||||
|
||||
async def process(self) -> "Graph":
|
||||
"""Processes the graph with vertices in each layer run in parallel."""
|
||||
vertices_layers = self.sorted_vertices_layers
|
||||
|
||||
for layer_index, layer in enumerate(vertices_layers):
|
||||
tasks = []
|
||||
for vertex_id in layer:
|
||||
vertex = self.get_vertex(vertex_id)
|
||||
task = asyncio.create_task(
|
||||
vertex.build(), name=f"layer-{layer_index}-vertex-{vertex_id}"
|
||||
)
|
||||
tasks.append(task)
|
||||
logger.debug(f"Running layer {layer_index} with {len(tasks)} tasks")
|
||||
await self._execute_tasks(tasks)
|
||||
logger.debug("Graph processing complete")
|
||||
return self
|
||||
|
||||
async def _execute_tasks(self, tasks):
|
||||
"""Executes tasks in parallel, handling exceptions for each task."""
|
||||
results = []
|
||||
for task in asyncio.as_completed(tasks):
|
||||
try:
|
||||
result = await task
|
||||
results.append(result)
|
||||
except Exception as e:
|
||||
# Log the exception along with the task name for easier debugging
|
||||
task_name = task.get_name()
|
||||
logger.error(f"Task {task_name} failed with exception: {e}")
|
||||
return results
|
||||
|
||||
def topological_sort(self) -> List[Vertex]:
|
||||
"""
|
||||
Performs a topological sort of the vertices in the graph.
|
||||
|
|
@ -671,6 +720,7 @@ class Graph:
|
|||
vertices_layers = self.sort_by_avg_build_time(vertices_layers)
|
||||
vertices_layers = self.sort_chat_inputs_first(vertices_layers)
|
||||
self.increment_run_count()
|
||||
self._sorted_vertices_layers = vertices_layers
|
||||
return vertices_layers
|
||||
|
||||
def sort_interface_components_first(
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ if TYPE_CHECKING:
|
|||
from langflow.graph.graph.base import Graph
|
||||
|
||||
|
||||
class VertexStates(Enum):
|
||||
class VertexStates(str, Enum):
|
||||
"""Vertex are related to it being active, inactive, or in an error state."""
|
||||
|
||||
ACTIVE = "active"
|
||||
|
|
@ -53,6 +53,7 @@ class Vertex:
|
|||
output_component_name in self.id
|
||||
for output_component_name in OUTPUT_COMPONENTS
|
||||
)
|
||||
self.has_session_id = None
|
||||
self._custom_component = None
|
||||
self.has_external_input = False
|
||||
self.has_external_output = False
|
||||
|
|
@ -223,6 +224,8 @@ class Vertex:
|
|||
if isinstance(value, dict)
|
||||
}
|
||||
|
||||
self.has_session_id = "session_id" in template_dicts
|
||||
|
||||
self.required_inputs = [
|
||||
template_dicts[key]["type"]
|
||||
for key, value in template_dicts.items()
|
||||
|
|
|
|||
|
|
@ -47,10 +47,10 @@ class CustomComponent(Component):
|
|||
"""The icon of the component. It should be an emoji. Defaults to None."""
|
||||
is_input: Optional[bool] = None
|
||||
"""The input state of the component. Defaults to None.
|
||||
If True, the component must have a field named 'message'."""
|
||||
If True, the component must have a field named 'input_value'."""
|
||||
is_output: Optional[bool] = None
|
||||
"""The output state of the component. Defaults to None.
|
||||
If True, the component must have a field named 'message'."""
|
||||
If True, the component must have a field named 'input_value'."""
|
||||
code: Optional[str] = None
|
||||
"""The code of the component. Defaults to None."""
|
||||
field_config: dict = {}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ from langchain.schema import AgentAction, Document
|
|||
from langchain_community.vectorstores import VectorStore
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
|
@ -17,8 +20,6 @@ from langflow.interface.run import (
|
|||
)
|
||||
from langflow.services.deps import get_session_service
|
||||
from langflow.services.session.service import SessionService
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
|
|
@ -146,7 +147,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
|
|||
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
|
||||
result = await runnable.ainvoke(inputs)
|
||||
else:
|
||||
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
|
||||
raise ValueError(
|
||||
f"Runnable {runnable} does not support inputs of type {type(inputs)}"
|
||||
)
|
||||
# Check if the result is a list of AIMessages
|
||||
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
|
||||
result = [r.content for r in result]
|
||||
|
|
@ -155,7 +158,9 @@ async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
|
|||
return result
|
||||
|
||||
|
||||
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
|
||||
async def process_inputs_dict(
|
||||
built_object: Union[Chain, VectorStore, Runnable], inputs: dict
|
||||
):
|
||||
if isinstance(built_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
|
|
@ -190,7 +195,9 @@ async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
|
|||
return await process_runnable(built_object, inputs)
|
||||
|
||||
|
||||
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
|
||||
async def generate_result(
|
||||
built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]
|
||||
):
|
||||
if isinstance(inputs, dict):
|
||||
result = await process_inputs_dict(built_object, inputs)
|
||||
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
|
||||
|
|
@ -222,7 +229,9 @@ async def process_graph_cached(
|
|||
if clear_cache:
|
||||
session_service.clear_session(session_id)
|
||||
if session_id is None:
|
||||
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
|
||||
session_id = session_service.generate_key(
|
||||
session_id=session_id, data_graph=data_graph
|
||||
)
|
||||
# Load the graph using SessionService
|
||||
session = await session_service.load_session(session_id, data_graph)
|
||||
graph, artifacts = session if session else (None, None)
|
||||
|
|
@ -258,14 +267,34 @@ async def build_graph_and_generate_result(
|
|||
return Result(result=result, session_id=session_id)
|
||||
|
||||
|
||||
def validate_input(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
async def run_graph(
|
||||
graph: Union["Graph", dict],
|
||||
session_id: str,
|
||||
inputs: Optional[Union[dict, List[dict]]] = None,
|
||||
artifacts: Optional[Dict[str, Any]] = None,
|
||||
session_service: Optional[SessionService] = None,
|
||||
):
|
||||
"""Run the graph and generate the result"""
|
||||
if isinstance(graph, dict):
|
||||
graph = Graph.from_payload(graph)
|
||||
outputs = await graph.run(inputs)
|
||||
if session_id and session_service:
|
||||
session_service.update_session(session_id, (graph, artifacts))
|
||||
return outputs
|
||||
|
||||
|
||||
def validate_input(
|
||||
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
|
||||
raise ValueError("graph_data and tweaks should be dictionaries")
|
||||
|
||||
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
|
||||
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key")
|
||||
raise ValueError(
|
||||
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
|
@ -274,7 +303,9 @@ def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
|
|||
template_data = node.get("data", {}).get("node", {}).get("template")
|
||||
|
||||
if not isinstance(template_data, dict):
|
||||
logger.warning(f"Template data for node {node.get('id')} should be a dictionary")
|
||||
logger.warning(
|
||||
f"Template data for node {node.get('id')} should be a dictionary"
|
||||
)
|
||||
return
|
||||
|
||||
for tweak_name, tweak_value in node_tweaks.items():
|
||||
|
|
@ -289,7 +320,9 @@ def apply_tweaks_on_vertex(vertex: Vertex, node_tweaks: Dict[str, Any]) -> None:
|
|||
vertex.params[tweak_name] = tweak_value
|
||||
|
||||
|
||||
def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]) -> Dict[str, Any]:
|
||||
def process_tweaks(
|
||||
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
This function is used to tweak the graph data using the node id and the tweaks dict.
|
||||
|
||||
|
|
@ -310,7 +343,9 @@ def process_tweaks(graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
|||
if node_tweaks := tweaks.get(node_id):
|
||||
apply_tweaks(node, node_tweaks)
|
||||
else:
|
||||
logger.warning("Each node should be a dictionary with an 'id' key of type str")
|
||||
logger.warning(
|
||||
"Each node should be a dictionary with an 'id' key of type str"
|
||||
)
|
||||
|
||||
return graph_data
|
||||
|
||||
|
|
@ -322,6 +357,8 @@ def process_tweaks_on_graph(graph: Graph, tweaks: Dict[str, Dict[str, Any]]):
|
|||
if node_tweaks := tweaks.get(node_id):
|
||||
apply_tweaks_on_vertex(vertex, node_tweaks)
|
||||
else:
|
||||
logger.warning("Each node should be a Vertex with an 'id' attribute of type str")
|
||||
logger.warning(
|
||||
"Each node should be a Vertex with an 'id' attribute of type str"
|
||||
)
|
||||
|
||||
return graph
|
||||
|
|
|
|||
|
|
@ -49,10 +49,10 @@ class FrontendNode(BaseModel):
|
|||
"""Icon of the frontend node."""
|
||||
is_input: Optional[bool] = None
|
||||
"""Whether the frontend node is used as an input when processing the Graph.
|
||||
If True, there should be a field named 'message'."""
|
||||
If True, there should be a field named 'input_value'."""
|
||||
is_output: Optional[bool] = None
|
||||
"""Whether the frontend node is used as an output when processing the Graph.
|
||||
If True, there should be a field named 'message'."""
|
||||
If True, there should be a field named 'input_value'."""
|
||||
is_composition: Optional[bool] = None
|
||||
"""Whether the frontend node is used for composition."""
|
||||
base_classes: List[str]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue