Update type hints and refactor result handling in process.py (#1234)
This pull request updates the type hints for the inputs parameter in the process_graph_data and process functions. It also refactors the result handling in the generate_result function. Additionally, it updates the version to 0.6.3a5 in pyproject.toml.
This commit is contained in:
commit
bb9aed50e0
3 changed files with 23 additions and 13 deletions
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "langflow"
|
||||
version = "0.6.3a4"
|
||||
version = "0.6.3a5"
|
||||
description = "A Python package with a built-in web application"
|
||||
authors = ["Logspace <contact@logspace.ai>"]
|
||||
maintainers = [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Annotated, Optional, Union
|
||||
from typing import Annotated, List, Optional, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, status
|
||||
|
|
@ -42,7 +42,7 @@ router = APIRouter(tags=["Base"])
|
|||
|
||||
async def process_graph_data(
|
||||
graph_data: dict,
|
||||
inputs: Optional[dict] = None,
|
||||
inputs: Optional[Union[List[dict], dict]] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
clear_cache: bool = False,
|
||||
session_id: Optional[str] = None,
|
||||
|
|
@ -160,7 +160,7 @@ async def process_json(
|
|||
async def process(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
flow_id: str,
|
||||
inputs: Optional[dict] = None,
|
||||
inputs: Optional[Union[List[dict], dict]] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session_id: Annotated[Union[None, str], Body(embed=True)] = None, # noqa: F821
|
||||
|
|
|
|||
|
|
@ -6,11 +6,12 @@ from langchain.chains.base import Chain
|
|||
from langchain.schema import AgentAction, Document
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.components.custom_components import CustomComponent
|
||||
from langflow.interface.run import build_sorted_vertices, get_memory_key, update_memory_keys
|
||||
from langflow.services.deps import get_session_service
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
|
|
@ -118,7 +119,7 @@ def process_inputs(inputs: Optional[dict], artifacts: Dict[str, Any]) -> dict:
|
|||
return inputs
|
||||
|
||||
|
||||
async def generate_result(langchain_object: Union[Chain, VectorStore], inputs: dict):
|
||||
async def generate_result(langchain_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
|
||||
if isinstance(langchain_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
|
|
@ -131,12 +132,21 @@ async def generate_result(langchain_object: Union[Chain, VectorStore], inputs: d
|
|||
elif isinstance(langchain_object, Document):
|
||||
result = langchain_object.dict()
|
||||
elif isinstance(langchain_object, Runnable):
|
||||
if isinstance(inputs, List):
|
||||
call_func = langchain_object.abatch
|
||||
elif isinstance(inputs, dict):
|
||||
call_func = langchain_object.ainvoke
|
||||
result = await call_func(inputs)
|
||||
result = result.content if hasattr(result, "content") else result
|
||||
# Define call_method as a coroutine function
|
||||
# by default
|
||||
if isinstance(inputs, List) and hasattr(langchain_object, "abatch"):
|
||||
call_method = langchain_object.abatch
|
||||
elif isinstance(inputs, dict) and hasattr(langchain_object, "ainvoke"):
|
||||
call_method = langchain_object.ainvoke
|
||||
else:
|
||||
raise ValueError("Inputs must be provided for a Runnable")
|
||||
result = await call_method(inputs)
|
||||
if isinstance(result, list):
|
||||
result = [r.content if hasattr(r, "content") else r for r in result]
|
||||
elif hasattr(result, "content"):
|
||||
result = result.content
|
||||
else:
|
||||
result = result
|
||||
elif hasattr(langchain_object, "run") and isinstance(langchain_object, CustomComponent):
|
||||
result = langchain_object.run(inputs)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue