refactor(api): change predict endpoint to use async/await and add response model

feat(api): add predict request and response schemas
refactor(interface): rename get_result_and_steps to get_result_and_thought and remove async prefix
fix(interface): use get_result_and_thought instead of async_get_result_and_steps in process_graph_cached
This commit is contained in:
Gabriel Almeida 2023-05-09 13:19:19 -03:00
commit 0bc96208e4
3 changed files with 46 additions and 36 deletions

View file

@ -5,6 +5,12 @@ from fastapi import APIRouter, HTTPException
from langflow.interface.run import process_graph_cached
from langflow.interface.types import build_langchain_types_dict
from langflow.api.schemas import (
ExportedFlow,
GraphData,
PredictRequest,
PredictResponse,
)
# build router
router = APIRouter()
@ -16,10 +22,14 @@ def get_all():
return build_langchain_types_dict()
@router.post("/predict")
def get_load(data: Dict[str, Any]):
@router.post("/predict", response_model=PredictResponse)
async def get_load(predict_request: PredictRequest):
try:
return process_graph_cached(data)
exported_flow: ExportedFlow = predict_request.exported_flow
graph_data: GraphData = exported_flow.data
data = graph_data.dict()
response = process_graph_cached(data, predict_request.message)
return PredictResponse(result=response.get("result", ""))
except Exception as e:
# Log stack trace
logger.exception(e)

View file

@ -1,8 +1,37 @@
from typing import Any, Union
from typing import Any, Union, Dict, List
from pydantic import BaseModel, validator
class GraphData(BaseModel):
"""Data inside the exported flow."""
nodes: List[Dict[str, Any]]
edges: List[Dict[str, Any]]
class ExportedFlow(BaseModel):
"""Exported flow from LangFlow."""
description: str
name: str
id: str
data: GraphData
class PredictRequest(BaseModel):
"""Predict request schema."""
message: str
exported_flow: ExportedFlow
class PredictResponse(BaseModel):
"""Predict response schema."""
result: str
class ChatMessage(BaseModel):
"""Chat message schema."""

View file

@ -100,13 +100,12 @@ def process_graph(data_graph: Dict[str, Any]):
return {"result": str(result), "thought": thought.strip()}
def process_graph_cached(data_graph: Dict[str, Any]):
def process_graph_cached(data_graph: Dict[str, Any], message: str):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
# Load langchain object
message = data_graph.pop("message", "")
is_first_message = len(data_graph.get("chatHistory", [])) == 0
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
logger.debug("Loaded langchain object")
@ -119,7 +118,7 @@ def process_graph_cached(data_graph: Dict[str, Any]):
# Generate result and thought
logger.debug("Generating result and thought")
result, thought = get_result_and_steps(langchain_object, message)
result, thought = get_result_and_thought(langchain_object, message)
logger.debug("Generated result and thought")
return {"result": str(result), "thought": thought.strip()}
@ -241,7 +240,7 @@ def get_result_and_steps(langchain_object, message: str):
return result, thought
def async_get_result_and_steps(langchain_object, message: str):
def get_result_and_thought(langchain_object, message: str):
"""Get result and thought from extracted json"""
try:
if hasattr(langchain_object, "verbose"):
@ -296,34 +295,6 @@ def async_get_result_and_steps(langchain_object, message: str):
return result, thought
def get_result_and_thought(extracted_json: Dict[str, Any], message: str):
"""Get result and thought from extracted json"""
try:
langchain_object = loading.load_langchain_type_from_config(
config=extracted_json
)
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
output = langchain_object(message)
intermediate_steps = (
output.get("intermediate_steps", []) if isinstance(output, dict) else []
)
result = (
output.get(langchain_object.output_keys[0])
if isinstance(output, dict)
else output
)
if intermediate_steps:
thought = format_intermediate_steps(intermediate_steps)
else:
thought = output_buffer.getvalue()
except Exception as e:
result = f"Error: {str(e)}"
thought = ""
return result, thought
def format_intermediate_steps(intermediate_steps):
formatted_chain = "> Entering new AgentExecutor chain...\n"
for step in intermediate_steps: