Refactor process.py and schema.py
This commit is contained in:
parent
bb3257ed80
commit
dd9347a186
2 changed files with 7 additions and 76 deletions
|
|
@ -1,20 +1,15 @@
|
|||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import AgentAction, Document
|
||||
from langchain_community.vectorstores import VectorStore
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables.base import Runnable
|
||||
from langchain.schema import AgentAction
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.schema import INPUT_FIELD_NAME, RunOutputs
|
||||
from langflow.graph.schema import RunOutputs
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.interface.run import get_memory_key, update_memory_keys
|
||||
from langflow.schema.schema import INPUT_FIELD_NAME
|
||||
from langflow.services.session.service import SessionService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -124,73 +119,6 @@ def update_inputs_dict(inputs: dict, artifacts: Dict[str, Any]) -> dict:
|
|||
return inputs
|
||||
|
||||
|
||||
async def process_runnable(runnable: Runnable, inputs: Union[dict, List[dict]]):
|
||||
if isinstance(inputs, List) and hasattr(runnable, "abatch"):
|
||||
result = await runnable.abatch(inputs)
|
||||
elif isinstance(inputs, dict) and hasattr(runnable, "ainvoke"):
|
||||
result = await runnable.ainvoke(inputs)
|
||||
else:
|
||||
raise ValueError(f"Runnable {runnable} does not support inputs of type {type(inputs)}")
|
||||
# Check if the result is a list of AIMessages
|
||||
if isinstance(result, list) and all(isinstance(r, AIMessage) for r in result):
|
||||
result = [r.content for r in result]
|
||||
elif isinstance(result, AIMessage):
|
||||
result = result.content
|
||||
return result
|
||||
|
||||
|
||||
async def process_inputs_dict(built_object: Union[Chain, VectorStore, Runnable], inputs: dict):
|
||||
if isinstance(built_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
logger.debug("Generating result and thought")
|
||||
result = get_result_and_thought(built_object, inputs)
|
||||
|
||||
logger.debug("Generated result and thought")
|
||||
elif isinstance(built_object, VectorStore) and "query" in inputs:
|
||||
if isinstance(inputs, dict) and "search_type" not in inputs:
|
||||
inputs["search_type"] = "similarity"
|
||||
logger.info("search_type not provided, using default value: similarity")
|
||||
result = built_object.search(**inputs)
|
||||
elif isinstance(built_object, Document):
|
||||
result = built_object.dict()
|
||||
elif isinstance(built_object, Runnable):
|
||||
result = await process_runnable(built_object, inputs)
|
||||
if isinstance(result, list):
|
||||
result = [r.content if hasattr(r, "content") else r for r in result]
|
||||
elif hasattr(result, "content"):
|
||||
result = result.content
|
||||
else:
|
||||
result = result
|
||||
elif hasattr(built_object, "run") and isinstance(built_object, CustomComponent):
|
||||
result = built_object.run(inputs)
|
||||
else:
|
||||
result = None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def process_inputs_list(built_object: Runnable, inputs: List[dict]):
|
||||
return await process_runnable(built_object, inputs)
|
||||
|
||||
|
||||
async def generate_result(built_object: Union[Chain, VectorStore, Runnable], inputs: Union[dict, List[dict]]):
|
||||
if isinstance(inputs, dict):
|
||||
result = await process_inputs_dict(built_object, inputs)
|
||||
elif isinstance(inputs, List) and isinstance(built_object, Runnable):
|
||||
result = await process_inputs_list(built_object, inputs)
|
||||
else:
|
||||
raise ValueError(f"Invalid inputs type: {type(inputs)}")
|
||||
|
||||
if result is None:
|
||||
logger.warning(f"Unknown built_object type: {type(built_object)}")
|
||||
if isinstance(built_object, Coroutine):
|
||||
result = asyncio.run(built_object)
|
||||
result = built_object
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Result(BaseModel):
|
||||
result: Any
|
||||
session_id: str
|
||||
|
|
|
|||
|
|
@ -120,3 +120,6 @@ class Record(BaseModel):
|
|||
# check which attributes the Record has by checking the keys in the data dictionary
|
||||
def __dir__(self):
|
||||
return super().__dir__() + list(self.data.keys())
|
||||
|
||||
|
||||
INPUT_FIELD_NAME = "input_value"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue