From 44642b5a0e031d97fe98e455903ab6dc8d279bbd Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Wed, 28 Feb 2024 17:21:05 -0300 Subject: [PATCH] Update code with new typings and bug fixes --- src/backend/langflow/api/v1/chat.py | 69 +---- src/backend/langflow/api/v1/endpoints.py | 280 +----------------- .../components/chains/ConversationChain.py | 13 +- .../components/chains/LLMCheckerChain.py | 9 +- .../components/chains/LLMMathChain.py | 7 +- .../langflow/components/chains/RetrievalQA.py | 17 +- .../chains/RetrievalQAWithSourcesChain.py | 8 +- .../components/chains/SQLGenerator.py | 25 +- .../documentloaders/GatherRecords.py | 11 +- .../langflow/components/io/StoreMessages.py | 8 +- .../langflow/components/io/base/chat.py | 2 +- .../model_specs/ChatLiteLLMSpecs.py | 21 +- .../components/models/AzureOpenAIModel.py | 4 +- .../components/models/CTransformersModel.py | 8 +- .../langflow/components/models/CohereModel.py | 4 +- .../components/models/HuggingFaceModel.py | 6 +- .../langflow/components/models/OllamaModel.py | 2 +- .../langflow/components/models/OpenAIModel.py | 9 +- .../components/models/VertexAiModel.py | 2 +- .../langflow/components/prompts/Prompt.py | 3 +- .../components/utilities/ShouldRunNext.py | 6 +- .../components/vectorstores/Chroma.py | 4 +- .../components/vectorstores/ChromaSearch.py | 3 +- .../langflow/components/vectorstores/FAISS.py | 1 + .../vectorstores/MongoDBAtlasVector.py | 20 +- .../vectorstores/MongoDBAtlasVectorSearch.py | 2 +- .../components/vectorstores/PineconeSearch.py | 4 +- .../components/vectorstores/QdrantSearch.py | 4 +- .../components/vectorstores/RedisSearch.py | 2 +- .../components/vectorstores/VectaraSearch.py | 2 +- .../components/vectorstores/WeaviateSearch.py | 2 +- .../components/vectorstores/base/model.py | 18 +- .../components/vectorstores/pgvectorSearch.py | 6 +- .../langflow/field_typing/constants.py | 3 +- src/backend/langflow/graph/graph/base.py | 22 +- src/backend/langflow/graph/vertex/types.py | 3 +- src/backend/langflow/interface/run.py | 15 +- src/backend/langflow/main.py | 14 +- src/backend/langflow/processing/process.py | 26 +- src/backend/langflow/schema.py | 4 +- src/backend/langflow/services/chat/utils.py | 24 +- .../langflow/services/database/models/base.py | 6 - .../langflow/services/monitor/schema.py | 12 +- .../langflow/services/monitor/service.py | 14 +- .../langflow/services/session/service.py | 6 +- .../langflow/services/settings/manager.py | 14 +- .../langflow/services/socket/service.py | 2 +- .../langflow/services/storage/service.py | 16 +- .../langflow/template/frontend_node/chains.py | 6 +- src/backend/langflow/worker.py | 46 +-- 50 files changed, 264 insertions(+), 551 deletions(-) diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index 10b9d9e38..ec76a30be 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -2,19 +2,9 @@ import time import uuid from typing import TYPE_CHECKING, Annotated, Optional -from fastapi import ( - APIRouter, - BackgroundTasks, - Body, - Depends, - HTTPException, - WebSocket, - WebSocketException, - status, -) +from fastapi import APIRouter, BackgroundTasks, Body, Depends, HTTPException from fastapi.responses import StreamingResponse from loguru import logger -from sqlmodel import Session from langflow.api.utils import ( build_and_cache_graph, @@ -28,11 +18,7 @@ from langflow.api.v1.schemas import ( VertexBuildResponse, VerticesOrderResponse, ) -from langflow.graph.graph.base import Graph -from langflow.services.auth.utils import ( - get_current_active_user, - get_current_user_for_websocket, -) +from langflow.services.auth.utils import get_current_active_user from langflow.services.chat.service import ChatService from langflow.services.deps import get_chat_service, get_session, get_session_service from langflow.services.monitor.utils import log_vertex_build @@ -44,47 +30,6 @@ if TYPE_CHECKING: router = APIRouter(tags=["Chat"]) -@router.websocket("/chat/{client_id}") -async def chat( - client_id: str, - websocket: WebSocket, - db: Session = Depends(get_session), - chat_service: "ChatService" = Depends(get_chat_service), -): - """Websocket endpoint for chat.""" - try: - user = await get_current_user_for_websocket(websocket, db) - await websocket.accept() - if not user: - await websocket.close( - code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" - ) - elif not user.is_active: - await websocket.close( - code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" - ) - - if client_id in chat_service.cache_service: - await chat_service.handle_websocket(client_id, websocket) - else: - # We accept the connection but close it immediately - # if the flow is not built yet - message = "Please, build the flow before sending messages" - await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=message) - except WebSocketException as exc: - logger.error(f"Websocket exrror: {exc}") - await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc)) - except Exception as exc: - logger.error(f"Error in chat websocket: {exc}") - messsage = exc.detail if isinstance(exc, HTTPException) else str(exc) - if "Could not validate credentials" in str(exc): - await websocket.close( - code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" - ) - else: - await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage) - - async def try_running_celery_task(vertex, user_id): # Try running the task in celery # and set the task_id to the local vertex @@ -113,7 +58,7 @@ async def get_vertices( # First, we need to check if the flow_id is in the cache graph = None if cache := chat_service.get_cache(flow_id): - graph: Graph = cache.get("result") + graph = cache.get("result") graph = build_and_cache_graph(flow_id, session, chat_service, graph) if component_id: try: @@ -141,7 +86,7 @@ async def build_vertex( flow_id: str, vertex_id: str, background_tasks: BackgroundTasks, - inputs: Annotated[InputValueRequest, Body(embed=True)] = None, + inputs: Annotated[Optional[InputValueRequest], Body(embed=True)] = None, chat_service: "ChatService" = Depends(get_chat_service), current_user=Depends(get_current_active_user), ): @@ -161,7 +106,7 @@ async def build_vertex( ) else: graph = cache.get("result") - result_data_response = {} + result_data_response = ResultDataResponse(results={}) duration = "" vertex = graph.get_vertex(vertex_id) @@ -250,7 +195,9 @@ async def build_vertex_stream( else: graph = cache.get("result") else: - session_data = await session_service.load_session(session_id) + session_data = await session_service.load_session( + session_id, flow_id=flow_id + ) graph, artifacts = session_data if session_data else (None, None) if not graph: raise ValueError(f"No graph found for {flow_id}.") diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 116c63b2c..d972c1b27 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -4,27 +4,20 @@ from typing import Annotated, Any, List, Optional, Union import sqlalchemy as sa from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status from loguru import logger -from sqlmodel import select +from sqlmodel import Session, select from langflow.api.utils import update_frontend_node_with_template_values from langflow.api.v1.schemas import ( CustomComponentCode, - PreloadResponse, ProcessResponse, RunResponse, - TaskResponse, TaskStatusResponse, UploadFileResponse, ) from langflow.interface.custom.custom_component import CustomComponent from langflow.interface.custom.directory_reader import DirectoryReader from langflow.interface.custom.utils import build_custom_component_template -from langflow.processing.process import ( - build_graph_and_generate_result, - process_graph_cached, - process_tweaks, - run_graph, -) +from langflow.processing.process import process_tweaks, run_graph from langflow.services.auth.utils import api_key_security, get_current_active_user from langflow.services.cache.utils import save_uploaded_file from langflow.services.database.models.flow import Flow @@ -36,98 +29,12 @@ from langflow.services.deps import ( get_task_service, ) from langflow.services.session.service import SessionService - -try: - from langflow.worker import process_graph_cached_task -except ImportError: - - def process_graph_cached_task(*args, **kwargs): - raise NotImplementedError("Celery is not installed") - - -from sqlmodel import Session - from langflow.services.task.service import TaskService # build router router = APIRouter(tags=["Base"]) -async def process_graph_data( - graph_data: dict, - inputs: Optional[Union[List[dict], dict]] = None, - tweaks: Optional[dict] = None, - clear_cache: bool = False, - session_id: Optional[str] = None, - task_service: "TaskService" = Depends(get_task_service), - sync: bool = True, -): - task_result: Any = None - task_status = None - if tweaks: - try: - graph_data = process_tweaks(graph_data, tweaks) - except Exception as exc: - logger.error(f"Error processing tweaks: {exc}") - if sync: - result = await process_graph_cached( - graph_data, - inputs, - clear_cache, - session_id, - ) - task_id = str(id(result)) - if isinstance(result, dict) and "result" in result: - task_result = result["result"] - session_id = result["session_id"] - elif hasattr(result, "result") and hasattr(result, "session_id"): - task_result = result.result - - session_id = result.session_id - else: - task_result = result - else: - logger.warning( - "This is an experimental feature and may not work as expected." - "Please report any issues to our GitHub repository." - ) - if session_id is None: - # Generate a session ID - session_id = get_session_service().generate_key( - session_id=session_id, data_graph=graph_data - ) - task_id, task = await task_service.launch_task( - ( - process_graph_cached_task - if task_service.use_celery - else process_graph_cached - ), - graph_data, - inputs, - clear_cache, - session_id, - ) - task_status = task.status - if task.status == "FAILURE": - logger.error(f"Task {task_id} failed: {task.traceback}") - task_result = str(task._exception) - else: - task_result = task.result - - if task_id: - task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}") - else: - task_response = None - - return ProcessResponse( - result=task_result, - status=task_status, - task=task_response, - session_id=session_id, - backend=task_service.backend_name, - ) - - @router.get("/all", dependencies=[Depends(get_current_active_user)]) def get_all( settings_service=Depends(get_settings_service), @@ -141,85 +48,6 @@ def get_all( raise HTTPException(status_code=500, detail=str(exc)) from exc -@router.post("/process/json", response_model=ProcessResponse) -async def process_json( - session: Annotated[Session, Depends(get_session)], - data: dict, - inputs: Optional[dict] = None, - tweaks: Optional[dict] = None, - clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821 - session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821 - task_service: "TaskService" = Depends(get_task_service), - sync: Annotated[bool, Body(embed=True)] = True, # noqa: F821 -): - try: - return await process_graph_data( - graph_data=data, - inputs=inputs, - tweaks=tweaks, - clear_cache=clear_cache, - session_id=session_id, - task_service=task_service, - sync=sync, - ) - except Exception as exc: - logger.exception(exc) - raise HTTPException(status_code=500, detail=str(exc)) from exc - - -# Endpoint to preload a graph -@router.post("/process/preload/{flow_id}", response_model=PreloadResponse) -async def preload_flow( - session: Annotated[Session, Depends(get_session)], - flow_id: str, - session_id: Optional[str] = None, - session_service: SessionService = Depends(get_session_service), - api_key_user: User = Depends(api_key_security), - clear_session: Annotated[bool, Body(embed=True)] = False, # noqa: F821 -): - try: - # Get the flow that matches the flow_id and belongs to the user - # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() - if clear_session: - session_service.clear_session(session_id) - # Check if the session exists - session_data = await session_service.load_session(session_id) - # Session data is a tuple of (graph, artifacts) - # or (None, None) if the session is empty - if isinstance(session_data, tuple): - graph, artifacts = session_data - is_clear = graph is None and artifacts is None - else: - is_clear = session_data is None - return PreloadResponse(session_id=session_id, is_clear=is_clear) - else: - if session_id is None: - session_id = flow_id - flow = session.exec( - select(Flow) - .where(Flow.id == flow_id) - .where(Flow.user_id == api_key_user.id) - ).first() - if flow is None: - raise ValueError(f"Flow {flow_id} not found") - - if flow.data is None: - raise ValueError(f"Flow {flow_id} has no data") - graph_data = flow.data - session_service.clear_session(session_id) - # Load the graph using SessionService - session_data = await session_service.load_session(session_id, graph_data) - graph, artifacts = session_data if session_data else (None, None) - if not graph: - raise ValueError("Graph not found in the session") - _ = await graph.build() - session_service.update_session(session_id, (graph, artifacts)) - return PreloadResponse(session_id=session_id) - except Exception as exc: - logger.exception(exc) - raise HTTPException(status_code=500, detail=str(exc)) from exc - - @router.post( "/run/{flow_id}", response_model=RunResponse, response_model_exclude_none=True ) @@ -235,7 +63,9 @@ async def run_flow_with_caching( ): try: if session_id: - session_data = await session_service.load_session(session_id) + session_data = await session_service.load_session( + session_id, flow_id=flow_id + ) graph, artifacts = session_data if session_data else (None, None) task_result: Any = None if not graph: @@ -264,7 +94,7 @@ async def run_flow_with_caching( if flow.data is None: raise ValueError(f"Flow {flow_id} has no data") graph_data = flow.data - graph_data = process_tweaks(graph_data, tweaks) + graph_data = process_tweaks(graph_data, tweaks or {}) task_result, session_id = await run_graph( graph=graph_data, flow_id=flow_id, @@ -318,94 +148,16 @@ async def process( """ Endpoint to process an input with a given flow_id. """ - - try: - if session_id: - session_data = await session_service.load_session(session_id) - graph, artifacts = session_data if session_data else (None, None) - task_result: Any = None - task_status = None - task_id = None - if not graph: - raise ValueError("Graph not found in the session") - result = await build_graph_and_generate_result( - graph=graph, - inputs=inputs, - artifacts=artifacts, - session_id=session_id, - session_service=session_service, - ) - task_id = str(id(result)) - if isinstance(result, dict) and "result" in result: - task_result = result["result"] - session_id = result["session_id"] - elif hasattr(result, "result") and hasattr(result, "session_id"): - task_result = result.result - - session_id = result.session_id - else: - task_result = result - if task_id: - task_response = TaskResponse(id=task_id, href=f"api/v1/task/{task_id}") - else: - task_response = None - return ProcessResponse( - result=task_result, - status=task_status, - task=task_response, - session_id=session_id, - backend=task_service.backend_name, - ) - - else: - if api_key_user is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid API Key", - ) - - # Get the flow that matches the flow_id and belongs to the user - # flow = session.query(Flow).filter(Flow.id == flow_id).filter(Flow.user_id == api_key_user.id).first() - flow = session.exec( - select(Flow) - .where(Flow.id == flow_id) - .where(Flow.user_id == api_key_user.id) - ).first() - if flow is None: - raise ValueError(f"Flow {flow_id} not found") - - if flow.data is None: - raise ValueError(f"Flow {flow_id} has no data") - graph_data = flow.data - return await process_graph_data( - graph_data=graph_data, - inputs=inputs, - tweaks=tweaks, - clear_cache=clear_cache, - session_id=session_id, - task_service=task_service, - sync=sync, - ) - except sa.exc.StatementError as exc: - # StatementError('(builtins.ValueError) badly formed hexadecimal UUID string') - if "badly formed hexadecimal UUID string" in str(exc): - # This means the Flow ID is not a valid UUID which means it can't find the flow - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - except ValueError as exc: - if f"Flow {flow_id} not found" in str(exc): - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, detail=str(exc) - ) from exc - else: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(exc) - ) from exc - except Exception as e: - # Log stack trace - logger.exception(e) - raise HTTPException(status_code=500, detail=str(e)) from e + # Raise a depreciation warning + logger.warning( + "The /process endpoint is deprecated and will be removed in a future version. " + "Please use /run instead." + ) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="The /process endpoint is deprecated and will be removed in a future version. " + "Please use /run instead.", + ) @router.get("/task/{task_id}", response_model=TaskStatusResponse) diff --git a/src/backend/langflow/components/chains/ConversationChain.py b/src/backend/langflow/components/chains/ConversationChain.py index 726056138..f6ddd598a 100644 --- a/src/backend/langflow/components/chains/ConversationChain.py +++ b/src/backend/langflow/components/chains/ConversationChain.py @@ -31,17 +31,18 @@ class ConversationChainComponent(CustomComponent): chain = ConversationChain(llm=llm) else: chain = ConversationChain(llm=llm, memory=memory) - result = chain.invoke(input_value) + result = chain.invoke({chain.input_key: input_value}) # result is an AIMessage which is a subclass of BaseMessage # We need to check if it is a string or a BaseMessage + result_str = "" if hasattr(result, "content") and isinstance(result.content, str): self.status = "is message" - result = result.content + result_str = result.content elif isinstance(result, str): self.status = "is_string" - result = result + result_str = result else: # is dict - result = result.get("response") - self.status = result - return result + result_str = result.get("response") + self.status = result_str + return result_str diff --git a/src/backend/langflow/components/chains/LLMCheckerChain.py b/src/backend/langflow/components/chains/LLMCheckerChain.py index 15a540311..2e846b226 100644 --- a/src/backend/langflow/components/chains/LLMCheckerChain.py +++ b/src/backend/langflow/components/chains/LLMCheckerChain.py @@ -23,7 +23,8 @@ class LLMCheckerChainComponent(CustomComponent): ) -> Text: chain = LLMCheckerChain.from_llm(llm=llm) - response = chain.invoke({chain.input_key: inputs}) - result = response.get(chain.output_key) - self.status = result - return result + response = chain.invoke({chain.input_key: input_value}) + result = response.get(chain.output_key, "") + result_str = str(result) + self.status = result_str + return result_str diff --git a/src/backend/langflow/components/chains/LLMMathChain.py b/src/backend/langflow/components/chains/LLMMathChain.py index 7fb253b83..7f2be6f95 100644 --- a/src/backend/langflow/components/chains/LLMMathChain.py +++ b/src/backend/langflow/components/chains/LLMMathChain.py @@ -38,7 +38,8 @@ class LLMMathChainComponent(CustomComponent): output_key=output_key, memory=memory, ) - response = chain.invoke({input_key: inputs}) + response = chain.invoke({input_key: input_value}) result = response.get(output_key) - self.status = result - return result + result_str = str(result) + self.status = result_str + return result_str diff --git a/src/backend/langflow/components/chains/RetrievalQA.py b/src/backend/langflow/components/chains/RetrievalQA.py index 53fa24f15..567c62e93 100644 --- a/src/backend/langflow/components/chains/RetrievalQA.py +++ b/src/backend/langflow/components/chains/RetrievalQA.py @@ -1,7 +1,7 @@ -from typing import Callable, Optional, Union +from typing import Optional from langchain.chains.combine_documents.base import BaseCombineDocumentsChain -from langchain.chains.retrieval_qa.base import BaseRetrievalQA, RetrievalQA +from langchain.chains.retrieval_qa.base import RetrievalQA from langchain_core.documents import Document from langflow import CustomComponent @@ -35,7 +35,7 @@ class RetrievalQAComponent(CustomComponent): input_key: str = "query", output_key: str = "result", return_source_documents: bool = True, - ) -> Union[BaseRetrievalQA, Callable, Text]: + ) -> Text: runnable = RetrievalQA( combine_documents_chain=combine_documents_chain, retriever=retriever, @@ -44,10 +44,10 @@ class RetrievalQAComponent(CustomComponent): output_key=output_key, return_source_documents=return_source_documents, ) - if isinstance(inputs, Document): - inputs = inputs.page_content + if isinstance(input_value, Document): + input_value = input_value.page_content self.status = runnable - result = runnable.invoke({input_key: inputs}) + result = runnable.invoke({input_key: input_value}) result = result.content if hasattr(result, "content") else result # Result is a dict with keys "query", "result" and "source_documents" # for now we just return the result @@ -55,7 +55,8 @@ class RetrievalQAComponent(CustomComponent): references_str = "" if return_source_documents: references_str = self.create_references_from_records(records) - result_str = result.get("result") - final_result = "\n".join([result_str, references_str]) + result_str = result.get("result", "") + + final_result = "\n".join([str(result_str), references_str]) self.status = final_result return final_result diff --git a/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py b/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py index 8be64c631..8270c9778 100644 --- a/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py +++ b/src/backend/langflow/components/chains/RetrievalQAWithSourcesChain.py @@ -40,11 +40,11 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent): return_source_documents=return_source_documents, retriever=retriever, ) - if isinstance(inputs, Document): - inputs = inputs.page_content + if isinstance(input_value, Document): + input_value = input_value.page_content self.status = runnable input_key = runnable.input_keys[0] - result = runnable.invoke({input_key: inputs}) + result = runnable.invoke({input_key: input_value}) result = result.content if hasattr(result, "content") else result # Result is a dict with keys "query", "result" and "source_documents" # for now we just return the result @@ -52,7 +52,7 @@ class RetrievalQAWithSourcesChainComponent(CustomComponent): references_str = "" if return_source_documents: references_str = self.create_references_from_records(records) - result_str = result.get("answer") + result_str = str(result.get("answer", "")) final_result = "\n".join([result_str, references_str]) self.status = final_result return final_result diff --git a/src/backend/langflow/components/chains/SQLGenerator.py b/src/backend/langflow/components/chains/SQLGenerator.py index 39b8fe394..0ab41b477 100644 --- a/src/backend/langflow/components/chains/SQLGenerator.py +++ b/src/backend/langflow/components/chains/SQLGenerator.py @@ -3,6 +3,7 @@ from typing import Optional from langchain.chains import create_sql_query_chain from langchain_community.utilities.sql_database import SQLDatabase from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import Runnable from langflow import CustomComponent from langflow.field_typing import BaseLanguageModel, Text @@ -39,33 +40,27 @@ class SQLGeneratorComponent(CustomComponent): else: prompt_template = None - if top_k > 0: - kwargs = { - "k": top_k, - } + if top_k < 1: + raise ValueError("Top K must be greater than 0.") + if not prompt_template: - sql_query_chain = create_sql_query_chain(llm=llm, db=db, **kwargs) + sql_query_chain = create_sql_query_chain(llm=llm, db=db, k=top_k) else: - template = ( - prompt_template.template - if hasattr(prompt, "template") - else prompt_template - ) # Check if {question} is in the prompt if ( - "{question}" not in template - or "question" not in template.input_variables + "{question}" not in prompt_template.template + or "question" not in prompt_template.input_variables ): raise ValueError( "Prompt must contain `{question}` to be used with Natural Language to SQL." ) sql_query_chain = create_sql_query_chain( - llm=llm, db=db, prompt=prompt_template, **kwargs + llm=llm, db=db, prompt=prompt_template, k=top_k ) - query_writer = sql_query_chain | { + query_writer: Runnable = sql_query_chain | { "query": lambda x: x.replace("SQLQuery:", "").strip() } - response = query_writer.invoke({"question": inputs}) + response = query_writer.invoke({"question": input_value}) query = response.get("query") self.status = query return query diff --git a/src/backend/langflow/components/documentloaders/GatherRecords.py b/src/backend/langflow/components/documentloaders/GatherRecords.py index 745d0655b..dd7f86596 100644 --- a/src/backend/langflow/components/documentloaders/GatherRecords.py +++ b/src/backend/langflow/components/documentloaders/GatherRecords.py @@ -76,9 +76,11 @@ class GatherRecordsComponent(CustomComponent): return file_paths - def parse_file_to_record(self, file_path: str, silent_errors: bool) -> Record: + def parse_file_to_record( + self, file_path: str, silent_errors: bool + ) -> Optional[Record]: # Use the partition function to load the file - from unstructured.partition.auto import partition + from unstructured.partition.auto import partition # type: ignore try: elements = partition(file_path) @@ -115,13 +117,14 @@ class GatherRecordsComponent(CustomComponent): def parallel_load_records( self, file_paths: List[str], silent_errors: bool, max_concurrency: int - ) -> List[Record]: + ) -> List[Optional[Record]]: with futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor: loaded_files = executor.map( lambda file_path: self.parse_file_to_record(file_path, silent_errors), file_paths, ) - return loaded_files + # loaded_files is an iterator, so we need to convert it to a list + return list(loaded_files) def build( self, diff --git a/src/backend/langflow/components/io/StoreMessages.py b/src/backend/langflow/components/io/StoreMessages.py index 93af76e4f..3cf698a70 100644 --- a/src/backend/langflow/components/io/StoreMessages.py +++ b/src/backend/langflow/components/io/StoreMessages.py @@ -48,11 +48,15 @@ class StoreMessages(CustomComponent): # and the other parameters if not texts and not records: raise ValueError("Either texts or records must be provided.") + if not texts: + texts = [] if not records: records = [] if not session_id or not sender or not sender_name: - raise ValueError("If passing texts, session_id, sender, and sender_name must be provided.") + raise ValueError( + "If passing texts, session_id, sender, and sender_name must be provided." + ) for text in texts: record = Record( text=text, @@ -68,4 +72,4 @@ class StoreMessages(CustomComponent): self.status = records records = add_messages(records) - return records + return records or [] diff --git a/src/backend/langflow/components/io/base/chat.py b/src/backend/langflow/components/io/base/chat.py index 4660b4276..4d60f6bac 100644 --- a/src/backend/langflow/components/io/base/chat.py +++ b/src/backend/langflow/components/io/base/chat.py @@ -35,7 +35,7 @@ class ChatComponent(CustomComponent): def store_message( self, - message: Union[Text, Record], + message: Union[str, Text, Record], session_id: Optional[str] = None, sender: Optional[str] = None, sender_name: Optional[str] = None, diff --git a/src/backend/langflow/components/model_specs/ChatLiteLLMSpecs.py b/src/backend/langflow/components/model_specs/ChatLiteLLMSpecs.py index 1f7f22234..320b2b10b 100644 --- a/src/backend/langflow/components/model_specs/ChatLiteLLMSpecs.py +++ b/src/backend/langflow/components/model_specs/ChatLiteLLMSpecs.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, Optional, Union from langchain_community.chat_models.litellm import ChatLiteLLM, ChatLiteLLMException + from langflow import CustomComponent from langflow.field_typing import BaseLanguageModel @@ -126,7 +127,8 @@ class ChatLiteLLMComponent(CustomComponent): litellm.set_verbose = verbose except ImportError: raise ChatLiteLLMException( - "Could not import litellm python package. " "Please install it with `pip install litellm`" + "Could not import litellm python package. " + "Please install it with `pip install litellm`" ) provider_map = { "OpenAI": "openai_api_key", @@ -137,11 +139,17 @@ class ChatLiteLLMComponent(CustomComponent): "OpenRouter": "openrouter_api_key", } # Set the API key based on the provider - kwarg = {provider_map[provider]: api_key} + api_keys = {v: None for v in provider_map.values()} + + if variable_name := provider_map.get(provider): + api_keys[variable_name] = api_key + else: + raise ChatLiteLLMException( + f"Provider {provider} is not supported. Supported providers are: {', '.join(provider_map.keys())}" + ) LLM = ChatLiteLLM( model=model, - client=None, streaming=streaming, temperature=temperature, model_kwargs=model_kwargs if model_kwargs is not None else {}, @@ -150,6 +158,11 @@ class ChatLiteLLMComponent(CustomComponent): n=n, max_tokens=max_tokens, max_retries=max_retries, - **kwarg, + openai_api_key=api_keys["openai_api_key"], + azure_api_key=api_keys["azure_api_key"], + anthropic_api_key=api_keys["anthropic_api_key"], + replicate_api_key=api_keys["replicate_api_key"], + cohere_api_key=api_keys["cohere_api_key"], + openrouter_api_key=api_keys["openrouter_api_key"], ) return LLM diff --git a/src/backend/langflow/components/models/AzureOpenAIModel.py b/src/backend/langflow/components/models/AzureOpenAIModel.py index 392f390c4..a823a7067 100644 --- a/src/backend/langflow/components/models/AzureOpenAIModel.py +++ b/src/backend/langflow/components/models/AzureOpenAIModel.py @@ -4,6 +4,7 @@ from langchain.llms.base import BaseLanguageModel from langchain_openai import AzureChatOpenAI from langflow.components.models.base.model import LCModelComponent +from pydantic.v1 import SecretStr class AzureChatOpenAIComponent(LCModelComponent): @@ -93,13 +94,14 @@ class AzureChatOpenAIComponent(LCModelComponent): max_tokens: Optional[int] = 1000, stream: bool = False, ) -> BaseLanguageModel: + secret_api_key = SecretStr(api_key) try: output = AzureChatOpenAI( model=model, azure_endpoint=azure_endpoint, azure_deployment=azure_deployment, api_version=api_version, - api_key=api_key, + api_key=secret_api_key, temperature=temperature, max_tokens=max_tokens, ) diff --git a/src/backend/langflow/components/models/CTransformersModel.py b/src/backend/langflow/components/models/CTransformersModel.py index 31123ad7e..537ceae2f 100644 --- a/src/backend/langflow/components/models/CTransformersModel.py +++ b/src/backend/langflow/components/models/CTransformersModel.py @@ -41,11 +41,15 @@ class CTransformersComponent(LCModelComponent): model_file: str, input_value: str, model_type: str, + stream: bool = False, config: Optional[Dict] = None, - stream: Optional[bool] = False, ) -> Text: output = CTransformers( - model=model, model_file=model_file, model_type=model_type, config=config + client=None, + model=model, + model_file=model_file, + model_type=model_type, + config=config, # noqa ) return self.get_result(output=output, stream=stream, input_value=input_value) diff --git a/src/backend/langflow/components/models/CohereModel.py b/src/backend/langflow/components/models/CohereModel.py index a32fb9b4b..d6625d6bf 100644 --- a/src/backend/langflow/components/models/CohereModel.py +++ b/src/backend/langflow/components/models/CohereModel.py @@ -41,13 +41,11 @@ class CohereComponent(LCModelComponent): self, cohere_api_key: str, input_value: str, - max_tokens: int = 256, temperature: float = 0.75, stream: bool = False, ) -> Text: - output = ChatCohere( + output = ChatCohere( # type: ignore cohere_api_key=cohere_api_key, - max_tokens=max_tokens, temperature=temperature, ) return self.get_result(output=output, stream=stream, input_value=input_value) diff --git a/src/backend/langflow/components/models/HuggingFaceModel.py b/src/backend/langflow/components/models/HuggingFaceModel.py index 3d92272e6..6b437810b 100644 --- a/src/backend/langflow/components/models/HuggingFaceModel.py +++ b/src/backend/langflow/components/models/HuggingFaceModel.py @@ -36,17 +36,19 @@ class HuggingFaceEndpointsComponent(LCModelComponent): self, input_value: str, endpoint_url: str, + model: Optional[str] = None, task: str = "text2text-generation", huggingfacehub_api_token: Optional[str] = None, model_kwargs: Optional[dict] = None, stream: bool = False, ) -> Text: try: - llm = HuggingFaceEndpoint( + llm = HuggingFaceEndpoint( # type: ignore endpoint_url=endpoint_url, task=task, huggingfacehub_api_token=huggingfacehub_api_token, - model_kwargs=model_kwargs, + model_kwargs=model_kwargs or {}, + model=model or "", ) except Exception as e: raise ValueError("Could not connect to HuggingFace Endpoints API.") from e diff --git a/src/backend/langflow/components/models/OllamaModel.py b/src/backend/langflow/components/models/OllamaModel.py index 7929c2b43..d7ef12052 100644 --- a/src/backend/langflow/components/models/OllamaModel.py +++ b/src/backend/langflow/components/models/OllamaModel.py @@ -203,7 +203,7 @@ class ChatOllamaComponent(LCModelComponent): timeout: Optional[int] = None, top_k: Optional[int] = None, top_p: Optional[int] = None, - stream: Optional[bool] = False, + stream: bool = False, ) -> Text: if not base_url: base_url = "http://localhost:11434" diff --git a/src/backend/langflow/components/models/OpenAIModel.py b/src/backend/langflow/components/models/OpenAIModel.py index 7a28acee6..f595255ce 100644 --- a/src/backend/langflow/components/models/OpenAIModel.py +++ b/src/backend/langflow/components/models/OpenAIModel.py @@ -1,6 +1,7 @@ from typing import Optional from langchain_openai import ChatOpenAI +from pydantic.v1 import SecretStr from langflow.components.models.base.model import LCModelComponent from langflow.field_typing import NestedDict, Text @@ -73,16 +74,20 @@ class OpenAIModelComponent(LCModelComponent): openai_api_base: Optional[str] = None, openai_api_key: Optional[str] = None, temperature: float = 0.7, - stream: Optional[bool] = False, + stream: bool = False, ) -> Text: if not openai_api_base: openai_api_base = "https://api.openai.com/v1" + if openai_api_key: + secret_key = SecretStr(openai_api_key) + else: + secret_key = None output = ChatOpenAI( max_tokens=max_tokens, model_kwargs=model_kwargs, model=model_name, base_url=openai_api_base, - api_key=openai_api_key, + api_key=secret_key, temperature=temperature, ) diff --git a/src/backend/langflow/components/models/VertexAiModel.py b/src/backend/langflow/components/models/VertexAiModel.py index 5a1950f39..aeb535be3 100644 --- a/src/backend/langflow/components/models/VertexAiModel.py +++ b/src/backend/langflow/components/models/VertexAiModel.py @@ -82,7 +82,7 @@ class ChatVertexAIComponent(LCModelComponent): stream: bool = False, ) -> Text: try: - from langchain_google_vertexai import ChatVertexAI + from langchain_google_vertexai import ChatVertexAI # type: ignore except ImportError: raise ImportError( "To use the ChatVertexAI model, you need to install the langchain-google-vertexai package." diff --git a/src/backend/langflow/components/prompts/Prompt.py b/src/backend/langflow/components/prompts/Prompt.py index ca8262c2f..737e82504 100644 --- a/src/backend/langflow/components/prompts/Prompt.py +++ b/src/backend/langflow/components/prompts/Prompt.py @@ -1,4 +1,5 @@ from langchain_core.prompts import PromptTemplate + from langflow import CustomComponent from langflow.field_typing import Prompt, TemplateField, Text @@ -19,7 +20,7 @@ class PromptComponent(CustomComponent): template: Prompt, **kwargs, ) -> Text: - prompt_template = PromptTemplate.from_template(template) + prompt_template = PromptTemplate.from_template(str(template)) attributes_to_check = ["text", "page_content"] for key, value in kwargs.items(): diff --git a/src/backend/langflow/components/utilities/ShouldRunNext.py b/src/backend/langflow/components/utilities/ShouldRunNext.py index ac4196dc7..2f595b01b 100644 --- a/src/backend/langflow/components/utilities/ShouldRunNext.py +++ b/src/backend/langflow/components/utilities/ShouldRunNext.py @@ -23,7 +23,7 @@ class ShouldRunNext(CustomComponent): def build(self, template: Prompt, llm: BaseLanguageModel, **kwargs) -> dict: # This is a simple component that always returns True - prompt_template = PromptTemplate.from_template(template) + prompt_template = PromptTemplate.from_template(str(template)) attributes_to_check = ["text", "page_content"] for key, value in kwargs.items(): @@ -41,7 +41,9 @@ class ShouldRunNext(CustomComponent): result = result.get("response") if result.lower() not in ["true", "false"]: - raise ValueError("The prompt should generate a boolean response (True or False).") + raise ValueError( + "The prompt should generate a boolean response (True or False)." + ) # The string should be the words true or false # if not raise an error bool_result = result.lower() == "true" diff --git a/src/backend/langflow/components/vectorstores/Chroma.py b/src/backend/langflow/components/vectorstores/Chroma.py index d9b617e61..96f921b33 100644 --- a/src/backend/langflow/components/vectorstores/Chroma.py +++ b/src/backend/langflow/components/vectorstores/Chroma.py @@ -95,8 +95,8 @@ class ChromaComponent(CustomComponent): # If documents, then we need to create a Chroma instance using .from_documents # Check index_directory and expand it if it is a relative path - - index_directory = self.resolve_path(index_directory) + if index_directory is not None: + index_directory = self.resolve_path(index_directory) if documents is not None and embedding is not None: if len(documents) == 0: diff --git a/src/backend/langflow/components/vectorstores/ChromaSearch.py b/src/backend/langflow/components/vectorstores/ChromaSearch.py index e3f37108c..3a6d283b3 100644 --- a/src/backend/langflow/components/vectorstores/ChromaSearch.py +++ b/src/backend/langflow/components/vectorstores/ChromaSearch.py @@ -100,7 +100,8 @@ class ChromaSearchComponent(LCVectorStoreComponent): chroma_server_grpc_port=chroma_server_grpc_port or None, chroma_server_ssl_enabled=chroma_server_ssl_enabled, ) - index_directory = self.resolve_path(index_directory) + if index_directory: + index_directory = self.resolve_path(index_directory) vector_store = Chroma( embedding_function=embedding, collection_name=collection_name, diff --git a/src/backend/langflow/components/vectorstores/FAISS.py b/src/backend/langflow/components/vectorstores/FAISS.py index 0cecab8e7..d68d8d085 100644 --- a/src/backend/langflow/components/vectorstores/FAISS.py +++ b/src/backend/langflow/components/vectorstores/FAISS.py @@ -36,3 +36,4 @@ class FAISSComponent(CustomComponent): raise ValueError("Folder path is required to save the FAISS index.") path = self.resolve_path(folder_path) vector_store.save_local(str(path), index_name) + return vector_store diff --git a/src/backend/langflow/components/vectorstores/MongoDBAtlasVector.py b/src/backend/langflow/components/vectorstores/MongoDBAtlasVector.py index 5d4537408..d0593e740 100644 --- a/src/backend/langflow/components/vectorstores/MongoDBAtlasVector.py +++ b/src/backend/langflow/components/vectorstores/MongoDBAtlasVector.py @@ -27,7 +27,7 @@ class MongoDBAtlasComponent(CustomComponent): def build( self, embedding: Embeddings, - documents: List[Document] = None, + documents: List[Document], collection_name: str = "", db_name: str = "", index_name: str = "", @@ -35,11 +35,22 @@ class MongoDBAtlasComponent(CustomComponent): search_kwargs: Optional[NestedDict] = None, ) -> MongoDBAtlasVectorSearch: search_kwargs = search_kwargs or {} + try: + from pymongo import MongoClient + except ImportError: + raise ImportError( + "Please install pymongo to use MongoDB Atlas Vector Store" + ) + try: + mongo_client: MongoClient = MongoClient(mongodb_atlas_cluster_uri) + collection = mongo_client[db_name][collection_name] + except Exception as e: + raise ValueError(f"Failed to connect to MongoDB Atlas: {e}") if documents: vector_store = MongoDBAtlasVectorSearch.from_documents( documents=documents, embedding=embedding, - collection_name=collection_name, + collection=collection, db_name=db_name, index_name=index_name, mongodb_atlas_cluster_uri=mongodb_atlas_cluster_uri, @@ -48,10 +59,7 @@ class MongoDBAtlasComponent(CustomComponent): else: vector_store = MongoDBAtlasVectorSearch( embedding=embedding, - collection_name=collection_name, - db_name=db_name, + collection=collection, index_name=index_name, - mongodb_atlas_cluster_uri=mongodb_atlas_cluster_uri, - search_kwargs=search_kwargs, ) return vector_store diff --git a/src/backend/langflow/components/vectorstores/MongoDBAtlasVectorSearch.py b/src/backend/langflow/components/vectorstores/MongoDBAtlasVectorSearch.py index 6393c2a7b..ecd6e6157 100644 --- a/src/backend/langflow/components/vectorstores/MongoDBAtlasVectorSearch.py +++ b/src/backend/langflow/components/vectorstores/MongoDBAtlasVectorSearch.py @@ -25,7 +25,7 @@ class MongoDBAtlasSearchComponent(MongoDBAtlasComponent, LCVectorStoreComponent) "search_kwargs": {"display_name": "Search Kwargs", "advanced": True}, } - def build( + def build( # type: ignore[override] self, input_value: str, search_type: str, diff --git a/src/backend/langflow/components/vectorstores/PineconeSearch.py b/src/backend/langflow/components/vectorstores/PineconeSearch.py index 7af7f627f..c34f6666d 100644 --- a/src/backend/langflow/components/vectorstores/PineconeSearch.py +++ b/src/backend/langflow/components/vectorstores/PineconeSearch.py @@ -40,7 +40,7 @@ class PineconeSearchComponent(PineconeComponent, LCVectorStoreComponent): }, } - def build( + def build( # type: ignore[override] self, input_value: str, embedding: Embeddings, @@ -51,7 +51,7 @@ class PineconeSearchComponent(PineconeComponent, LCVectorStoreComponent): pinecone_api_key: Optional[str] = None, namespace: Optional[str] = "default", search_type: str = "similarity", - ) -> List[Record]: + ) -> List[Record]: # type: ignore[override] vector_store = super().build( embedding=embedding, pinecone_env=pinecone_env, diff --git a/src/backend/langflow/components/vectorstores/QdrantSearch.py b/src/backend/langflow/components/vectorstores/QdrantSearch.py index 742690a5d..c2e2a8adf 100644 --- a/src/backend/langflow/components/vectorstores/QdrantSearch.py +++ b/src/backend/langflow/components/vectorstores/QdrantSearch.py @@ -44,7 +44,7 @@ class QdrantSearchComponent(QdrantComponent, LCVectorStoreComponent): "url": {"display_name": "URL", "advanced": True}, } - def build( + def build( # type: ignore[override] self, input_value: str, embedding: Embeddings, @@ -65,7 +65,7 @@ class QdrantSearchComponent(QdrantComponent, LCVectorStoreComponent): search_kwargs: Optional[NestedDict] = None, timeout: Optional[int] = None, url: Optional[str] = None, - ) -> List[Record]: + ) -> List[Record]: # type: ignore[override] vector_store = super().build( embedding=embedding, collection_name=collection_name, diff --git a/src/backend/langflow/components/vectorstores/RedisSearch.py b/src/backend/langflow/components/vectorstores/RedisSearch.py index 71022de1d..25938ed52 100644 --- a/src/backend/langflow/components/vectorstores/RedisSearch.py +++ b/src/backend/langflow/components/vectorstores/RedisSearch.py @@ -42,7 +42,7 @@ class RedisSearchComponent(RedisComponent, LCVectorStoreComponent): "redis_index_name": {"display_name": "Redis Index", "advanced": False}, } - def build( + def build( # type: ignore[override] self, input_value: str, search_type: str, diff --git a/src/backend/langflow/components/vectorstores/VectaraSearch.py b/src/backend/langflow/components/vectorstores/VectaraSearch.py index da2c083d1..fa6752b8f 100644 --- a/src/backend/langflow/components/vectorstores/VectaraSearch.py +++ b/src/backend/langflow/components/vectorstores/VectaraSearch.py @@ -42,7 +42,7 @@ class VectaraSearchComponent(VectaraComponent, LCVectorStoreComponent): }, } - def build( + def build( # type: ignore[override] self, input_value: str, search_type: str, diff --git a/src/backend/langflow/components/vectorstores/WeaviateSearch.py b/src/backend/langflow/components/vectorstores/WeaviateSearch.py index 3eda5c583..ab8bec222 100644 --- a/src/backend/langflow/components/vectorstores/WeaviateSearch.py +++ b/src/backend/langflow/components/vectorstores/WeaviateSearch.py @@ -55,7 +55,7 @@ class WeaviateSearchVectorStore(WeaviateVectorStoreComponent, LCVectorStoreCompo "code": {"show": False}, } - def build( + def build( # type: ignore[override] self, input_value: str, search_type: str, diff --git a/src/backend/langflow/components/vectorstores/base/model.py b/src/backend/langflow/components/vectorstores/base/model.py index 1cc8b9d88..c0e916313 100644 --- a/src/backend/langflow/components/vectorstores/base/model.py +++ b/src/backend/langflow/components/vectorstores/base/model.py @@ -1,9 +1,10 @@ -from typing import List +from typing import List, Union +from langchain_core.documents import Document +from langchain_core.retrievers import BaseRetriever from langchain_core.vectorstores import VectorStore from langflow import CustomComponent -from langflow.field_typing import Text from langflow.schema import Record, docs_to_records @@ -14,7 +15,10 @@ class LCVectorStoreComponent(CustomComponent): beta: bool = True def search_with_vector_store( - self, input_value: Text, search_type: str, vector_store: VectorStore + self, + input_value: str, + search_type: str, + vector_store: Union[VectorStore, BaseRetriever], ) -> List[Record]: """ Search for records in the vector store based on the input value and search type. @@ -31,8 +35,12 @@ class LCVectorStoreComponent(CustomComponent): ValueError: If invalid inputs are provided. """ - docs = [] - if input_value and isinstance(input_value, str): + docs: List[Document] = [] + if ( + input_value + and isinstance(input_value, str) + and hasattr(vector_store, "search") + ): docs = vector_store.search( query=input_value, search_type=search_type.lower() ) diff --git a/src/backend/langflow/components/vectorstores/pgvectorSearch.py b/src/backend/langflow/components/vectorstores/pgvectorSearch.py index 00e291e76..5a75617fb 100644 --- a/src/backend/langflow/components/vectorstores/pgvectorSearch.py +++ b/src/backend/langflow/components/vectorstores/pgvectorSearch.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List from langchain.embeddings.base import Embeddings @@ -40,13 +40,13 @@ class PGVectorSearchComponent(PGVectorComponent, LCVectorStoreComponent): "input_value": {"display_name": "Input"}, } - def build( + def build( # type: ignore[override] self, input_value: str, embedding: Embeddings, + search_type: str, pg_server_url: str, collection_name: str, - search_type: Optional[str] = None, ) -> List[Record]: """ Builds the Vector Store or BaseRetriever object. diff --git a/src/backend/langflow/field_typing/constants.py b/src/backend/langflow/field_typing/constants.py index 5977cd9f8..6865ffb17 100644 --- a/src/backend/langflow/field_typing/constants.py +++ b/src/backend/langflow/field_typing/constants.py @@ -22,7 +22,8 @@ class Object: pass -class Text: +# Text = NewType("Text", str) +class Text(str): pass diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 051fc6b3b..0f715690a 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -39,10 +39,10 @@ class Graph: self._runs = 0 self._updates = 0 self.flow_id = flow_id - self._is_input_vertices = [] - self._is_output_vertices = [] - self._has_session_id_vertices = [] - self._sorted_vertices_layers = [] + self._is_input_vertices: List[str] = [] + self._is_output_vertices: List[str] = [] + self._has_session_id_vertices: List[str] = [] + self._sorted_vertices_layers: List[List[str]] = [] self.top_level_vertices = [] for vertex in self._vertices: @@ -73,7 +73,9 @@ class Graph: if getattr(vertex, attribute): getattr(self, f"_{attribute}_vertices").append(vertex.id) - async def _run(self, inputs: Dict[str, str], stream: bool) -> List["ResultData"]: + async def _run( + self, inputs: Dict[str, str], stream: bool + ) -> List[Optional["ResultData"]]: """Runs the graph with the given inputs.""" for vertex_id in self._is_input_vertices: vertex = self.get_vertex(vertex_id) @@ -363,10 +365,10 @@ class Graph: # All vertices that do not have edges are invalid return len(self.get_vertex_edges(vertex.id)) > 0 - def get_vertex(self, vertex_id: str) -> Union[None, Vertex]: + def get_vertex(self, vertex_id: str) -> Vertex: """Returns a vertex by id.""" try: - return self.vertex_map.get(vertex_id) + return self.vertex_map[vertex_id] except KeyError: raise ValueError(f"Vertex {vertex_id} not found") @@ -590,7 +592,7 @@ class Graph: ) return f"Graph:\nNodes: {vertex_ids}\nConnections:\n{edges_repr}" - def sort_up_to_vertex(self, vertex_id: str) -> "Graph": + def sort_up_to_vertex(self, vertex_id: str) -> List[Vertex]: """Cuts the graph up to a given vertex and sorts the resulting subgraph.""" # Initial setup visited = set() # To keep track of visited vertices @@ -727,7 +729,9 @@ class Graph: ] return sorted_vertices - def sort_by_avg_build_time(self, vertices_layers: List[str]) -> List[str]: + def sort_by_avg_build_time( + self, vertices_layers: List[List[str]] + ) -> List[List[str]]: """Sorts the vertices in the graph so that vertices with the lowest average build time come first.""" def sort_layer_by_avg_build_time(vertices_ids: List[str]) -> List[str]: diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index 721a7ccc8..d9dd42632 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -7,7 +7,7 @@ from langchain_core.messages import AIMessage from loguru import logger from langflow.graph.schema import INPUT_FIELD_NAME -from langflow.graph.utils import UnbuiltObject, flatten_list +from langflow.graph.utils import UnbuiltObject, flatten_list, serialize_field from langflow.graph.vertex.base import StatefulVertex, StatelessVertex from langflow.interface.utils import extract_input_variables_from_prompt from langflow.schema import Record @@ -483,7 +483,6 @@ class RoutingVertex(StatelessVertex): def dict_to_codeblock(d: dict) -> str: - from langflow.api.utils import serialize_field serialized = {key: serialize_field(val) for key, val in d.items()} json_str = json.dumps(serialized, indent=4) diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index 94cd922eb..b078fda04 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -6,20 +6,17 @@ from loguru import logger from langflow.graph import Graph -async def build_sorted_vertices(data_graph, user_id: Optional[Union[str, UUID]] = None) -> Tuple[Graph, Dict]: +async def build_sorted_vertices( + data_graph, flow_id: Optional[Union[str, UUID]] = None +) -> Tuple[Graph, Dict]: """ Build langchain object from data_graph. """ logger.debug("Building langchain object") - graph = Graph.from_payload(data_graph) - sorted_vertices = graph.topological_sort() - artifacts = {} - for vertex in sorted_vertices: - await vertex.build(user_id=user_id) - if vertex.artifacts: - artifacts.update(vertex.artifacts) - return graph, artifacts + graph = Graph.from_payload(data_graph, flow_id=flow_id) + + return graph, {} def get_memory_key(langchain_object): diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index f07552cf6..e3b821b11 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Optional from urllib.parse import urlencode -import socketio +import socketio # type: ignore from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse @@ -18,7 +18,9 @@ from langflow.utils.logger import configure def get_lifespan(fix_migration=False, socketio_server=None): @asynccontextmanager async def lifespan(app: FastAPI): - initialize_services(fix_migration=fix_migration, socketio_server=socketio_server) + initialize_services( + fix_migration=fix_migration, socketio_server=socketio_server + ) setup_llm_caching() LangfuseInstance.update() yield @@ -31,7 +33,9 @@ def create_app(): """Create the FastAPI app and include the router.""" configure() - socketio_server = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*", logger=True) + socketio_server = socketio.AsyncServer( + async_mode="asgi", cors_allowed_origins="*", logger=True + ) lifespan = get_lifespan(socketio_server=socketio_server) app = FastAPI(lifespan=lifespan) origins = ["*"] @@ -98,7 +102,9 @@ def get_static_files_dir(): return frontend_path / "frontend" -def setup_app(static_files_dir: Optional[Path] = None, backend_only: bool = False) -> FastAPI: +def setup_app( + static_files_dir: Optional[Path] = None, backend_only: bool = False +) -> FastAPI: """Setup the FastAPI app.""" # get the directory of the current file if not static_files_dir: diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py index d7cf09a6f..fbb3986ab 100644 --- a/src/backend/langflow/processing/process.py +++ b/src/backend/langflow/processing/process.py @@ -98,22 +98,6 @@ def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]: return list(inputs.values())[0] if len(inputs) == 1 else None -def get_build_result(data_graph, session_id): - # If session_id is provided, load the langchain_object from the session - # using build_sorted_vertices_with_caching.get_result_by_session_id - # if it returns something different than None, return it - # otherwise, build the graph and return the result - if session_id: - logger.debug(f"Loading LangChain object from session {session_id}") - result = build_sorted_vertices(data_graph=data_graph) - if result is not None: - logger.debug("Loaded LangChain object") - return result - - logger.debug("Building langchain object") - return build_sorted_vertices(data_graph) - - def process_inputs( inputs: Optional[Union[dict, List[dict]]] = None, artifacts: Optional[Dict[str, Any]] = None, @@ -233,7 +217,9 @@ async def process_graph_cached( session_id=session_id, data_graph=data_graph ) # Load the graph using SessionService - session = await session_service.load_session(session_id, data_graph) + session = await session_service.load_session( + session_id, data_graph, flow_id=flow_id + ) graph, artifacts = session if session else (None, None) if not graph: raise ValueError("Graph not found in the session") @@ -270,8 +256,8 @@ async def build_graph_and_generate_result( async def run_graph( graph: Union["Graph", dict], flow_id: str, - session_id: str, stream: bool, + session_id: Optional[str] = None, inputs: Optional[Union[dict, List[dict]]] = None, artifacts: Optional[Dict[str, Any]] = None, session_service: Optional[SessionService] = None, @@ -282,10 +268,12 @@ async def run_graph( graph = Graph.from_payload(graph, flow_id=flow_id) else: graph_data = graph._graph_data - if not session_id: + if not session_id and session_service is not None: session_id = session_service.generate_key( session_id=flow_id, data_graph=graph_data ) + if inputs is None: + inputs = {} outputs = await graph.run(inputs, stream=stream) if session_id and session_service: diff --git a/src/backend/langflow/schema.py b/src/backend/langflow/schema.py index 053caf3a3..e9f437038 100644 --- a/src/backend/langflow/schema.py +++ b/src/backend/langflow/schema.py @@ -1,4 +1,4 @@ -from typing import Any, Optional +from typing import Any from langchain_core.documents import Document from pydantic import BaseModel @@ -14,7 +14,7 @@ class Record(BaseModel): """ text: str - data: Optional[dict] = None + data: dict = {} @classmethod def from_document(cls, document: Document) -> "Record": diff --git a/src/backend/langflow/services/chat/utils.py b/src/backend/langflow/services/chat/utils.py index 85e6cdcd5..6a3e925b3 100644 --- a/src/backend/langflow/services/chat/utils.py +++ b/src/backend/langflow/services/chat/utils.py @@ -8,7 +8,6 @@ from loguru import logger from langflow.api.v1.schemas import ChatMessage from langflow.interface.utils import try_setting_streaming_options from langflow.processing.base import get_result_and_steps -from langflow.utils.chat import ChatDefinition LANGCHAIN_RUNNABLES = (Chain, Runnable, AgentExecutor) @@ -24,7 +23,9 @@ async def process_graph( if build_result is None: # Raise user facing error - raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.") + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) # Generate result and thought try: @@ -40,20 +41,7 @@ async def process_graph( client_id=client_id, session_id=session_id, ) - elif isinstance(build_result, ChatDefinition): - raw_output = await run_build_result( - build_result, - chat_inputs, - client_id=client_id, - session_id=session_id, - ) - if isinstance(raw_output, dict): - if not build_result.output_key: - raise ValueError("No output key provided to ChatDefinition when returning a dict.") - result = raw_output[build_result.output_key] - else: - result = raw_output - intermediate_steps = [] + else: raise TypeError(f"Unknown type {type(build_result)}") logger.debug("Generated result and intermediate_steps") @@ -64,5 +52,7 @@ async def process_graph( raise e -async def run_build_result(build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str): +async def run_build_result( + build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str +): return build_result(inputs=chat_inputs.message) diff --git a/src/backend/langflow/services/database/models/base.py b/src/backend/langflow/services/database/models/base.py index 17dd947d9..53ee2c37e 100644 --- a/src/backend/langflow/services/database/models/base.py +++ b/src/backend/langflow/services/database/models/base.py @@ -1,6 +1,4 @@ import orjson -from pydantic import ConfigDict -from sqlmodel import SQLModel def orjson_dumps(v, *, default=None, sort_keys=False, indent_2=True): @@ -17,7 +15,3 @@ def orjson_dumps(v, *, default=None, sort_keys=False, indent_2=True): if default is None: return orjson.dumps(v, option=option).decode() return orjson.dumps(v, default=default, option=option).decode() - - -class SQLModelSerializable(SQLModel): - model_config = ConfigDict(from_attributes=True) diff --git a/src/backend/langflow/services/monitor/schema.py b/src/backend/langflow/services/monitor/schema.py index 32e5e582f..2c1e34cd5 100644 --- a/src/backend/langflow/services/monitor/schema.py +++ b/src/backend/langflow/services/monitor/schema.py @@ -10,7 +10,9 @@ if TYPE_CHECKING: class TransactionModel(BaseModel): id: Optional[int] = Field(default=None, alias="id") - timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp") + timestamp: Optional[datetime] = Field( + default_factory=datetime.now, alias="timestamp" + ) source: str target: str target_args: dict @@ -51,8 +53,12 @@ class MessageModel(BaseModel): @classmethod def from_record(cls, record: "Record"): # first check if the record has all the required fields - if "sender" not in record.data and "sender_name" not in record.data: - raise ValueError("The record does not have the required fields 'sender' and 'sender_name' in the data.") + if not record.data or ( + "sender" not in record.data and "sender_name" not in record.data + ): + raise ValueError( + "The record does not have the required fields 'sender' and 'sender_name' in the data." + ) return cls( sender=record.data["sender"], sender_name=record.data["sender_name"], diff --git a/src/backend/langflow/services/monitor/service.py b/src/backend/langflow/services/monitor/service.py index 0aa14a081..25ff65fd8 100644 --- a/src/backend/langflow/services/monitor/service.py +++ b/src/backend/langflow/services/monitor/service.py @@ -1,8 +1,12 @@ from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Type, Union import duckdb +from loguru import logger +from platformdirs import user_cache_dir +from pydantic import BaseModel + from langflow.services.base import Service from langflow.services.monitor.schema import ( MessageModel, @@ -13,8 +17,6 @@ from langflow.services.monitor.utils import ( add_row_to_table, drop_and_create_table_if_schema_mismatch, ) -from loguru import logger -from platformdirs import user_cache_dir if TYPE_CHECKING: from langflow.services.settings.manager import SettingsService @@ -43,7 +45,9 @@ class MonitorService(Service): def ensure_tables_exist(self): for table_name, model in self.table_map.items(): - drop_and_create_table_if_schema_mismatch(str(self.db_path), table_name, model) + drop_and_create_table_if_schema_mismatch( + str(self.db_path), table_name, model + ) def add_row( self, @@ -52,7 +56,7 @@ class MonitorService(Service): ): # Make sure the model passed matches the table - model = self.table_map.get(table_name) + model: Type[BaseModel] = self.table_map.get(table_name) if model is None: raise ValueError(f"Unknown table name: {table_name}") diff --git a/src/backend/langflow/services/session/service.py b/src/backend/langflow/services/session/service.py index 059d82bec..914ca7a3a 100644 --- a/src/backend/langflow/services/session/service.py +++ b/src/backend/langflow/services/session/service.py @@ -14,7 +14,9 @@ class SessionService(Service): def __init__(self, cache_service): self.cache_service: "BaseCacheService" = cache_service - async def load_session(self, key, data_graph: Optional[dict] = None): + async def load_session( + self, key, data_graph: Optional[dict] = None, flow_id: Optional[str] = None + ): # Check if the data is cached if key in self.cache_service: return self.cache_service.get(key) @@ -24,7 +26,7 @@ class SessionService(Service): if data_graph is None: return (None, None) # If not cached, build the graph and cache it - graph, artifacts = await build_sorted_vertices(data_graph) + graph, artifacts = await build_sorted_vertices(data_graph, flow_id) self.cache_service.set(key, (graph, artifacts)) diff --git a/src/backend/langflow/services/settings/manager.py b/src/backend/langflow/services/settings/manager.py index e9e535911..b4812058f 100644 --- a/src/backend/langflow/services/settings/manager.py +++ b/src/backend/langflow/services/settings/manager.py @@ -1,9 +1,11 @@ +import os + +import yaml +from loguru import logger + from langflow.services.base import Service from langflow.services.settings.auth import AuthSettings from langflow.services.settings.base import Settings -from loguru import logger -import os -import yaml class SettingsService(Service): @@ -28,9 +30,11 @@ class SettingsService(Service): settings_dict = {k.upper(): v for k, v in settings_dict.items()} for key in settings_dict: - if key not in Settings.__fields__.keys(): + if key not in Settings.model_fields().keys(): raise KeyError(f"Key {key} not found in settings") - logger.debug(f"Loading {len(settings_dict[key])} {key} from {file_path}") + logger.debug( + f"Loading {len(settings_dict[key])} {key} from {file_path}" + ) settings = Settings(**settings_dict) if not settings.CONFIG_DIR: diff --git a/src/backend/langflow/services/socket/service.py b/src/backend/langflow/services/socket/service.py index a84c9ca86..b3ae2b08a 100644 --- a/src/backend/langflow/services/socket/service.py +++ b/src/backend/langflow/services/socket/service.py @@ -25,7 +25,7 @@ class SocketIOService(Service): self.sio.on("message")(self.message) self.sio.on("get_vertices")(self.on_get_vertices) self.sio.on("build_vertex")(self.on_build_vertex) - self.sessions = {} + self.sessions = {} # type: dict[str, dict] async def emit_error(self, sid, error): await self.sio.emit("error", to=sid, data=error) diff --git a/src/backend/langflow/services/storage/service.py b/src/backend/langflow/services/storage/service.py index 830557dbf..fea4a714e 100644 --- a/src/backend/langflow/services/storage/service.py +++ b/src/backend/langflow/services/storage/service.py @@ -11,32 +11,34 @@ if TYPE_CHECKING: class StorageService(Service): name = "storage_service" - def __init__(self, session_service: "SessionService", settings_service: "SettingsService"): + def __init__( + self, session_service: "SessionService", settings_service: "SettingsService" + ): self.settings_service = settings_service self.session_service = session_service self.set_ready() def build_full_path(self, flow_id: str, file_name: str) -> str: - pass + raise NotImplementedError def set_ready(self): self.ready = True @abstractmethod async def save_file(self, flow_id: str, file_name: str, data) -> None: - pass + raise NotImplementedError @abstractmethod async def get_file(self, flow_id: str, file_name: str) -> bytes: - pass + raise NotImplementedError @abstractmethod async def list_files(self, flow_id: str) -> list[str]: - pass + raise NotImplementedError @abstractmethod async def delete_file(self, flow_id: str, file_name: str) -> bool: - pass + raise NotImplementedError def teardown(self): - pass + raise NotImplementedError diff --git a/src/backend/langflow/template/frontend_node/chains.py b/src/backend/langflow/template/frontend_node/chains.py index a40b29577..9e369d5d0 100644 --- a/src/backend/langflow/template/frontend_node/chains.py +++ b/src/backend/langflow/template/frontend_node/chains.py @@ -59,7 +59,7 @@ class ChainFrontendNode(FrontendNode): field.required = False field.advanced = False - if "key" in field.name: + if "key" in str(field.name): field.password = False field.show = False if field.name in ["input_key", "output_key"]: @@ -216,7 +216,9 @@ class MidJourneyPromptChainNode(FrontendNode): ), ], ) - description: str = "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." + description: str = ( + "MidJourneyPromptChain is a chain you can use to generate new MidJourney prompts." + ) base_classes: list[str] = [ "LLMChain", "BaseCustomChain", diff --git a/src/backend/langflow/worker.py b/src/backend/langflow/worker.py index b3872bcd8..1f8b871dc 100644 --- a/src/backend/langflow/worker.py +++ b/src/backend/langflow/worker.py @@ -2,12 +2,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from asgiref.sync import async_to_sync from celery.exceptions import SoftTimeLimitExceeded # type: ignore + from langflow.core.celery_app import celery_app -from langflow.processing.process import Result, generate_result, process_inputs -from langflow.services.deps import get_session_service -from langflow.services.manager import initialize_session_service -from loguru import logger -from rich import print if TYPE_CHECKING: from langflow.graph.vertex.base import Vertex @@ -28,7 +24,9 @@ def build_vertex(self, vertex: "Vertex") -> "Vertex": async_to_sync(vertex.build)() return vertex except SoftTimeLimitExceeded as e: - raise self.retry(exc=SoftTimeLimitExceeded("Task took too long"), countdown=2) from e + raise self.retry( + exc=SoftTimeLimitExceeded("Task took too long"), countdown=2 + ) from e @celery_app.task(acks_late=True) @@ -38,38 +36,4 @@ def process_graph_cached_task( clear_cache=False, session_id=None, ) -> Dict[str, Any]: - try: - initialize_session_service() - session_service = get_session_service() - - if clear_cache: - session_service.clear_session(session_id) - - if session_id is None: - session_id = session_service.generate_key(session_id=session_id, data_graph=data_graph) - - # Use async_to_sync to handle the asynchronous part of the session service - session_data = async_to_sync(session_service.load_session, force_new_loop=True)(session_id, data_graph) - logger.warning(f"session_data: {session_data}") - graph, artifacts = session_data if session_data else (None, None) - - if not graph: - raise ValueError("Graph not found in the session") - - # Use async_to_sync for the asynchronous build method - built_object = async_to_sync(graph.build, force_new_loop=True)() - - logger.debug(f"Built object: {built_object}") - - processed_inputs = process_inputs(inputs, artifacts or {}) - result = async_to_sync(generate_result, force_new_loop=True)(built_object, processed_inputs) - - # Update the session with the new data - session_service.update_session(session_id, (graph, artifacts)) - result_object = Result(result=result, session_id=session_id).model_dump() - print(f"Result object: {result_object}") - return result_object - except Exception as e: - logger.error(f"Error in process_graph_cached_task: {e}") - # Handle the exception as needed, maybe re-raise or return an error message - raise + raise NotImplementedError("This task is not implemented yet")