refactor(api): change predict endpoint to use async/await and add response model
feat(api): add predict request and response schemas refactor(interface): rename get_result_and_steps to get_result_and_thought and remove async prefix fix(interface): use get_result_and_thought instead of async_get_result_and_steps in process_graph_cached
This commit is contained in:
parent
8810a0392a
commit
0bc96208e4
3 changed files with 46 additions and 36 deletions
|
|
@ -5,6 +5,12 @@ from fastapi import APIRouter, HTTPException
|
|||
|
||||
from langflow.interface.run import process_graph_cached
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
from langflow.api.schemas import (
|
||||
ExportedFlow,
|
||||
GraphData,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
)
|
||||
|
||||
# build router
|
||||
router = APIRouter()
|
||||
|
|
@ -16,10 +22,14 @@ def get_all():
|
|||
return build_langchain_types_dict()
|
||||
|
||||
|
||||
@router.post("/predict")
|
||||
def get_load(data: Dict[str, Any]):
|
||||
@router.post("/predict", response_model=PredictResponse)
|
||||
async def get_load(predict_request: PredictRequest):
|
||||
try:
|
||||
return process_graph_cached(data)
|
||||
exported_flow: ExportedFlow = predict_request.exported_flow
|
||||
graph_data: GraphData = exported_flow.data
|
||||
data = graph_data.dict()
|
||||
response = process_graph_cached(data, predict_request.message)
|
||||
return PredictResponse(result=response.get("result", ""))
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,37 @@
|
|||
from typing import Any, Union
|
||||
from typing import Any, Union, Dict, List
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
|
||||
class GraphData(BaseModel):
|
||||
"""Data inside the exported flow."""
|
||||
|
||||
nodes: List[Dict[str, Any]]
|
||||
edges: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class ExportedFlow(BaseModel):
|
||||
"""Exported flow from LangFlow."""
|
||||
|
||||
description: str
|
||||
name: str
|
||||
id: str
|
||||
data: GraphData
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
"""Predict request schema."""
|
||||
|
||||
message: str
|
||||
exported_flow: ExportedFlow
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
"""Predict response schema."""
|
||||
|
||||
result: str
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Chat message schema."""
|
||||
|
||||
|
|
|
|||
|
|
@ -100,13 +100,12 @@ def process_graph(data_graph: Dict[str, Any]):
|
|||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def process_graph_cached(data_graph: Dict[str, Any]):
|
||||
def process_graph_cached(data_graph: Dict[str, Any], message: str):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
message = data_graph.pop("message", "")
|
||||
is_first_message = len(data_graph.get("chatHistory", [])) == 0
|
||||
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
|
@ -119,7 +118,7 @@ def process_graph_cached(data_graph: Dict[str, Any]):
|
|||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_steps(langchain_object, message)
|
||||
result, thought = get_result_and_thought(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
|
@ -241,7 +240,7 @@ def get_result_and_steps(langchain_object, message: str):
|
|||
return result, thought
|
||||
|
||||
|
||||
def async_get_result_and_steps(langchain_object, message: str):
|
||||
def get_result_and_thought(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
|
|
@ -296,34 +295,6 @@ def async_get_result_and_steps(langchain_object, message: str):
|
|||
return result, thought
|
||||
|
||||
|
||||
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
langchain_object = loading.load_langchain_type_from_config(
|
||||
config=extracted_json
|
||||
)
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
output = langchain_object(message)
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
|
||||
if intermediate_steps:
|
||||
thought = format_intermediate_steps(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
except Exception as e:
|
||||
result = f"Error: {str(e)}"
|
||||
thought = ""
|
||||
return result, thought
|
||||
|
||||
|
||||
def format_intermediate_steps(intermediate_steps):
|
||||
formatted_chain = "> Entering new AgentExecutor chain...\n"
|
||||
for step in intermediate_steps:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue