Update code with new typings and bug fixes

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-28 17:21:05 -03:00
commit 44642b5a0e
50 changed files with 264 additions and 551 deletions

View file

@ -2,19 +2,9 @@ import time
import uuid
from typing import TYPE_CHECKING, Annotated, Optional
from fastapi import (
APIRouter,
BackgroundTasks,
Body,
Depends,
HTTPException,
WebSocket,
WebSocketException,
status,
)
from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException
from fastapi.responses import StreamingResponse
from loguru import logger
from sqlmodel import Session
from langflow.api.utils import (
build_and_cache_graph,
@ -28,11 +18,7 @@ from langflow.api.v1.schemas import (
VertexBuildResponse,
VerticesOrderResponse,
)
from langflow.graph.graph.base import Graph
from langflow.services.auth.utils import (
get_current_active_user,
get_current_user_for_websocket,
)
from langflow.services.auth.utils import get_current_active_user
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session, get_session_service
from langflow.services.monitor.utils import log_vertex_build
@ -44,47 +30,6 @@ if TYPE_CHECKING:
router = APIRouter(tags=["Chat"])
@router.websocket("/chat/{client_id}")
async def chat(
client_id: str,
websocket: WebSocket,
db: Session = Depends(get_session),
chat_service: "ChatService" = Depends(get_chat_service),
):
"""Websocket endpoint for chat."""
try:
user = await get_current_user_for_websocket(websocket, db)
await websocket.accept()
if not user:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
elif not user.is_active:
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
if client_id in chat_service.cache_service:
await chat_service.handle_websocket(client_id, websocket)
else:
# We accept the connection but close it immediately
# if the flow is not built yet
message = "Please, build the flow before sending messages"
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=message)
except WebSocketException as exc:
logger.error(f"Websocket exrror: {exc}")
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
except Exception as exc:
logger.error(f"Error in chat websocket: {exc}")
messsage = exc.detail if isinstance(exc, HTTPException) else str(exc)
if "Could not validate credentials" in str(exc):
await websocket.close(
code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized"
)
else:
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage)
async def try_running_celery_task(vertex, user_id):
# Try running the task in celery
# and set the task_id to the local vertex
@ -113,7 +58,7 @@ async def get_vertices(
# First, we need to check if the flow_id is in the cache
graph = None
if cache := chat_service.get_cache(flow_id):
graph: Graph = cache.get("result")
graph = cache.get("result")
graph = build_and_cache_graph(flow_id, session, chat_service, graph)
if component_id:
try:
@ -141,7 +86,7 @@ async def build_vertex(
flow_id: str,
vertex_id: str,
background_tasks: BackgroundTasks,
inputs: Annotated[InputValueRequest, Body(embed=True)] = None,
inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None,
chat_service: "ChatService" = Depends(get_chat_service),
current_user=Depends(get_current_active_user),
):
@ -161,7 +106,7 @@ async def build_vertex(
)
else:
graph = cache.get("result")
result_data_response = {}
result_data_response = ResultDataResponse(results={})
duration = ""
vertex = graph.get_vertex(vertex_id)
@ -250,7 +195,9 @@ async def build_vertex_stream(
else:
graph = cache.get("result")
else:
session_data = await session_service.load_session(session_id)
session_data = await session_service.load_session(
session_id, flow_id=flow_id
)
graph, artifacts = session_data if session_data else (None, None)
if not graph:
raise ValueError(f"No graph found for {flow_id}.")

View file

@ -4,27 +4,20 @@ from typing import Annotated, Any, List, Optional, Union
import sqlalchemy as sa
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
from loguru import logger
from sqlmodel import select
from sqlmodel import Session, select
from langflow.api.utils import update_frontend_node_with_template_values
from langflow.api.v1.schemas import (
CustomComponentCode,
PreloadResponse,
ProcessResponse,
RunResponse,
TaskResponse,
TaskStatusResponse,
UploadFileResponse,
)
from langflow.interface.custom.custom_component import CustomComponent
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.interface.custom.utils import build_custom_component_template
from langflow.processing.process import (
build_graph_and_generate_result,
process_graph_cached,
process_tweaks,
run_graph,
)
from langflow.processing.process import process_tweaks, run_graph
from langflow.services.auth.utils import api_key_security, get_current_active_user
from langflow.services.cache.utils import save_uploaded_file
from langflow.services.database.models.flow import Flow
@ -36,98 +29,12 @@ from langflow.services.deps import (
get_task_service,
)
from langflow.services.session.service import SessionService
try:
from langflow.worker import process_graph_cached_task
except ImportError:
def process_graph_cached_task(*args, **kwargs):
raise NotImplementedError("Celery is not installed")
from sqlmodel import Session
from langflow.services.task.service import TaskService
# build router
router = APIRouter(tags=["Base"])
async def process_graph_data(
graph_data: dict,
inputs: Optional[Union[List[dict], dict]] = None,
tweaks: Optional[dict] = None,
clear_cache: bool = False,
session_id: Optional[str] = None,
task_service: "TaskService" = Depends(get_task_service),
sync: bool = True,
):
task_result: Any = None
task_status = None
if tweaks:
try:
graph_data = process_tweaks(graph_data, tweaks)
except Exception as exc:
logger.error(f"Error processing tweaks: {exc}")
if sync:
result = await process_graph_cached(
graph_data,
inputs,
clear_cache,
session_id,
)
task_id = str(id(result))
if isinstance(result, dict) and "result" in result:
task_result = result["result"]
session_id = result["session_id"]
elif hasattr(result, "result") and hasattr(result, "session_id"):
task_result = result.result
session_id = result.session_id
else:
task_result = result
else:
logger.warning(
"This is an experimental feature and may not work as expected."
"Please report any issues to our GitHub repository."
)
if session_id is None:
# Generate a session ID
session_id = get_session_service().generate_key(
session_id=session_id, data_graph=graph_data
)
task_id, task = await task_service.launch_task(
(
process_graph_cached_task
if task_service.use_celery
else process_graph_cached
),
graph_data,
inputs,
clear_cache,
session_id,
)
task_status = task.status
if task.status == "FAILURE":
logger.error(f"Task {task_id} failed: {task.traceback}")
task_result = str(task._exception)
else:
task_result = task.result
if task_id:
task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}")
else:
task_response = None
return ProcessResponse(
result=task_result,
status=task_status,
task=task_response,
session_id=session_id,
backend=task_service.backend_name,
)
@router.get("/all", dependencies=[Depends(get_current_active_user)])
def get_all(
settings_service=Depends(get_settings_service),
@ -141,85 +48,6 @@ def get_all(
raise HTTPException(status_code=500, detail=str(exc)) from exc
@router.post("/process/json", response_model=ProcessResponse)
async def process_json(
session: Annotated[Session, Depends(get_session)],
data: dict,
inputs: Optional[dict] = None,
tweaks: Optional[dict] = None,
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
task_service: "TaskService" = Depends(get_task_service),
sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821
):
try:
return await process_graph_data(
graph_data=data,
inputs=inputs,
tweaks=tweaks,
clear_cache=clear_cache,
session_id=session_id,
task_service=task_service,
sync=sync,
)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
# Endpoint to preload a graph
@router.post("/process/preload/{flow_id}", response_model=PreloadResponse)
async def preload_flow(
session: Annotated[Session, Depends(get_session)],
flow_id: str,
session_id: Optional[str] = None,
session_service: SessionService = Depends(get_session_service),
api_key_user: User = Depends(api_key_security),
clear_session: Annotated[bool, Body(embed=True)] = False, # noqa: F821
):
try:
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
if clear_session:
session_service.clear_session(session_id)
# Check if the session exists
session_data = await session_service.load_session(session_id)
# Session data is a tuple of (graph, artifacts)
# or (None, None) if the session is empty
if isinstance(session_data, tuple):
graph, artifacts = session_data
is_clear = graph is None and artifacts is None
else:
is_clear = session_data is None
return PreloadResponse(session_id=session_id, is_clear=is_clear)
else:
if session_id is None:
session_id = flow_id
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
if flow.data is None:
raise ValueError(f"Flow {flow_id} has no data")
graph_data = flow.data
session_service.clear_session(session_id)
# Load the graph using SessionService
session_data = await session_service.load_session(session_id, graph_data)
graph, artifacts = session_data if session_data else (None, None)
if not graph:
raise ValueError("Graph not found in the session")
_ = await graph.build()
session_service.update_session(session_id, (graph, artifacts))
return PreloadResponse(session_id=session_id)
except Exception as exc:
logger.exception(exc)
raise HTTPException(status_code=500, detail=str(exc)) from exc
@router.post(
"/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True
)
@ -235,7 +63,9 @@ async def run_flow_with_caching(
):
try:
if session_id:
session_data = await session_service.load_session(session_id)
session_data = await session_service.load_session(
session_id, flow_id=flow_id
)
graph, artifacts = session_data if session_data else (None, None)
task_result: Any = None
if not graph:
@ -264,7 +94,7 @@ async def run_flow_with_caching(
if flow.data is None:
raise ValueError(f"Flow {flow_id} has no data")
graph_data = flow.data
graph_data = process_tweaks(graph_data, tweaks)
graph_data = process_tweaks(graph_data, tweaks or {})
task_result, session_id = await run_graph(
graph=graph_data,
flow_id=flow_id,
@ -318,94 +148,16 @@ async def process(
"""
Endpoint to process an input with a given flow_id.
"""
try:
if session_id:
session_data = await session_service.load_session(session_id)
graph, artifacts = session_data if session_data else (None, None)
task_result: Any = None
task_status = None
task_id = None
if not graph:
raise ValueError("Graph not found in the session")
result = await build_graph_and_generate_result(
graph=graph,
inputs=inputs,
artifacts=artifacts,
session_id=session_id,
session_service=session_service,
)
task_id = str(id(result))
if isinstance(result, dict) and "result" in result:
task_result = result["result"]
session_id = result["session_id"]
elif hasattr(result, "result") and hasattr(result, "session_id"):
task_result = result.result
session_id = result.session_id
else:
task_result = result
if task_id:
task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}")
else:
task_response = None
return ProcessResponse(
result=task_result,
status=task_status,
task=task_response,
session_id=session_id,
backend=task_service.backend_name,
)
else:
if api_key_user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API Key",
)
# Get the flow that matches the flow_id and belongs to the user
# flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first()
flow = session.exec(
select(Flow)
.where(Flow.id == flow_id)
.where(Flow.user_id == api_key_user.id)
).first()
if flow is None:
raise ValueError(f"Flow {flow_id} not found")
if flow.data is None:
raise ValueError(f"Flow {flow_id} has no data")
graph_data = flow.data
return await process_graph_data(
graph_data=graph_data,
inputs=inputs,
tweaks=tweaks,
clear_cache=clear_cache,
session_id=session_id,
task_service=task_service,
sync=sync,
)
except sa.exc.StatementError as exc:
# StatementError('(builtins.ValueError) badly formed hexadecimal UUID string')
if "badly formed hexadecimal UUID string" in str(exc):
# This means the Flow ID is not a valid UUID which means it can't find the flow
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
except ValueError as exc:
if f"Flow {flow_id} not found" in str(exc):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)
) from exc
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc)
) from exc
except Exception as e:
# Log stack trace
logger.exception(e)
raise HTTPException(status_code=500, detail=str(e)) from e
# Raise a depreciation warning
logger.warning(
"The /process endpoint is deprecated and will be removed in a future version. "
"Please use /run instead."
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The /process endpoint is deprecated and will be removed in a future version. "
"Please use /run instead.",
)
@router.get("/task/{task_id}", response_model=TaskStatusResponse)

View file

@ -31,17 +31,18 @@ class ConversationChainComponent(CustomComponent):
chain = ConversationChain(llm=llm)
else:
chain = ConversationChain(llm=llm, memory=memory)
result = chain.invoke(input_value)
result = chain.invoke({chain.input_key: input_value})
# result is an AIMessage which is a subclass of BaseMessage
# We need to check if it is a string or a BaseMessage
result_str = ""
if hasattr(result, "content") and isinstance(result.content, str):
self.status = "is message"
result = result.content
result_str = result.content
elif isinstance(result, str):
self.status = "is_string"
result = result
result_str = result
else:
# is dict
result = result.get("response")
self.status = result
return result
result_str = result.get("response")
self.status = result_str
return result_str

View file

@ -23,7 +23,8 @@ class LLMCheckerChainComponent(CustomComponent):
) -> Text:
chain = LLMCheckerChain.from_llm(llm=llm)
response = chain.invoke({chain.input_key: inputs})
result = response.get(chain.output_key)
self.status = result
return result
response = chain.invoke({chain.input_key: input_value})
result = response.get(chain.output_key, "")
result_str = str(result)
self.status = result_str
return result_str

View file

@ -38,7 +38,8 @@ class LLMMathChainComponent(CustomComponent):
output_key=output_key,
memory=memory,
)
response = chain.invoke({input_key: inputs})
response = chain.invoke({input_key: input_value})
result = response.get(output_key)
self.status = result
return result
result_str = str(result)
self.status = result_str
return result_str

View file

@ -1,7 +1,7 @@
from typing import Callable, Optional, Union
from typing import Optional
from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
from langchain.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA
from langchain.chains.retrieval_qa.base import RetrievalQA
from langchain_core.documents import Document
from langflow import CustomComponent
@ -35,7 +35,7 @@ class RetrievalQAComponent(CustomComponent):
input_key: str = "query",
output_key: str = "result",
return_source_documents: bool = True,
) -> Union[BaseRetrievalQA, Callable, Text]:
) -> Text:
runnable = RetrievalQA(
combine_documents_chain=combine_documents_chain,
retriever=retriever,
@ -44,10 +44,10 @@ class RetrievalQAComponent(CustomComponent):
output_key=output_key,
return_source_documents=return_source_documents,
)
if isinstance(inputs, Document):
inputs = inputs.page_content
if isinstance(input_value, Document):
input_value = input_value.page_content
self.status = runnable
result = runnable.invoke({input_key: inputs})
result = runnable.invoke({input_key: input_value})
result = result.content if hasattr(result, "content") else result
# Result is a dict with keys "query", "result" and "source_documents"
# for now we just return the result
@ -55,7 +55,8 @@ class RetrievalQAComponent(CustomComponent):
references_str = ""
if return_source_documents:
references_str = self.create_references_from_records(records)
result_str = result.get("result")
final_result = "\n".join([result_str, references_str])
result_str = result.get("result", "")
final_result = "\n".join([str(result_str), references_str])
self.status = final_result
return final_result

View file

@ -40,11 +40,11 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
return_source_documents=return_source_documents,
retriever=retriever,
)
if isinstance(inputs, Document):
inputs = inputs.page_content
if isinstance(input_value, Document):
input_value = input_value.page_content
self.status = runnable
input_key = runnable.input_keys[0]
result = runnable.invoke({input_key: inputs})
result = runnable.invoke({input_key: input_value})
result = result.content if hasattr(result, "content") else result
# Result is a dict with keys "query", "result" and "source_documents"
# for now we just return the result
@ -52,7 +52,7 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent):
references_str = ""
if return_source_documents:
references_str = self.create_references_from_records(records)
result_str = result.get("answer")
result_str = str(result.get("answer", ""))
final_result = "\n".join([result_str, references_str])
self.status = final_result
return final_result

View file

@ -3,6 +3,7 @@ from typing import Optional
from langchain.chains import create_sql_query_chain
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import Runnable
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel, Text
@ -39,33 +40,27 @@ class SQLGeneratorComponent(CustomComponent):
else:
prompt_template = None
if top_k > 0:
kwargs = {
"k": top_k,
}
if top_k < 1:
raise ValueError("Top K must be greater than 0.")
if not prompt_template:
sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs)
sql_query_chain = create_sql_query_chain(llm=llm, db=db, k=top_k)
else:
template = (
prompt_template.template
if hasattr(prompt, "template")
else prompt_template
)
# Check if {question} is in the prompt
if (
"{question}" not in template
or "question" not in template.input_variables
"{question}" not in prompt_template.template
or "question" not in prompt_template.input_variables
):
raise ValueError(
"Prompt must contain `{question}` to be used with Natural Language to SQL."
)
sql_query_chain = create_sql_query_chain(
llm=llm, db=db, prompt=prompt_template, **kwargs
llm=llm, db=db, prompt=prompt_template, k=top_k
)
query_writer = sql_query_chain | {
query_writer: Runnable = sql_query_chain | {
"query": lambda x: x.replace("SQLQuery:", "").strip()
}
response = query_writer.invoke({"question": inputs})
response = query_writer.invoke({"question": input_value})
query = response.get("query")
self.status = query
return query

View file

@ -76,9 +76,11 @@ class GatherRecordsComponent(CustomComponent):
return file_paths
def parse_file_to_record(self, file_path: str, silent_errors: bool) -> Record:
def parse_file_to_record(
self, file_path: str, silent_errors: bool
) -> Optional[Record]:
# Use the partition function to load the file
from unstructured.partition.auto import partition
from unstructured.partition.auto import partition # type: ignore
try:
elements = partition(file_path)
@ -115,13 +117,14 @@ class GatherRecordsComponent(CustomComponent):
def parallel_load_records(
self, file_paths: List[str], silent_errors: bool, max_concurrency: int
) -> List[Record]:
) -> List[Optional[Record]]:
with futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
loaded_files = executor.map(
lambda file_path: self.parse_file_to_record(file_path, silent_errors),
file_paths,
)
return loaded_files
# loaded_files is an iterator, so we need to convert it to a list
return list(loaded_files)
def build(
self,

View file

@ -48,11 +48,15 @@ class StoreMessages(CustomComponent):
# and the other parameters
if not texts and not records:
raise ValueError("Either texts or records must be provided.")
if not texts:
texts = []
if not records:
records = []
if not session_id or not sender or not sender_name:
raise ValueError("If passing texts, session_id, sender, and sender_name must be provided.")
raise ValueError(
"If passing texts, session_id, sender, and sender_name must be provided."
)
for text in texts:
record = Record(
text=text,
@ -68,4 +72,4 @@ class StoreMessages(CustomComponent):
self.status = records
records = add_messages(records)
return records
return records or []

View file

@ -35,7 +35,7 @@ class ChatComponent(CustomComponent):
def store_message(
self,
message: Union[Text, Record],
message: Union[str, Text, Record],
session_id: Optional[str] = None,
sender: Optional[str] = None,
sender_name: Optional[str] = None,

View file

@ -1,6 +1,7 @@
from typing import Any, Callable, Dict, Optional, Union
from langchain_community.chat_models.litellm import ChatLiteLLM, ChatLiteLLMException
from langflow import CustomComponent
from langflow.field_typing import BaseLanguageModel
@ -126,7 +127,8 @@ class ChatLiteLLMComponent(CustomComponent):
litellm.set_verbose = verbose
except ImportError:
raise ChatLiteLLMException(
"Could not import litellm python package. " "Please install it with `pip install litellm`"
"Could not import litellm python package. "
"Please install it with `pip install litellm`"
)
provider_map = {
"OpenAI": "openai_api_key",
@ -137,11 +139,17 @@ class ChatLiteLLMComponent(CustomComponent):
"OpenRouter": "openrouter_api_key",
}
# Set the API key based on the provider
kwarg = {provider_map[provider]: api_key}
api_keys = {v: None for v in provider_map.values()}
if variable_name := provider_map.get(provider):
api_keys[variable_name] = api_key
else:
raise ChatLiteLLMException(
f"Provider {provider} is not supported. Supported providers are: {', '.join(provider_map.keys())}"
)
LLM = ChatLiteLLM(
model=model,
client=None,
streaming=streaming,
temperature=temperature,
model_kwargs=model_kwargs if model_kwargs is not None else {},
@ -150,6 +158,11 @@ class ChatLiteLLMComponent(CustomComponent):
n=n,
max_tokens=max_tokens,
max_retries=max_retries,
**kwarg,
openai_api_key=api_keys["openai_api_key"],
azure_api_key=api_keys["azure_api_key"],
anthropic_api_key=api_keys["anthropic_api_key"],
replicate_api_key=api_keys["replicate_api_key"],
cohere_api_key=api_keys["cohere_api_key"],
openrouter_api_key=api_keys["openrouter_api_key"],
)
return LLM

View file

@ -4,6 +4,7 @@ from langchain.llms.base import BaseLanguageModel
from langchain_openai import AzureChatOpenAI
from langflow.components.models.base.model import LCModelComponent
from pydantic.v1 import SecretStr
class AzureChatOpenAIComponent(LCModelComponent):
@ -93,13 +94,14 @@ class AzureChatOpenAIComponent(LCModelComponent):
max_tokens: Optional[int] = 1000,
stream: bool = False,
) -> BaseLanguageModel:
secret_api_key = SecretStr(api_key)
try:
output = AzureChatOpenAI(
model=model,
azure_endpoint=azure_endpoint,
azure_deployment=azure_deployment,
api_version=api_version,
api_key=api_key,
api_key=secret_api_key,
temperature=temperature,
max_tokens=max_tokens,
)

View file

@ -41,11 +41,15 @@ class CTransformersComponent(LCModelComponent):
model_file: str,
input_value: str,
model_type: str,
stream: bool = False,
config: Optional[Dict] = None,
stream: Optional[bool] = False,
) -> Text:
output = CTransformers(
model=model, model_file=model_file, model_type=model_type, config=config
client=None,
model=model,
model_file=model_file,
model_type=model_type,
config=config, # noqa
)
return self.get_result(output=output, stream=stream, input_value=input_value)

View file

@ -41,13 +41,11 @@ class CohereComponent(LCModelComponent):
self,
cohere_api_key: str,
input_value: str,
max_tokens: int = 256,
temperature: float = 0.75,
stream: bool = False,
) -> Text:
output = ChatCohere(
output = ChatCohere( # type: ignore
cohere_api_key=cohere_api_key,
max_tokens=max_tokens,
temperature=temperature,
)
return self.get_result(output=output, stream=stream, input_value=input_value)

View file

@ -36,17 +36,19 @@ class HuggingFaceEndpointsComponent(LCModelComponent):
self,
input_value: str,
endpoint_url: str,
model: Optional[str] = None,
task: str = "text2text-generation",
huggingfacehub_api_token: Optional[str] = None,
model_kwargs: Optional[dict] = None,
stream: bool = False,
) -> Text:
try:
llm = HuggingFaceEndpoint(
llm = HuggingFaceEndpoint( # type: ignore
endpoint_url=endpoint_url,
task=task,
huggingfacehub_api_token=huggingfacehub_api_token,
model_kwargs=model_kwargs,
model_kwargs=model_kwargs or {},
model=model or "",
)
except Exception as e:
raise ValueError("Could not connect to HuggingFace Endpoints API.") from e

View file

@ -203,7 +203,7 @@ class ChatOllamaComponent(LCModelComponent):
timeout: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
stream: Optional[bool] = False,
stream: bool = False,
) -> Text:
if not base_url:
base_url = "http://localhost:11434"

View file

@ -1,6 +1,7 @@
from typing import Optional
from langchain_openai import ChatOpenAI
from pydantic.v1 import SecretStr
from langflow.components.models.base.model import LCModelComponent
from langflow.field_typing import NestedDict, Text
@ -73,16 +74,20 @@ class OpenAIModelComponent(LCModelComponent):
openai_api_base: Optional[str] = None,
openai_api_key: Optional[str] = None,
temperature: float = 0.7,
stream: Optional[bool] = False,
stream: bool = False,
) -> Text:
if not openai_api_base:
openai_api_base = "https://api.openai.com/v1"
if openai_api_key:
secret_key = SecretStr(openai_api_key)
else:
secret_key = None
output = ChatOpenAI(
max_tokens=max_tokens,
model_kwargs=model_kwargs,
model=model_name,
base_url=openai_api_base,
api_key=openai_api_key,
api_key=secret_key,
temperature=temperature,
)

View file

@ -82,7 +82,7 @@ class ChatVertexAIComponent(LCModelComponent):
stream: bool = False,
) -> Text:
try:
from langchain_google_vertexai import ChatVertexAI
from langchain_google_vertexai import ChatVertexAI # type: ignore
except ImportError:
raise ImportError(
"To use the ChatVertexAI model, you need to install the langchain-google-vertexai package."

View file

@ -1,4 +1,5 @@
from langchain_core.prompts import PromptTemplate
from langflow import CustomComponent
from langflow.field_typing import Prompt, TemplateField, Text
@ -19,7 +20,7 @@ class PromptComponent(CustomComponent):
template: Prompt,
**kwargs,
) -> Text:
prompt_template = PromptTemplate.from_template(template)
prompt_template = PromptTemplate.from_template(str(template))
attributes_to_check = ["text", "page_content"]
for key, value in kwargs.items():

View file

@ -23,7 +23,7 @@ class ShouldRunNext(CustomComponent):
def build(self, template: Prompt, llm: BaseLanguageModel, **kwargs) -> dict:
# This is a simple component that always returns True
prompt_template = PromptTemplate.from_template(template)
prompt_template = PromptTemplate.from_template(str(template))
attributes_to_check = ["text", "page_content"]
for key, value in kwargs.items():
@ -41,7 +41,9 @@ class ShouldRunNext(CustomComponent):
result = result.get("response")
if result.lower() not in ["true", "false"]:
raise ValueError("The prompt should generate a boolean response (True or False).")
raise ValueError(
"The prompt should generate a boolean response (True or False)."
)
# The string should be the words true or false
# if not raise an error
bool_result = result.lower() == "true"

View file

@ -95,8 +95,8 @@ class ChromaComponent(CustomComponent):
# If documents, then we need to create a Chroma instance using .from_documents
# Check index_directory and expand it if it is a relative path
index_directory = self.resolve_path(index_directory)
if index_directory is not None:
index_directory = self.resolve_path(index_directory)
if documents is not None and embedding is not None:
if len(documents) == 0:

View file

@ -100,7 +100,8 @@ class ChromaSearchComponent(LCVectorStoreComponent):
chroma_server_grpc_port=chroma_server_grpc_port or None,
chroma_server_ssl_enabled=chroma_server_ssl_enabled,
)
index_directory = self.resolve_path(index_directory)
if index_directory:
index_directory = self.resolve_path(index_directory)
vector_store = Chroma(
embedding_function=embedding,
collection_name=collection_name,

View file

@ -36,3 +36,4 @@ class FAISSComponent(CustomComponent):
raise ValueError("Folder path is required to save the FAISS index.")
path = self.resolve_path(folder_path)
vector_store.save_local(str(path), index_name)
return vector_store

View file

@ -27,7 +27,7 @@ class MongoDBAtlasComponent(CustomComponent):
def build(
self,
embedding: Embeddings,
documents: List[Document] = None,
documents: List[Document],
collection_name: str = "",
db_name: str = "",
index_name: str = "",
@ -35,11 +35,22 @@ class MongoDBAtlasComponent(CustomComponent):
search_kwargs: Optional[NestedDict] = None,
) -> MongoDBAtlasVectorSearch:
search_kwargs = search_kwargs or {}
try:
from pymongo import MongoClient
except ImportError:
raise ImportError(
"Please install pymongo to use MongoDB Atlas Vector Store"
)
try:
mongo_client: MongoClient = MongoClient(mongodb_atlas_cluster_uri)
collection = mongo_client[db_name][collection_name]
except Exception as e:
raise ValueError(f"Failed to connect to MongoDB Atlas: {e}")
if documents:
vector_store = MongoDBAtlasVectorSearch.from_documents(
documents=documents,
embedding=embedding,
collection_name=collection_name,
collection=collection,
db_name=db_name,
index_name=index_name,
mongodb_atlas_cluster_uri=mongodb_atlas_cluster_uri,
@ -48,10 +59,7 @@ class MongoDBAtlasComponent(CustomComponent):
else:
vector_store = MongoDBAtlasVectorSearch(
embedding=embedding,
collection_name=collection_name,
db_name=db_name,
collection=collection,
index_name=index_name,
mongodb_atlas_cluster_uri=mongodb_atlas_cluster_uri,
search_kwargs=search_kwargs,
)
return vector_store

View file

@ -25,7 +25,7 @@ class MongoDBAtlasSearchComponent(MongoDBAtlasComponent, LCVectorStoreComponent)
"search_kwargs": {"display_name": "Search Kwargs", "advanced": True},
}
def build(
def build( # type: ignore[override]
self,
input_value: str,
search_type: str,

View file

@ -40,7 +40,7 @@ class PineconeSearchComponent(PineconeComponent, LCVectorStoreComponent):
},
}
def build(
def build( # type: ignore[override]
self,
input_value: str,
embedding: Embeddings,
@ -51,7 +51,7 @@ class PineconeSearchComponent(PineconeComponent, LCVectorStoreComponent):
pinecone_api_key: Optional[str] = None,
namespace: Optional[str] = "default",
search_type: str = "similarity",
) -> List[Record]:
) -> List[Record]: # type: ignore[override]
vector_store = super().build(
embedding=embedding,
pinecone_env=pinecone_env,

View file

@ -44,7 +44,7 @@ class QdrantSearchComponent(QdrantComponent, LCVectorStoreComponent):
"url": {"display_name": "URL", "advanced": True},
}
def build(
def build( # type: ignore[override]
self,
input_value: str,
embedding: Embeddings,
@ -65,7 +65,7 @@ class QdrantSearchComponent(QdrantComponent, LCVectorStoreComponent):
search_kwargs: Optional[NestedDict] = None,
timeout: Optional[int] = None,
url: Optional[str] = None,
) -> List[Record]:
) -> List[Record]: # type: ignore[override]
vector_store = super().build(
embedding=embedding,
collection_name=collection_name,

View file

@ -42,7 +42,7 @@ class RedisSearchComponent(RedisComponent, LCVectorStoreComponent):
"redis_index_name": {"display_name": "Redis Index", "advanced": False},
}
def build(
def build( # type: ignore[override]
self,
input_value: str,
search_type: str,

View file

@ -42,7 +42,7 @@ class VectaraSearchComponent(VectaraComponent, LCVectorStoreComponent):
},
}
def build(
def build( # type: ignore[override]
self,
input_value: str,
search_type: str,

View file

@ -55,7 +55,7 @@ class WeaviateSearchVectorStore(WeaviateVectorStoreComponent, LCVectorStoreCompo
"code": {"show": False},
}
def build(
def build( # type: ignore[override]
self,
input_value: str,
search_type: str,

View file

@ -1,9 +1,10 @@
from typing import List
from typing import List, Union
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.vectorstores import VectorStore
from langflow import CustomComponent
from langflow.field_typing import Text
from langflow.schema import Record, docs_to_records
@ -14,7 +15,10 @@ class LCVectorStoreComponent(CustomComponent):
beta: bool = True
def search_with_vector_store(
self, input_value: Text, search_type: str, vector_store: VectorStore
self,
input_value: str,
search_type: str,
vector_store: Union[VectorStore, BaseRetriever],
) -> List[Record]:
"""
Search for records in the vector store based on the input value and search type.
@ -31,8 +35,12 @@ class LCVectorStoreComponent(CustomComponent):
ValueError: If invalid inputs are provided.
"""
docs = []
if input_value and isinstance(input_value, str):
docs: List[Document] = []
if (
input_value
and isinstance(input_value, str)
and hasattr(vector_store, "search")
):
docs = vector_store.search(
query=input_value, search_type=search_type.lower()
)

View file

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import List
from langchain.embeddings.base import Embeddings
@ -40,13 +40,13 @@ class PGVectorSearchComponent(PGVectorComponent, LCVectorStoreComponent):
"input_value": {"display_name": "Input"},
}
def build(
def build( # type: ignore[override]
self,
input_value: str,
embedding: Embeddings,
search_type: str,
pg_server_url: str,
collection_name: str,
search_type: Optional[str] = None,
) -> List[Record]:
"""
Builds the Vector Store or BaseRetriever object.

View file

@ -22,7 +22,8 @@ class Object:
pass
class Text:
# Text = NewType("Text", str)
class Text(str):
pass

View file

@ -39,10 +39,10 @@ class Graph:
self._runs = 0
self._updates = 0
self.flow_id = flow_id
self._is_input_vertices = []
self._is_output_vertices = []
self._has_session_id_vertices = []
self._sorted_vertices_layers = []
self._is_input_vertices: List[str] = []
self._is_output_vertices: List[str] = []
self._has_session_id_vertices: List[str] = []
self._sorted_vertices_layers: List[List[str]] = []
self.top_level_vertices = []
for vertex in self._vertices:
@ -73,7 +73,9 @@ class Graph:
if getattr(vertex, attribute):
getattr(self, f"_{attribute}_vertices").append(vertex.id)
async def _run(self, inputs: Dict[str, str], stream: bool) -> List["ResultData"]:
async def _run(
self, inputs: Dict[str, str], stream: bool
) -> List[Optional["ResultData"]]:
"""Runs the graph with the given inputs."""
for vertex_id in self._is_input_vertices:
vertex = self.get_vertex(vertex_id)
@ -363,10 +365,10 @@ class Graph:
# All vertices that do not have edges are invalid
return len(self.get_vertex_edges(vertex.id)) > 0
def get_vertex(self, vertex_id: str) -> Union[None, Vertex]:
def get_vertex(self, vertex_id: str) -> Vertex:
"""Returns a vertex by id."""
try:
return self.vertex_map.get(vertex_id)
return self.vertex_map[vertex_id]
except KeyError:
raise ValueError(f"Vertex {vertex_id} not found")
@ -590,7 +592,7 @@ class Graph:
)
return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}"
def sort_up_to_vertex(self, vertex_id: str) -> "Graph":
def sort_up_to_vertex(self, vertex_id: str) -> List[Vertex]:
"""Cuts the graph up to a given vertex and sorts the resulting subgraph."""
# Initial setup
visited = set() # To keep track of visited vertices
@ -727,7 +729,9 @@ class Graph:
]
return sorted_vertices
def sort_by_avg_build_time(self, vertices_layers: List[str]) -> List[str]:
def sort_by_avg_build_time(
self, vertices_layers: List[List[str]]
) -> List[List[str]]:
"""Sorts the vertices in the graph so that vertices with the lowest average build time come first."""
def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]:

View file

@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage
from loguru import logger
from langflow.graph.schema import INPUT_FIELD_NAME
from langflow.graph.utils import UnbuiltObject, flatten_list
from langflow.graph.utils import UnbuiltObject, flatten_list, serialize_field
from langflow.graph.vertex.base import StatefulVertex, StatelessVertex
from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.schema import Record
@ -483,7 +483,6 @@ class RoutingVertex(StatelessVertex):
def dict_to_codeblock(d: dict) -> str:
from langflow.api.utils import serialize_field
serialized = {key: serialize_field(val) for key, val in d.items()}
json_str = json.dumps(serialized, indent=4)

View file

@ -6,20 +6,17 @@ from loguru import logger
from langflow.graph import Graph
async def build_sorted_vertices(data_graph, user_id: Optional[Union[str, UUID]] = None) -> Tuple[Graph, Dict]:
async def build_sorted_vertices(
data_graph, flow_id: Optional[Union[str, UUID]] = None
) -> Tuple[Graph, Dict]:
"""
Build langchain object from data_graph.
"""
logger.debug("Building langchain object")
graph = Graph.from_payload(data_graph)
sorted_vertices = graph.topological_sort()
artifacts = {}
for vertex in sorted_vertices:
await vertex.build(user_id=user_id)
if vertex.artifacts:
artifacts.update(vertex.artifacts)
return graph, artifacts
graph = Graph.from_payload(data_graph, flow_id=flow_id)
return graph, {}
def get_memory_key(langchain_object):

View file

@ -3,7 +3,7 @@ from pathlib import Path
from typing import Optional
from urllib.parse import urlencode
import socketio
import socketio # type: ignore
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
@ -18,7 +18,9 @@ from langflow.utils.logger import configure
def get_lifespan(fix_migration=False, socketio_server=None):
@asynccontextmanager
async def lifespan(app: FastAPI):
initialize_services(fix_migration=fix_migration, socketio_server=socketio_server)
initialize_services(
fix_migration=fix_migration, socketio_server=socketio_server
)
setup_llm_caching()
LangfuseInstance.update()
yield
@ -31,7 +33,9 @@ def create_app():
"""Create the FastAPI app and include the router."""
configure()
socketio_server = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*", logger=True)
socketio_server = socketio.AsyncServer(
async_mode="asgi", cors_allowed_origins="*", logger=True
)
lifespan = get_lifespan(socketio_server=socketio_server)
app = FastAPI(lifespan=lifespan)
origins = ["*"]
@ -98,7 +102,9 @@ def get_static_files_dir():
return frontend_path / "frontend"
def setup_app(static_files_dir: Optional[Path] = None, backend_only: bool = False) -> FastAPI:
def setup_app(
static_files_dir: Optional[Path] = None, backend_only: bool = False
) -> FastAPI:
"""Setup the FastAPI app."""
# get the directory of the current file
if not static_files_dir:

View file

@ -98,22 +98,6 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]:
return list(inputs.values())[0] if len(inputs) == 1 else None
def get_build_result(data_graph, session_id):
# If session_id is provided, load the langchain_object from the session
# using build_sorted_vertices_with_caching.get_result_by_session_id
# if it returns something different than None, return it
# otherwise, build the graph and return the result
if session_id:
logger.debug(f"Loading LangChain object from session {session_id}")
result = build_sorted_vertices(data_graph=data_graph)
if result is not None:
logger.debug("Loaded LangChain object")
return result
logger.debug("Building langchain object")
return build_sorted_vertices(data_graph)
def process_inputs(
inputs: Optional[Union[dict, List[dict]]] = None,
artifacts: Optional[Dict[str, Any]] = None,
@ -233,7 +217,9 @@ async def process_graph_cached(
session_id=session_id, data_graph=data_graph
)
# Load the graph using SessionService
session = await session_service.load_session(session_id, data_graph)
session = await session_service.load_session(
session_id, data_graph, flow_id=flow_id
)
graph, artifacts = session if session else (None, None)
if not graph:
raise ValueError("Graph not found in the session")
@ -270,8 +256,8 @@ async def build_graph_and_generate_result(
async def run_graph(
graph: Union["Graph", dict],
flow_id: str,
session_id: str,
stream: bool,
session_id: Optional[str] = None,
inputs: Optional[Union[dict, List[dict]]] = None,
artifacts: Optional[Dict[str, Any]] = None,
session_service: Optional[SessionService] = None,
@ -282,10 +268,12 @@ async def run_graph(
graph = Graph.from_payload(graph, flow_id=flow_id)
else:
graph_data = graph._graph_data
if not session_id:
if not session_id and session_service is not None:
session_id = session_service.generate_key(
session_id=flow_id, data_graph=graph_data
)
if inputs is None:
inputs = {}
outputs = await graph.run(inputs, stream=stream)
if session_id and session_service:

View file

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
from langchain_core.documents import Document
from pydantic import BaseModel
@ -14,7 +14,7 @@ class Record(BaseModel):
"""
text: str
data: Optional[dict] = None
data: dict = {}
@classmethod
def from_document(cls, document: Document) -> "Record":

View file

@ -8,7 +8,6 @@ from loguru import logger
from langflow.api.v1.schemas import ChatMessage
from langflow.interface.utils import try_setting_streaming_options
from langflow.processing.base import get_result_and_steps
from langflow.utils.chat import ChatDefinition
LANGCHAIN_RUNNABLES = (Chain, Runnable, AgentExecutor)
@ -24,7 +23,9 @@ async def process_graph(
if build_result is None:
# Raise user facing error
raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.")
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
try:
@ -40,20 +41,7 @@ async def process_graph(
client_id=client_id,
session_id=session_id,
)
elif isinstance(build_result, ChatDefinition):
raw_output = await run_build_result(
build_result,
chat_inputs,
client_id=client_id,
session_id=session_id,
)
if isinstance(raw_output, dict):
if not build_result.output_key:
raise ValueError("No output key provided to ChatDefinition when returning a dict.")
result = raw_output[build_result.output_key]
else:
result = raw_output
intermediate_steps = []
else:
raise TypeError(f"Unknown type {type(build_result)}")
logger.debug("Generated result and intermediate_steps")
@ -64,5 +52,7 @@ async def process_graph(
raise e
async def run_build_result(build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str):
async def run_build_result(
build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str
):
return build_result(inputs=chat_inputs.message)

View file

@ -1,6 +1,4 @@
import orjson
from pydantic import ConfigDict
from sqlmodel import SQLModel
def orjson_dumps(v, *, default=None, sort_keys=False, indent_2=True):
@ -17,7 +15,3 @@ def orjson_dumps(v, *, default=None, sort_keys=False, indent_2=True):
if default is None:
return orjson.dumps(v, option=option).decode()
return orjson.dumps(v, default=default, option=option).decode()
class SQLModelSerializable(SQLModel):
model_config = ConfigDict(from_attributes=True)

View file

@ -10,7 +10,9 @@ if TYPE_CHECKING:
class TransactionModel(BaseModel):
id: Optional[int] = Field(default=None, alias="id")
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
timestamp: Optional[datetime] = Field(
default_factory=datetime.now, alias="timestamp"
)
source: str
target: str
target_args: dict
@ -51,8 +53,12 @@ class MessageModel(BaseModel):
@classmethod
def from_record(cls, record: "Record"):
# first check if the record has all the required fields
if "sender" not in record.data and "sender_name" not in record.data:
raise ValueError("The record does not have the required fields 'sender' and 'sender_name' in the data.")
if not record.data or (
"sender" not in record.data and "sender_name" not in record.data
):
raise ValueError(
"The record does not have the required fields 'sender' and 'sender_name' in the data."
)
return cls(
sender=record.data["sender"],
sender_name=record.data["sender_name"],

View file

@ -1,8 +1,12 @@
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional, Type, Union
import duckdb
from loguru import logger
from platformdirs import user_cache_dir
from pydantic import BaseModel
from langflow.services.base import Service
from langflow.services.monitor.schema import (
MessageModel,
@ -13,8 +17,6 @@ from langflow.services.monitor.utils import (
add_row_to_table,
drop_and_create_table_if_schema_mismatch,
)
from loguru import logger
from platformdirs import user_cache_dir
if TYPE_CHECKING:
from langflow.services.settings.manager import SettingsService
@ -43,7 +45,9 @@ class MonitorService(Service):
def ensure_tables_exist(self):
for table_name, model in self.table_map.items():
drop_and_create_table_if_schema_mismatch(str(self.db_path), table_name, model)
drop_and_create_table_if_schema_mismatch(
str(self.db_path), table_name, model
)
def add_row(
self,
@ -52,7 +56,7 @@ class MonitorService(Service):
):
# Make sure the model passed matches the table
model = self.table_map.get(table_name)
model: Type[BaseModel] = self.table_map.get(table_name)
if model is None:
raise ValueError(f"Unknown table name: {table_name}")

View file

@ -14,7 +14,9 @@ class SessionService(Service):
def __init__(self, cache_service):
self.cache_service: "BaseCacheService" = cache_service
async def load_session(self, key, data_graph: Optional[dict] = None):
async def load_session(
self, key, data_graph: Optional[dict] = None, flow_id: Optional[str] = None
):
# Check if the data is cached
if key in self.cache_service:
return self.cache_service.get(key)
@ -24,7 +26,7 @@ class SessionService(Service):
if data_graph is None:
return (None, None)
# If not cached, build the graph and cache it
graph, artifacts = await build_sorted_vertices(data_graph)
graph, artifacts = await build_sorted_vertices(data_graph, flow_id)
self.cache_service.set(key, (graph, artifacts))

View file

@ -1,9 +1,11 @@
import os
import yaml
from loguru import logger
from langflow.services.base import Service
from langflow.services.settings.auth import AuthSettings
from langflow.services.settings.base import Settings
from loguru import logger
import os
import yaml
class SettingsService(Service):
@ -28,9 +30,11 @@ class SettingsService(Service):
settings_dict = {k.upper(): v for k, v in settings_dict.items()}
for key in settings_dict:
if key not in Settings.__fields__.keys():
if key not in Settings.model_fields().keys():
raise KeyError(f"Key {key} not found in settings")
logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}")
logger.debug(
f"Loading {len(settings_dict[key])} {key} from {file_path}"
)
settings = Settings(**settings_dict)
if not settings.CONFIG_DIR:

View file

@ -25,7 +25,7 @@ class SocketIOService(Service):
self.sio.on("message")(self.message)
self.sio.on("get_vertices")(self.on_get_vertices)
self.sio.on("build_vertex")(self.on_build_vertex)
self.sessions = {}
self.sessions = {} # type: dict[str, dict]
async def emit_error(self, sid, error):
await self.sio.emit("error", to=sid, data=error)

View file

@ -11,32 +11,34 @@ if TYPE_CHECKING:
class StorageService(Service):
name = "storage_service"
def __init__(self, session_service: "SessionService", settings_service: "SettingsService"):
def __init__(
self, session_service: "SessionService", settings_service: "SettingsService"
):
self.settings_service = settings_service
self.session_service = session_service
self.set_ready()
def build_full_path(self, flow_id: str, file_name: str) -> str:
pass
raise NotImplementedError
def set_ready(self):
self.ready = True
@abstractmethod
async def save_file(self, flow_id: str, file_name: str, data) -> None:
pass
raise NotImplementedError
@abstractmethod
async def get_file(self, flow_id: str, file_name: str) -> bytes:
pass
raise NotImplementedError
@abstractmethod
async def list_files(self, flow_id: str) -> list[str]:
pass
raise NotImplementedError
@abstractmethod
async def delete_file(self, flow_id: str, file_name: str) -> bool:
pass
raise NotImplementedError
def teardown(self):
pass
raise NotImplementedError

View file

@ -59,7 +59,7 @@ class ChainFrontendNode(FrontendNode):
field.required = False
field.advanced = False
if "key" in field.name:
if "key" in str(field.name):
field.password = False
field.show = False
if field.name in ["input_key", "output_key"]:
@ -216,7 +216,9 @@ class MidJourneyPromptChainNode(FrontendNode):
),
],
)
description: str = "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
description: str = (
"MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts."
)
base_classes: list[str] = [
"LLMChain",
"BaseCustomChain",

View file

@ -2,12 +2,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from asgiref.sync import async_to_sync
from celery.exceptions import SoftTimeLimitExceeded # type: ignore
from langflow.core.celery_app import celery_app
from langflow.processing.process import Result, generate_result, process_inputs
from langflow.services.deps import get_session_service
from langflow.services.manager import initialize_session_service
from loguru import logger
from rich import print
if TYPE_CHECKING:
from langflow.graph.vertex.base import Vertex
@ -28,7 +24,9 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex":
async_to_sync(vertex.build)()
return vertex
except SoftTimeLimitExceeded as e:
raise self.retry(exc=SoftTimeLimitExceeded("Task took too long"), countdown=2) from e
raise self.retry(
exc=SoftTimeLimitExceeded("Task took too long"), countdown=2
) from e
@celery_app.task(acks_late=True)
@ -38,38 +36,4 @@ def process_graph_cached_task(
clear_cache=False,
session_id=None,
) -> Dict[str, Any]:
try:
initialize_session_service()
session_service = get_session_service()
if clear_cache:
session_service.clear_session(session_id)
if session_id is None:
session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph)
# Use async_to_sync to handle the asynchronous part of the session service
session_data = async_to_sync(session_service.load_session, force_new_loop=True)(session_id, data_graph)
logger.warning(f"session_data: {session_data}")
graph, artifacts = session_data if session_data else (None, None)
if not graph:
raise ValueError("Graph not found in the session")
# Use async_to_sync for the asynchronous build method
built_object = async_to_sync(graph.build, force_new_loop=True)()
logger.debug(f"Built object: {built_object}")
processed_inputs = process_inputs(inputs, artifacts or {})
result = async_to_sync(generate_result, force_new_loop=True)(built_object, processed_inputs)
# Update the session with the new data
session_service.update_session(session_id, (graph, artifacts))
result_object = Result(result=result, session_id=session_id).model_dump()
print(f"Result object: {result_object}")
return result_object
except Exception as e:
logger.error(f"Error in process_graph_cached_task: {e}")
# Handle the exception as needed, maybe re-raise or return an error message
raise
raise NotImplementedError("This task is not implemented yet")