API Refactor and structure changes (#449)
This commit is contained in:
commit
0eb60c0bf0
40 changed files with 398 additions and 339 deletions
|
|
@ -78,6 +78,15 @@ types-pillow = "^9.5.0.2"
|
|||
[tool.poetry.extras]
|
||||
deploy = ["langchain-serve"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "6.0"
|
||||
addopts = "-ra"
|
||||
testpaths = ["tests", "integration"]
|
||||
console_output_style = "progress"
|
||||
filterwarnings = ["ignore::DeprecationWarning"]
|
||||
log_cli = true
|
||||
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 120
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from langflow.cache import cache_manager
|
||||
from langflow.interface.loading import load_flow_from_json
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
|
||||
__all__ = ["load_flow_from_json", "cache_manager"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
from langflow.api.router import router
|
||||
|
||||
__all__ = ["router"]
|
||||
8
src/backend/langflow/api/router.py
Normal file
8
src/backend/langflow/api/router.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
# Router for base api
|
||||
from fastapi import APIRouter
|
||||
from langflow.api.v1 import chat_router, endpoints_router, validate_router
|
||||
|
||||
router = APIRouter(prefix="/api/v1", tags=["api"])
|
||||
router.include_router(chat_router)
|
||||
router.include_router(endpoints_router)
|
||||
router.include_router(validate_router)
|
||||
5
src/backend/langflow/api/v1/__init__.py
Normal file
5
src/backend/langflow/api/v1/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
from langflow.api.v1.endpoints import router as endpoints_router
|
||||
from langflow.api.v1.validate import router as validate_router
|
||||
from langflow.api.v1.chat import router as chat_router
|
||||
|
||||
__all__ = ["chat_router", "endpoints_router", "validate_router"]
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
from pydantic import BaseModel, validator
|
||||
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
||||
|
||||
class CacheResponse(BaseModel):
|
||||
|
|
@ -3,7 +3,7 @@ from typing import Any
|
|||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
|
||||
from langflow.api.schemas import ChatResponse
|
||||
from langflow.api.v1.schemas import ChatResponse
|
||||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
|
|
@ -6,7 +6,7 @@ from fastapi import (
|
|||
status,
|
||||
)
|
||||
|
||||
from langflow.api.chat_manager import ChatManager
|
||||
from langflow.chat.manager import ChatManager
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -3,13 +3,13 @@ from importlib.metadata import version
|
|||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.api.schemas import (
|
||||
from langflow.api.v1.schemas import (
|
||||
ExportedFlow,
|
||||
GraphData,
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
)
|
||||
from langflow.interface.run import process_graph_cached
|
||||
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
|
||||
# build router
|
||||
|
|
@ -25,6 +25,8 @@ def get_all():
|
|||
@router.post("/predict", response_model=PredictResponse)
|
||||
async def get_load(predict_request: PredictRequest):
|
||||
try:
|
||||
from langflow.processing.process import process_graph_cached
|
||||
|
||||
exported_flow: ExportedFlow = predict_request.exported_flow
|
||||
graph_data: GraphData = exported_flow.data
|
||||
data = graph_data.dict()
|
||||
|
|
@ -40,8 +42,3 @@ async def get_load(predict_request: PredictRequest):
|
|||
@router.get("/version")
|
||||
def get_version():
|
||||
return {"version": version("langflow")}
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
def get_health():
|
||||
return {"status": "OK"}
|
||||
|
|
@ -2,7 +2,7 @@ import json
|
|||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.api.base import (
|
||||
from langflow.api.v1.base import (
|
||||
Code,
|
||||
CodeValidationResponse,
|
||||
Prompt,
|
||||
|
|
@ -10,7 +10,7 @@ from langflow.api.base import (
|
|||
validate_prompt,
|
||||
)
|
||||
from langflow.graph.vertex.types import VectorStoreVertex
|
||||
from langflow.interface.run import build_graph
|
||||
from langflow.graph import Graph
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.validate import validate_code
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ def post_validate_prompt(prompt: Prompt):
|
|||
def post_validate_node(node_id: str, data: dict):
|
||||
try:
|
||||
# build graph
|
||||
graph = build_graph(data)
|
||||
graph = Graph.from_payload(data)
|
||||
# validate node
|
||||
node = graph.get_node(node_id)
|
||||
if node is None:
|
||||
0
src/backend/langflow/chat/__init__.py
Normal file
0
src/backend/langflow/chat/__init__.py
Normal file
|
|
@ -1,21 +1,18 @@
|
|||
import asyncio
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
from fastapi import WebSocket, status
|
||||
|
||||
from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.cache import cache_manager
|
||||
from langflow.cache.manager import Subject
|
||||
from langflow.interface.run import (
|
||||
get_result_and_steps,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.interface.utils import pil_to_base64, try_setting_streaming_options
|
||||
from langflow.chat.utils import process_graph
|
||||
from langflow.interface.utils import pil_to_base64
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
class ChatHistory(Subject):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -191,33 +188,3 @@ class ChatManager:
|
|||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self.disconnect(client_id)
|
||||
|
||||
|
||||
async def process_graph(
|
||||
graph_data: Dict,
|
||||
is_first_message: bool,
|
||||
chat_message: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
):
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await get_result_and_steps(
|
||||
langchain_object, chat_message.message or "", websocket=websocket
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
41
src/backend/langflow/chat/utils.py
Normal file
41
src/backend/langflow/chat/utils.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
from fastapi import WebSocket
|
||||
from langflow.api.v1.schemas import ChatMessage
|
||||
from langflow.processing.process import (
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
from langflow.processing.base import get_result_and_steps
|
||||
from langflow.interface.utils import try_setting_streaming_options
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
from typing import Dict
|
||||
|
||||
|
||||
async def process_graph(
|
||||
graph_data: Dict,
|
||||
is_first_message: bool,
|
||||
chat_message: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
):
|
||||
langchain_object = load_or_build_langchain_object(graph_data, is_first_message)
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
try:
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await get_result_and_steps(
|
||||
langchain_object, chat_message.message or "", websocket=websocket
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
|
@ -24,6 +24,27 @@ class Graph:
|
|||
self._edges = edges
|
||||
self._build_graph()
|
||||
|
||||
@classmethod
|
||||
@classmethod
|
||||
def from_payload(cls, payload: Dict) -> "Graph":
|
||||
"""
|
||||
Creates a graph from a payload.
|
||||
|
||||
Args:
|
||||
payload (Dict): The payload to create the graph from.
|
||||
|
||||
Returns:
|
||||
Graph: The created graph.
|
||||
"""
|
||||
if "data" in payload:
|
||||
payload = payload["data"]
|
||||
try:
|
||||
nodes = payload["nodes"]
|
||||
edges = payload["edges"]
|
||||
return cls(nodes, edges)
|
||||
except KeyError as exc:
|
||||
raise ValueError("Invalid payload") from exc
|
||||
|
||||
def _build_graph(self) -> None:
|
||||
"""Builds the graph from the nodes and edges."""
|
||||
self.nodes = self._build_vertices()
|
||||
|
|
|
|||
0
src/backend/langflow/graph/graph/utils.py
Normal file
0
src/backend/langflow/graph/graph/utils.py
Normal file
|
|
@ -1,6 +1,7 @@
|
|||
import re
|
||||
from typing import Any, Union
|
||||
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
||||
|
||||
def validate_prompt(prompt: str):
|
||||
"""Validate prompt."""
|
||||
|
|
@ -15,11 +16,6 @@ def fix_prompt(prompt: str):
|
|||
return prompt + " {input}"
|
||||
|
||||
|
||||
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
||||
"""Extract input variables from prompt."""
|
||||
return re.findall(r"{(.*?)}", prompt)
|
||||
|
||||
|
||||
def flatten_list(list_of_lists: list[Union[list, Any]]) -> list:
|
||||
"""Flatten list of lists."""
|
||||
new_list = []
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list
|
||||
from langflow.graph.utils import flatten_list
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
||||
|
||||
class AgentVertex(Vertex):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langchain.memory.buffer import ConversationBufferMemory
|
|||
from langchain.schema import BaseMemory
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
||||
DEFAULT_SUFFIX = """"
|
||||
Current conversation:
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from langchain.agents.load_tools import (
|
|||
_LLM_TOOLS,
|
||||
)
|
||||
from langchain.agents.loading import load_agent_from_config
|
||||
from langflow.graph import Graph
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
|
|
@ -22,7 +21,6 @@ from pydantic import ValidationError
|
|||
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.importing.utils import get_function, import_by_type
|
||||
from langflow.interface.run import fix_memory_inputs
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.types import get_type_list
|
||||
from langflow.interface.utils import load_file_into_dict
|
||||
|
|
@ -163,37 +161,6 @@ def instantiate_utility(node_type, class_object, params):
|
|||
return class_object(**params)
|
||||
|
||||
|
||||
def load_flow_from_json(path: str, build=True):
|
||||
"""Load flow from json file"""
|
||||
# This is done to avoid circular imports
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
# Substitute ZeroShotPrompt with PromptTemplate
|
||||
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
|
||||
# Add input variables
|
||||
# nodes = payload.extract_input_variables(nodes)
|
||||
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
if build:
|
||||
langchain_object = graph.build()
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
fix_memory_inputs(langchain_object)
|
||||
return langchain_object
|
||||
return graph
|
||||
|
||||
|
||||
def replace_zero_shot_prompt_with_prompt_template(nodes):
|
||||
"""Replace ZeroShotPrompt with PromptTemplate"""
|
||||
for node in nodes:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Type
|
|||
from langchain.prompts import PromptTemplate
|
||||
from pydantic import root_validator
|
||||
|
||||
from langflow.graph.utils import extract_input_variables_from_prompt
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
||||
# Steps to create a BaseCustomPrompt:
|
||||
# 1. Create a prompt template that endes with:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,3 @@
|
|||
import contextlib
|
||||
import io
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
from langchain.schema import AgentAction
|
||||
|
||||
from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore
|
||||
from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict
|
||||
from langflow.graph import Graph
|
||||
from langflow.utils.logger import logger
|
||||
|
|
@ -24,15 +17,6 @@ def load_langchain_object(data_graph, is_first_message=False):
|
|||
return computed_hash, langchain_object
|
||||
|
||||
|
||||
def load_or_build_langchain_object(data_graph, is_first_message=False):
|
||||
"""
|
||||
Load langchain object from cache if it exists, otherwise build it.
|
||||
"""
|
||||
if is_first_message:
|
||||
build_langchain_object_with_caching.clear_cache()
|
||||
return build_langchain_object_with_caching(data_graph)
|
||||
|
||||
|
||||
@memoize_dict(maxsize=10)
|
||||
def build_langchain_object_with_caching(data_graph):
|
||||
"""
|
||||
|
|
@ -40,16 +24,10 @@ def build_langchain_object_with_caching(data_graph):
|
|||
"""
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
graph = build_graph(data_graph)
|
||||
graph = Graph.from_payload(data_graph)
|
||||
return graph.build()
|
||||
|
||||
|
||||
def build_graph(data_graph):
|
||||
nodes = data_graph["nodes"]
|
||||
edges = data_graph["edges"]
|
||||
return Graph(nodes, edges)
|
||||
|
||||
|
||||
def build_langchain_object(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
|
|
@ -66,29 +44,6 @@ def build_langchain_object(data_graph):
|
|||
return graph.build()
|
||||
|
||||
|
||||
def process_graph_cached(data_graph: Dict[str, Any], message: str):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
is_first_message = len(data_graph.get("chatHistory", [])) == 0
|
||||
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def get_memory_key(langchain_object):
|
||||
"""
|
||||
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.
|
||||
|
|
@ -124,147 +79,3 @@ def update_memory_keys(langchain_object, possible_new_mem_key):
|
|||
langchain_object.memory.input_key = input_key
|
||||
langchain_object.memory.output_key = output_key
|
||||
langchain_object.memory.memory_key = possible_new_mem_key
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the
|
||||
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
|
||||
get_memory_key function and updates the memory keys using the update_memory_keys function.
|
||||
"""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
try:
|
||||
if langchain_object.memory.memory_key in langchain_object.input_variables:
|
||||
return
|
||||
except AttributeError:
|
||||
input_variables = (
|
||||
langchain_object.prompt.input_variables
|
||||
if hasattr(langchain_object, "prompt")
|
||||
else langchain_object.input_keys
|
||||
)
|
||||
if langchain_object.memory.memory_key in input_variables:
|
||||
return
|
||||
|
||||
possible_new_mem_key = get_memory_key(langchain_object)
|
||||
if possible_new_mem_key is not None:
|
||||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
async def get_result_and_steps(langchain_object, message: str, **kwargs):
|
||||
"""Get result and thought from extracted json"""
|
||||
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = True
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
try:
|
||||
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
|
||||
output = await langchain_object.acall(chat_input, callbacks=async_callbacks)
|
||||
except Exception as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
|
||||
output = langchain_object(chat_input, callbacks=sync_callbacks)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
thought = format_actions(intermediate_steps) if intermediate_steps else ""
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
|
||||
|
||||
def get_result_and_thought(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
# if hasattr(langchain_object, "acall"):
|
||||
# output = await langchain_object.acall(chat_input)
|
||||
# else:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
if intermediate_steps:
|
||||
thought = format_actions(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
|
||||
|
||||
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
|
||||
"""Format a list of (AgentAction, answer) tuples into a string."""
|
||||
output = []
|
||||
for action, answer in actions:
|
||||
log = action.log
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
output.append(f"Log: {log}")
|
||||
if "Action" not in log and "Action Input" not in log:
|
||||
output.append(f"Tool: {tool}")
|
||||
output.append(f"Tool Input: {tool_input}")
|
||||
output.append(f"Answer: {answer}")
|
||||
output.append("") # Add a blank line
|
||||
return "\n".join(output)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import base64
|
|||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
|
|
@ -48,3 +49,8 @@ def try_setting_streaming_options(langchain_object, websocket):
|
|||
llm.streaming = True
|
||||
|
||||
return langchain_object
|
||||
|
||||
|
||||
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
||||
"""Extract input variables from prompt."""
|
||||
return re.findall(r"{(.*?)}", prompt)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from langflow.api.chat import router as chat_router
|
||||
from langflow.api.endpoints import router as endpoints_router
|
||||
from langflow.api.validate import router as validate_router
|
||||
from langflow.api import router
|
||||
|
||||
|
||||
def create_app():
|
||||
|
|
@ -14,6 +12,10 @@ def create_app():
|
|||
"*",
|
||||
]
|
||||
|
||||
@app.get("/health")
|
||||
def get_health():
|
||||
return {"status": "OK"}
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
|
|
@ -22,9 +24,7 @@ def create_app():
|
|||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(endpoints_router)
|
||||
app.include_router(validate_router)
|
||||
app.include_router(chat_router)
|
||||
app.include_router(router)
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
0
src/backend/langflow/processing/__init__.py
Normal file
0
src/backend/langflow/processing/__init__.py
Normal file
55
src/backend/langflow/processing/base.py
Normal file
55
src/backend/langflow/processing/base.py
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
from langflow.api.v1.callback import (
|
||||
AsyncStreamingLLMCallbackHandler,
|
||||
StreamingLLMCallbackHandler,
|
||||
)
|
||||
from langflow.processing.process import fix_memory_inputs, format_actions
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
async def get_result_and_steps(langchain_object, message: str, **kwargs):
|
||||
"""Get result and thought from extracted json"""
|
||||
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = True
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
try:
|
||||
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
|
||||
output = await langchain_object.acall(chat_input, callbacks=async_callbacks)
|
||||
except Exception as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
|
||||
output = langchain_object(chat_input, callbacks=sync_callbacks)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
thought = format_actions(intermediate_steps) if intermediate_steps else ""
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
172
src/backend/langflow/processing/process.py
Normal file
172
src/backend/langflow/processing/process.py
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
import contextlib
|
||||
import io
|
||||
from langchain.schema import AgentAction
|
||||
import json
|
||||
from langflow.interface.run import (
|
||||
build_langchain_object_with_caching,
|
||||
get_memory_key,
|
||||
update_memory_keys,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
"""
|
||||
Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the
|
||||
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
|
||||
get_memory_key function and updates the memory keys using the update_memory_keys function.
|
||||
"""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
try:
|
||||
if langchain_object.memory.memory_key in langchain_object.input_variables:
|
||||
return
|
||||
except AttributeError:
|
||||
input_variables = (
|
||||
langchain_object.prompt.input_variables
|
||||
if hasattr(langchain_object, "prompt")
|
||||
else langchain_object.input_keys
|
||||
)
|
||||
if langchain_object.memory.memory_key in input_variables:
|
||||
return
|
||||
|
||||
possible_new_mem_key = get_memory_key(langchain_object)
|
||||
if possible_new_mem_key is not None:
|
||||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
|
||||
"""Format a list of (AgentAction, answer) tuples into a string."""
|
||||
output = []
|
||||
for action, answer in actions:
|
||||
log = action.log
|
||||
tool = action.tool
|
||||
tool_input = action.tool_input
|
||||
output.append(f"Log: {log}")
|
||||
if "Action" not in log and "Action Input" not in log:
|
||||
output.append(f"Tool: {tool}")
|
||||
output.append(f"Tool Input: {tool_input}")
|
||||
output.append(f"Answer: {answer}")
|
||||
output.append("") # Add a blank line
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
def get_result_and_thought(langchain_object, message: str):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
# if hasattr(langchain_object, "acall"):
|
||||
# output = await langchain_object.acall(chat_input)
|
||||
# else:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
if intermediate_steps:
|
||||
thought = format_actions(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
|
||||
|
||||
def load_or_build_langchain_object(data_graph, is_first_message=False):
|
||||
"""
|
||||
Load langchain object from cache if it exists, otherwise build it.
|
||||
"""
|
||||
if is_first_message:
|
||||
build_langchain_object_with_caching.clear_cache()
|
||||
return build_langchain_object_with_caching(data_graph)
|
||||
|
||||
|
||||
def process_graph_cached(data_graph: Dict[str, Any], message: str):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
is_first_message = len(data_graph.get("chatHistory", [])) == 0
|
||||
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
raise ValueError(
|
||||
"There was an error loading the langchain_object. Please, check all the nodes and try again."
|
||||
)
|
||||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
|
||||
|
||||
def load_flow_from_json(path: str, build=True):
|
||||
"""Load flow from json file"""
|
||||
# This is done to avoid circular imports
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
# Substitute ZeroShotPrompt with PromptTemplate
|
||||
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
|
||||
# Add input variables
|
||||
# nodes = payload.extract_input_variables(nodes)
|
||||
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
if build:
|
||||
langchain_object = graph.build()
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
fix_memory_inputs(langchain_object)
|
||||
return langchain_object
|
||||
return graph
|
||||
|
|
@ -14,13 +14,13 @@ export async function sendAll(data: sendAllProps) {
|
|||
export async function checkCode(
|
||||
code: string
|
||||
): Promise<AxiosResponse<errorsTypeAPI>> {
|
||||
return await axios.post("/validate/code", { code });
|
||||
return await axios.post("api/v1/validate/code", { code });
|
||||
}
|
||||
|
||||
export async function checkPrompt(
|
||||
template: string
|
||||
): Promise<AxiosResponse<PromptTypeAPI>> {
|
||||
return await axios.post("/validate/prompt", { template });
|
||||
return await axios.post("api/v1/validate/prompt", { template });
|
||||
}
|
||||
|
||||
export async function getExamples(): Promise<FlowType[]> {
|
||||
|
|
|
|||
|
|
@ -182,10 +182,10 @@ export default function ChatModal({
|
|||
try {
|
||||
const urlWs =
|
||||
process.env.NODE_ENV === "development"
|
||||
? `ws://localhost:7860/chat/${id.current}`
|
||||
? `ws://localhost:7860/api/v1/chat/${id.current}`
|
||||
: `${window.location.protocol === "https:" ? "wss" : "ws"}://${
|
||||
window.location.host
|
||||
}/chat/${id.current}`;
|
||||
}api/v1/chat/${id.current}`;
|
||||
const newWs = new WebSocket(urlWs);
|
||||
newWs.onopen = () => {
|
||||
console.log("WebSocket connection established!");
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ const apiRoutes = [
|
|||
];
|
||||
|
||||
// Use environment variable to determine the target.
|
||||
const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860";
|
||||
const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860/api/v1";
|
||||
|
||||
const proxyTargets = apiRoutes.reduce((proxyObj, route) => {
|
||||
proxyObj[route] = {
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langflow.settings import settings
|
|||
# check that all agents are in settings.agents
|
||||
# are in json_response["agents"]
|
||||
def test_agents_settings(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
@ -13,7 +13,7 @@ def test_agents_settings(client: TestClient):
|
|||
|
||||
|
||||
def test_zero_shot_agent(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
@ -52,7 +52,7 @@ def test_zero_shot_agent(client: TestClient):
|
|||
|
||||
|
||||
def test_json_agent(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
@ -87,7 +87,7 @@ def test_json_agent(client: TestClient):
|
|||
|
||||
|
||||
def test_csv_agent(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
@ -126,7 +126,7 @@ def test_csv_agent(client: TestClient):
|
|||
|
||||
|
||||
def test_initialize_agent(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
agents = json_response["agents"]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
import json
|
||||
from langflow.graph import Graph
|
||||
from langflow.processing.process import load_or_build_langchain_object
|
||||
|
||||
import pytest
|
||||
from langflow.interface.run import (
|
||||
build_graph,
|
||||
build_langchain_object_with_caching,
|
||||
load_or_build_langchain_object,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -62,7 +62,7 @@ def test_build_langchain_object_with_caching(basic_data_graph):
|
|||
|
||||
# Test build_graph
|
||||
def test_build_graph(basic_data_graph):
|
||||
graph = build_graph(basic_data_graph)
|
||||
graph = Graph.from_payload(basic_data_graph)
|
||||
assert graph is not None
|
||||
assert len(graph.nodes) == len(basic_data_graph["nodes"])
|
||||
assert len(graph.edges) == len(basic_data_graph["edges"])
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.settings import settings
|
|||
|
||||
|
||||
def test_chains_settings(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -12,7 +12,7 @@ def test_chains_settings(client: TestClient):
|
|||
|
||||
# Test the ConversationChain object
|
||||
def test_conversation_chain(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -94,7 +94,7 @@ def test_conversation_chain(client: TestClient):
|
|||
|
||||
|
||||
def test_llm_chain(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -152,7 +152,7 @@ def test_llm_chain(client: TestClient):
|
|||
|
||||
|
||||
def test_llm_checker_chain(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -228,7 +228,7 @@ def test_llm_checker_chain(client: TestClient):
|
|||
|
||||
|
||||
def test_llm_math_chain(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -306,7 +306,7 @@ def test_llm_math_chain(client: TestClient):
|
|||
|
||||
|
||||
def test_series_character_chain(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -368,7 +368,7 @@ def test_series_character_chain(client: TestClient):
|
|||
|
||||
|
||||
def test_mid_journey_prompt_chain(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
@ -407,7 +407,7 @@ def test_mid_journey_prompt_chain(client: TestClient):
|
|||
|
||||
|
||||
def test_time_travel_guide_chain(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
chains = json_response["chains"]
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from langflow.interface.tools.constants import CUSTOM_TOOLS
|
|||
|
||||
|
||||
def test_get_all(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
# We need to test the custom nodes
|
||||
|
|
@ -21,7 +21,7 @@ import math
|
|||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
response1 = client.post("/validate/code", json={"code": code1})
|
||||
response1 = client.post("api/v1/validate/code", json={"code": code1})
|
||||
assert response1.status_code == 200
|
||||
assert response1.json() == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
|
|
@ -32,7 +32,7 @@ import non_existent_module
|
|||
def square(x):
|
||||
return x ** 2
|
||||
"""
|
||||
response2 = client.post("/validate/code", json={"code": code2})
|
||||
response2 = client.post("api/v1/validate/code", json={"code": code2})
|
||||
assert response2.status_code == 200
|
||||
assert response2.json() == {
|
||||
"imports": {"errors": ["No module named 'non_existent_module'"]},
|
||||
|
|
@ -46,7 +46,7 @@ import math
|
|||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
response3 = client.post("/validate/code", json={"code": code3})
|
||||
response3 = client.post("api/v1/validate/code", json={"code": code3})
|
||||
assert response3.status_code == 200
|
||||
assert response3.json() == {
|
||||
"imports": {"errors": []},
|
||||
|
|
@ -54,11 +54,11 @@ def square(x)
|
|||
}
|
||||
|
||||
# Test case with invalid JSON payload
|
||||
response4 = client.post("/validate/code", json={"invalid_key": code1})
|
||||
response4 = client.post("api/v1/validate/code", json={"invalid_key": code1})
|
||||
assert response4.status_code == 422
|
||||
|
||||
# Test case with an empty code string
|
||||
response5 = client.post("/validate/code", json={"code": ""})
|
||||
response5 = client.post("api/v1/validate/code", json={"code": ""})
|
||||
assert response5.status_code == 200
|
||||
assert response5.json() == {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
||||
|
|
@ -69,7 +69,7 @@ import math
|
|||
def square(x)
|
||||
return x ** 2
|
||||
"""
|
||||
response6 = client.post("/validate/code", json={"code": code6})
|
||||
response6 = client.post("api/v1/validate/code", json={"code": code6})
|
||||
assert response6.status_code == 200
|
||||
assert response6.json() == {
|
||||
"imports": {"errors": []},
|
||||
|
|
@ -95,13 +95,13 @@ INVALID_PROMPT = "This is an invalid prompt without any input variable."
|
|||
|
||||
|
||||
def test_valid_prompt(client: TestClient):
|
||||
response = client.post("/validate/prompt", json={"template": VALID_PROMPT})
|
||||
response = client.post("api/v1/validate/prompt", json={"template": VALID_PROMPT})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"input_variables": ["product"]}
|
||||
|
||||
|
||||
def test_invalid_prompt(client: TestClient):
|
||||
response = client.post("/validate/prompt", json={"template": INVALID_PROMPT})
|
||||
response = client.post("api/v1/validate/prompt", json={"template": INVALID_PROMPT})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"input_variables": []}
|
||||
|
||||
|
|
@ -116,7 +116,7 @@ def test_invalid_prompt(client: TestClient):
|
|||
],
|
||||
)
|
||||
def test_various_prompts(client, prompt, expected_input_variables):
|
||||
response = client.post("/validate/prompt", json={"template": prompt})
|
||||
response = client.post("api/v1/validate/prompt", json={"template": prompt})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {
|
||||
"input_variables": expected_input_variables,
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from langflow.graph.vertex.types import (
|
|||
ToolVertex,
|
||||
WrapperVertex,
|
||||
)
|
||||
from langflow.interface.run import get_result_and_thought
|
||||
from langflow.processing.process import get_result_and_thought
|
||||
from langflow.utils.payload import get_root_node
|
||||
|
||||
# Test cases for the graph module
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.settings import settings
|
|||
|
||||
|
||||
def test_llms_settings(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
llms = json_response["llms"]
|
||||
|
|
@ -11,7 +11,7 @@ def test_llms_settings(client: TestClient):
|
|||
|
||||
|
||||
# def test_hugging_face_hub(client: TestClient):
|
||||
# response = client.get("/all")
|
||||
# response = client.get("api/v1/all")
|
||||
# assert response.status_code == 200
|
||||
# json_response = response.json()
|
||||
# language_models = json_response["llms"]
|
||||
|
|
@ -103,7 +103,7 @@ def test_llms_settings(client: TestClient):
|
|||
|
||||
|
||||
def test_openai(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
language_models = json_response["llms"]
|
||||
|
|
@ -333,7 +333,7 @@ def test_openai(client: TestClient):
|
|||
|
||||
|
||||
def test_chat_open_ai(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
language_models = json_response["llms"]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import json
|
|||
|
||||
import pytest
|
||||
from langchain.chains.base import Chain
|
||||
from langflow import load_flow_from_json
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
from langflow.graph import Graph
|
||||
from langflow.utils.payload import get_root_node
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.settings import settings
|
|||
|
||||
|
||||
def test_prompts_settings(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
prompts = json_response["prompts"]
|
||||
|
|
@ -11,7 +11,7 @@ def test_prompts_settings(client: TestClient):
|
|||
|
||||
|
||||
def test_prompt_template(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
prompts = json_response["prompts"]
|
||||
|
|
@ -89,7 +89,7 @@ def test_prompt_template(client: TestClient):
|
|||
|
||||
|
||||
def test_few_shot_prompt_template(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
prompts = json_response["prompts"]
|
||||
|
|
@ -168,7 +168,7 @@ def test_few_shot_prompt_template(client: TestClient):
|
|||
|
||||
|
||||
def test_zero_shot_prompt(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
prompts = json_response["prompts"]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from langflow.settings import settings
|
|||
# check that all agents are in settings.agents
|
||||
# are in json_response["agents"]
|
||||
def test_vectorstores_settings(client: TestClient):
|
||||
response = client.get("/all")
|
||||
response = client.get("api/v1/all")
|
||||
assert response.status_code == 200
|
||||
json_response = response.json()
|
||||
vectorstores = json_response["vectorstores"]
|
||||
|
|
|
|||
|
|
@ -5,17 +5,17 @@ from fastapi.testclient import TestClient
|
|||
|
||||
|
||||
def test_websocket_connection(client: TestClient):
|
||||
with client.websocket_connect("/chat/test_client") as websocket:
|
||||
with client.websocket_connect("api/v1/chat/test_client") as websocket:
|
||||
assert websocket.scope["client"] == ["testclient", 50000]
|
||||
assert websocket.scope["path"] == "/chat/test_client"
|
||||
assert websocket.scope["path"] == "/api/v1/chat/test_client"
|
||||
|
||||
|
||||
def test_chat_history(client: TestClient):
|
||||
# Mock the process_graph function to return a specific value
|
||||
with patch("langflow.api.chat_manager.process_graph") as mock_process_graph:
|
||||
with patch("langflow.chat.manager.process_graph") as mock_process_graph:
|
||||
mock_process_graph.return_value = ("Hello, I'm a mock response!", "")
|
||||
|
||||
with client.websocket_connect("/chat/test_client") as websocket:
|
||||
with client.websocket_connect("api/v1/chat/test_client") as websocket:
|
||||
# First message should be the history
|
||||
history = websocket.receive_json()
|
||||
assert history == [] # Empty history
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue