diff --git a/pyproject.toml b/pyproject.toml index cde7ddca5..87843198b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,15 @@ types-pillow = "^9.5.0.2" [tool.poetry.extras] deploy = ["langchain-serve"] +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra" +testpaths = ["tests", "integration"] +console_output_style = "progress" +filterwarnings = ["ignore::DeprecationWarning"] +log_cli = true + + [tool.ruff] line-length = 120 diff --git a/src/backend/langflow/__init__.py b/src/backend/langflow/__init__.py index 35fe814d2..17b1d940c 100644 --- a/src/backend/langflow/__init__.py +++ b/src/backend/langflow/__init__.py @@ -1,4 +1,4 @@ from langflow.cache import cache_manager -from langflow.interface.loading import load_flow_from_json +from langflow.processing.process import load_flow_from_json __all__ = ["load_flow_from_json", "cache_manager"] diff --git a/src/backend/langflow/api/__init__.py b/src/backend/langflow/api/__init__.py index e69de29bb..f887c47e1 100644 --- a/src/backend/langflow/api/__init__.py +++ b/src/backend/langflow/api/__init__.py @@ -0,0 +1,3 @@ +from langflow.api.router import router + +__all__ = ["router"] diff --git a/src/backend/langflow/api/router.py b/src/backend/langflow/api/router.py new file mode 100644 index 000000000..23b5aa1c5 --- /dev/null +++ b/src/backend/langflow/api/router.py @@ -0,0 +1,8 @@ +# Router for base api +from fastapi import APIRouter +from langflow.api.v1 import chat_router, endpoints_router, validate_router + +router = APIRouter(prefix="/api/v1", tags=["api"]) +router.include_router(chat_router) +router.include_router(endpoints_router) +router.include_router(validate_router) diff --git a/src/backend/langflow/api/v1/__init__.py b/src/backend/langflow/api/v1/__init__.py new file mode 100644 index 000000000..d835b4535 --- /dev/null +++ b/src/backend/langflow/api/v1/__init__.py @@ -0,0 +1,5 @@ +from langflow.api.v1.endpoints import router as endpoints_router +from langflow.api.v1.validate import router as validate_router +from langflow.api.v1.chat import router as chat_router + +__all__ = ["chat_router", "endpoints_router", "validate_router"] diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/v1/base.py similarity index 96% rename from src/backend/langflow/api/base.py rename to src/backend/langflow/api/v1/base.py index 8cddc52e4..6941bedf3 100644 --- a/src/backend/langflow/api/base.py +++ b/src/backend/langflow/api/v1/base.py @@ -1,6 +1,6 @@ from pydantic import BaseModel, validator -from langflow.graph.utils import extract_input_variables_from_prompt +from langflow.interface.utils import extract_input_variables_from_prompt class CacheResponse(BaseModel): diff --git a/src/backend/langflow/api/callback.py b/src/backend/langflow/api/v1/callback.py similarity index 95% rename from src/backend/langflow/api/callback.py rename to src/backend/langflow/api/v1/callback.py index d63e107c4..b58393d7b 100644 --- a/src/backend/langflow/api/callback.py +++ b/src/backend/langflow/api/v1/callback.py @@ -3,7 +3,7 @@ from typing import Any from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler -from langflow.api.schemas import ChatResponse +from langflow.api.v1.schemas import ChatResponse # https://github.com/hwchase17/chat-langchain/blob/master/callback.py diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/v1/chat.py similarity index 93% rename from src/backend/langflow/api/chat.py rename to src/backend/langflow/api/v1/chat.py index 4afa6c22f..7df4c65ed 100644 --- a/src/backend/langflow/api/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -6,7 +6,7 @@ from fastapi import ( status, ) -from langflow.api.chat_manager import ChatManager +from langflow.chat.manager import ChatManager from langflow.utils.logger import logger router = APIRouter() diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/v1/endpoints.py similarity index 87% rename from src/backend/langflow/api/endpoints.py rename to src/backend/langflow/api/v1/endpoints.py index 021a81ca8..1e9b0deb1 100644 --- a/src/backend/langflow/api/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -3,13 +3,13 @@ from importlib.metadata import version from fastapi import APIRouter, HTTPException -from langflow.api.schemas import ( +from langflow.api.v1.schemas import ( ExportedFlow, GraphData, PredictRequest, PredictResponse, ) -from langflow.interface.run import process_graph_cached + from langflow.interface.types import build_langchain_types_dict # build router @@ -25,6 +25,8 @@ def get_all(): @router.post("/predict", response_model=PredictResponse) async def get_load(predict_request: PredictRequest): try: + from langflow.processing.process import process_graph_cached + exported_flow: ExportedFlow = predict_request.exported_flow graph_data: GraphData = exported_flow.data data = graph_data.dict() @@ -40,8 +42,3 @@ async def get_load(predict_request: PredictRequest): @router.get("/version") def get_version(): return {"version": version("langflow")} - - -@router.get("/health") -def get_health(): - return {"status": "OK"} diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/v1/schemas.py similarity index 100% rename from src/backend/langflow/api/schemas.py rename to src/backend/langflow/api/v1/schemas.py diff --git a/src/backend/langflow/api/validate.py b/src/backend/langflow/api/v1/validate.py similarity index 93% rename from src/backend/langflow/api/validate.py rename to src/backend/langflow/api/v1/validate.py index e90e554f0..009cb9a30 100644 --- a/src/backend/langflow/api/validate.py +++ b/src/backend/langflow/api/v1/validate.py @@ -2,7 +2,7 @@ import json from fastapi import APIRouter, HTTPException -from langflow.api.base import ( +from langflow.api.v1.base import ( Code, CodeValidationResponse, Prompt, @@ -10,7 +10,7 @@ from langflow.api.base import ( validate_prompt, ) from langflow.graph.vertex.types import VectorStoreVertex -from langflow.interface.run import build_graph +from langflow.graph import Graph from langflow.utils.logger import logger from langflow.utils.validate import validate_code @@ -44,7 +44,7 @@ def post_validate_prompt(prompt: Prompt): def post_validate_node(node_id: str, data: dict): try: # build graph - graph = build_graph(data) + graph = Graph.from_payload(data) # validate node node = graph.get_node(node_id) if node is None: diff --git a/src/backend/langflow/chat/__init__.py b/src/backend/langflow/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/chat/manager.py similarity index 85% rename from src/backend/langflow/api/chat_manager.py rename to src/backend/langflow/chat/manager.py index 8b1c7a621..d24057b68 100644 --- a/src/backend/langflow/api/chat_manager.py +++ b/src/backend/langflow/chat/manager.py @@ -1,21 +1,18 @@ -import asyncio -import json from collections import defaultdict -from typing import Dict, List - from fastapi import WebSocket, status - -from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse +from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse from langflow.cache import cache_manager from langflow.cache.manager import Subject -from langflow.interface.run import ( - get_result_and_steps, - load_or_build_langchain_object, -) -from langflow.interface.utils import pil_to_base64, try_setting_streaming_options +from langflow.chat.utils import process_graph +from langflow.interface.utils import pil_to_base64 from langflow.utils.logger import logger +import asyncio +import json +from typing import Dict, List + + class ChatHistory(Subject): def __init__(self): super().__init__() @@ -191,33 +188,3 @@ class ChatManager: except Exception as e: logger.exception(e) self.disconnect(client_id) - - -async def process_graph( - graph_data: Dict, - is_first_message: bool, - chat_message: ChatMessage, - websocket: WebSocket, -): - langchain_object = load_or_build_langchain_object(graph_data, is_first_message) - langchain_object = try_setting_streaming_options(langchain_object, websocket) - logger.debug("Loaded langchain object") - - if langchain_object is None: - # Raise user facing error - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) - - # Generate result and thought - try: - logger.debug("Generating result and thought") - result, intermediate_steps = await get_result_and_steps( - langchain_object, chat_message.message or "", websocket=websocket - ) - logger.debug("Generated result and intermediate_steps") - return result, intermediate_steps - except Exception as e: - # Log stack trace - logger.exception(e) - raise e diff --git a/src/backend/langflow/chat/utils.py b/src/backend/langflow/chat/utils.py new file mode 100644 index 000000000..410a442be --- /dev/null +++ b/src/backend/langflow/chat/utils.py @@ -0,0 +1,41 @@ +from fastapi import WebSocket +from langflow.api.v1.schemas import ChatMessage +from langflow.processing.process import ( + load_or_build_langchain_object, +) +from langflow.processing.base import get_result_and_steps +from langflow.interface.utils import try_setting_streaming_options +from langflow.utils.logger import logger + + +from typing import Dict + + +async def process_graph( + graph_data: Dict, + is_first_message: bool, + chat_message: ChatMessage, + websocket: WebSocket, +): + langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + langchain_object = try_setting_streaming_options(langchain_object, websocket) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + try: + logger.debug("Generating result and thought") + result, intermediate_steps = await get_result_and_steps( + langchain_object, chat_message.message or "", websocket=websocket + ) + logger.debug("Generated result and intermediate_steps") + return result, intermediate_steps + except Exception as e: + # Log stack trace + logger.exception(e) + raise e diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 020f539ec..5fd00d09b 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -24,6 +24,27 @@ class Graph: self._edges = edges self._build_graph() + @classmethod + @classmethod + def from_payload(cls, payload: Dict) -> "Graph": + """ + Creates a graph from a payload. + + Args: + payload (Dict): The payload to create the graph from. + + Returns: + Graph: The created graph. + """ + if "data" in payload: + payload = payload["data"] + try: + nodes = payload["nodes"] + edges = payload["edges"] + return cls(nodes, edges) + except KeyError as exc: + raise ValueError("Invalid payload") from exc + def _build_graph(self) -> None: """Builds the graph from the nodes and edges.""" self.nodes = self._build_vertices() diff --git a/src/backend/langflow/graph/graph/utils.py b/src/backend/langflow/graph/graph/utils.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/graph/utils.py b/src/backend/langflow/graph/utils.py index e22b27cf5..b78b2f961 100644 --- a/src/backend/langflow/graph/utils.py +++ b/src/backend/langflow/graph/utils.py @@ -1,6 +1,7 @@ -import re from typing import Any, Union +from langflow.interface.utils import extract_input_variables_from_prompt + def validate_prompt(prompt: str): """Validate prompt.""" @@ -15,11 +16,6 @@ def fix_prompt(prompt: str): return prompt + " {input}" -def extract_input_variables_from_prompt(prompt: str) -> list[str]: - """Extract input variables from prompt.""" - return re.findall(r"{(.*?)}", prompt) - - def flatten_list(list_of_lists: list[Union[list, Any]]) -> list: """Flatten list of lists.""" new_list = [] diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index b81e72439..4eb20f416 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional, Union from langflow.graph.vertex.base import Vertex -from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list +from langflow.graph.utils import flatten_list +from langflow.interface.utils import extract_input_variables_from_prompt class AgentVertex(Vertex): diff --git a/src/backend/langflow/interface/chains/custom.py b/src/backend/langflow/interface/chains/custom.py index cb76a53c8..ba4ba8b62 100644 --- a/src/backend/langflow/interface/chains/custom.py +++ b/src/backend/langflow/interface/chains/custom.py @@ -5,7 +5,7 @@ from langchain.memory.buffer import ConversationBufferMemory from langchain.schema import BaseMemory from pydantic import Field, root_validator -from langflow.graph.utils import extract_input_variables_from_prompt +from langflow.interface.utils import extract_input_variables_from_prompt DEFAULT_SUFFIX = """" Current conversation: diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 16a7b186c..eb4623f5a 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -12,7 +12,6 @@ from langchain.agents.load_tools import ( _LLM_TOOLS, ) from langchain.agents.loading import load_agent_from_config -from langflow.graph import Graph from langchain.agents.tools import Tool from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager @@ -22,7 +21,6 @@ from pydantic import ValidationError from langflow.interface.agents.custom import CUSTOM_AGENTS from langflow.interface.importing.utils import get_function, import_by_type -from langflow.interface.run import fix_memory_inputs from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.types import get_type_list from langflow.interface.utils import load_file_into_dict @@ -163,37 +161,6 @@ def instantiate_utility(node_type, class_object, params): return class_object(**params) -def load_flow_from_json(path: str, build=True): - """Load flow from json file""" - # This is done to avoid circular imports - - with open(path, "r", encoding="utf-8") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - # Substitute ZeroShotPrompt with PromptTemplate - # nodes = replace_zero_shot_prompt_with_prompt_template(nodes) - # Add input variables - # nodes = payload.extract_input_variables(nodes) - - # Nodes, edges and root node - edges = data_graph["edges"] - graph = Graph(nodes, edges) - if build: - langchain_object = graph.build() - if hasattr(langchain_object, "verbose"): - langchain_object.verbose = True - - if hasattr(langchain_object, "return_intermediate_steps"): - # https://github.com/hwchase17/langchain/issues/2068 - # Deactivating until we have a frontend solution - # to display intermediate steps - langchain_object.return_intermediate_steps = False - fix_memory_inputs(langchain_object) - return langchain_object - return graph - - def replace_zero_shot_prompt_with_prompt_template(nodes): """Replace ZeroShotPrompt with PromptTemplate""" for node in nodes: diff --git a/src/backend/langflow/interface/prompts/custom.py b/src/backend/langflow/interface/prompts/custom.py index b1dbef370..286210271 100644 --- a/src/backend/langflow/interface/prompts/custom.py +++ b/src/backend/langflow/interface/prompts/custom.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Type from langchain.prompts import PromptTemplate from pydantic import root_validator -from langflow.graph.utils import extract_input_variables_from_prompt +from langflow.interface.utils import extract_input_variables_from_prompt # Steps to create a BaseCustomPrompt: # 1. Create a prompt template that endes with: diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index c2483416f..89f71fd8b 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -1,10 +1,3 @@ -import contextlib -import io -from typing import Any, Dict, List, Tuple - -from langchain.schema import AgentAction - -from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict from langflow.graph import Graph from langflow.utils.logger import logger @@ -24,15 +17,6 @@ def load_langchain_object(data_graph, is_first_message=False): return computed_hash, langchain_object -def load_or_build_langchain_object(data_graph, is_first_message=False): - """ - Load langchain object from cache if it exists, otherwise build it. - """ - if is_first_message: - build_langchain_object_with_caching.clear_cache() - return build_langchain_object_with_caching(data_graph) - - @memoize_dict(maxsize=10) def build_langchain_object_with_caching(data_graph): """ @@ -40,16 +24,10 @@ def build_langchain_object_with_caching(data_graph): """ logger.debug("Building langchain object") - graph = build_graph(data_graph) + graph = Graph.from_payload(data_graph) return graph.build() -def build_graph(data_graph): - nodes = data_graph["nodes"] - edges = data_graph["edges"] - return Graph(nodes, edges) - - def build_langchain_object(data_graph): """ Build langchain object from data_graph. @@ -66,29 +44,6 @@ def build_langchain_object(data_graph): return graph.build() -def process_graph_cached(data_graph: Dict[str, Any], message: str): - """ - Process graph by extracting input variables and replacing ZeroShotPrompt - with PromptTemplate,then run the graph and return the result and thought. - """ - # Load langchain object - is_first_message = len(data_graph.get("chatHistory", [])) == 0 - langchain_object = load_or_build_langchain_object(data_graph, is_first_message) - logger.debug("Loaded langchain object") - - if langchain_object is None: - # Raise user facing error - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) - - # Generate result and thought - logger.debug("Generating result and thought") - result, thought = get_result_and_thought(langchain_object, message) - logger.debug("Generated result and thought") - return {"result": str(result), "thought": thought.strip()} - - def get_memory_key(langchain_object): """ Given a LangChain object, this function retrieves the current memory key from the object's memory attribute. @@ -124,147 +79,3 @@ def update_memory_keys(langchain_object, possible_new_mem_key): langchain_object.memory.input_key = input_key langchain_object.memory.output_key = output_key langchain_object.memory.memory_key = possible_new_mem_key - - -def fix_memory_inputs(langchain_object): - """ - Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the - object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the - get_memory_key function and updates the memory keys using the update_memory_keys function. - """ - if hasattr(langchain_object, "memory") and langchain_object.memory is not None: - try: - if langchain_object.memory.memory_key in langchain_object.input_variables: - return - except AttributeError: - input_variables = ( - langchain_object.prompt.input_variables - if hasattr(langchain_object, "prompt") - else langchain_object.input_keys - ) - if langchain_object.memory.memory_key in input_variables: - return - - possible_new_mem_key = get_memory_key(langchain_object) - if possible_new_mem_key is not None: - update_memory_keys(langchain_object, possible_new_mem_key) - - -async def get_result_and_steps(langchain_object, message: str, **kwargs): - """Get result and thought from extracted json""" - - try: - if hasattr(langchain_object, "verbose"): - langchain_object.verbose = True - chat_input = None - memory_key = "" - if hasattr(langchain_object, "memory") and langchain_object.memory is not None: - memory_key = langchain_object.memory.memory_key - - if hasattr(langchain_object, "input_keys"): - for key in langchain_object.input_keys: - if key not in [memory_key, "chat_history"]: - chat_input = {key: message} - else: - chat_input = message # type: ignore - - if hasattr(langchain_object, "return_intermediate_steps"): - # https://github.com/hwchase17/langchain/issues/2068 - # Deactivating until we have a frontend solution - # to display intermediate steps - langchain_object.return_intermediate_steps = True - - fix_memory_inputs(langchain_object) - try: - async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)] - output = await langchain_object.acall(chat_input, callbacks=async_callbacks) - except Exception as exc: - # make the error message more informative - logger.debug(f"Error: {str(exc)}") - sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)] - output = langchain_object(chat_input, callbacks=sync_callbacks) - - intermediate_steps = ( - output.get("intermediate_steps", []) if isinstance(output, dict) else [] - ) - - result = ( - output.get(langchain_object.output_keys[0]) - if isinstance(output, dict) - else output - ) - thought = format_actions(intermediate_steps) if intermediate_steps else "" - except Exception as exc: - raise ValueError(f"Error: {str(exc)}") from exc - return result, thought - - -def get_result_and_thought(langchain_object, message: str): - """Get result and thought from extracted json""" - try: - if hasattr(langchain_object, "verbose"): - langchain_object.verbose = True - chat_input = None - memory_key = "" - if hasattr(langchain_object, "memory") and langchain_object.memory is not None: - memory_key = langchain_object.memory.memory_key - - if hasattr(langchain_object, "input_keys"): - for key in langchain_object.input_keys: - if key not in [memory_key, "chat_history"]: - chat_input = {key: message} - else: - chat_input = message # type: ignore - - if hasattr(langchain_object, "return_intermediate_steps"): - # https://github.com/hwchase17/langchain/issues/2068 - # Deactivating until we have a frontend solution - # to display intermediate steps - langchain_object.return_intermediate_steps = False - - fix_memory_inputs(langchain_object) - - with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): - try: - # if hasattr(langchain_object, "acall"): - # output = await langchain_object.acall(chat_input) - # else: - output = langchain_object(chat_input) - except ValueError as exc: - # make the error message more informative - logger.debug(f"Error: {str(exc)}") - output = langchain_object.run(chat_input) - - intermediate_steps = ( - output.get("intermediate_steps", []) if isinstance(output, dict) else [] - ) - - result = ( - output.get(langchain_object.output_keys[0]) - if isinstance(output, dict) - else output - ) - if intermediate_steps: - thought = format_actions(intermediate_steps) - else: - thought = output_buffer.getvalue() - - except Exception as exc: - raise ValueError(f"Error: {str(exc)}") from exc - return result, thought - - -def format_actions(actions: List[Tuple[AgentAction, str]]) -> str: - """Format a list of (AgentAction, answer) tuples into a string.""" - output = [] - for action, answer in actions: - log = action.log - tool = action.tool - tool_input = action.tool_input - output.append(f"Log: {log}") - if "Action" not in log and "Action Input" not in log: - output.append(f"Tool: {tool}") - output.append(f"Tool Input: {tool_input}") - output.append(f"Answer: {answer}") - output.append("") # Add a blank line - return "\n".join(output) diff --git a/src/backend/langflow/interface/utils.py b/src/backend/langflow/interface/utils.py index 2b7c5acd1..32c605654 100644 --- a/src/backend/langflow/interface/utils.py +++ b/src/backend/langflow/interface/utils.py @@ -2,6 +2,7 @@ import base64 import json import os from io import BytesIO +import re import yaml from langchain.base_language import BaseLanguageModel @@ -48,3 +49,8 @@ def try_setting_streaming_options(langchain_object, websocket): llm.streaming = True return langchain_object + + +def extract_input_variables_from_prompt(prompt: str) -> list[str]: + """Extract input variables from prompt.""" + return re.findall(r"{(.*?)}", prompt) diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 56cc32e46..de39d8750 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -1,9 +1,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from langflow.api.chat import router as chat_router -from langflow.api.endpoints import router as endpoints_router -from langflow.api.validate import router as validate_router +from langflow.api import router def create_app(): @@ -14,6 +12,10 @@ def create_app(): "*", ] + @app.get("/health") + def get_health(): + return {"status": "OK"} + app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -22,9 +24,7 @@ def create_app(): allow_headers=["*"], ) - app.include_router(endpoints_router) - app.include_router(validate_router) - app.include_router(chat_router) + app.include_router(router) return app diff --git a/src/backend/langflow/processing/__init__.py b/src/backend/langflow/processing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/processing/base.py b/src/backend/langflow/processing/base.py new file mode 100644 index 000000000..97b0d5be0 --- /dev/null +++ b/src/backend/langflow/processing/base.py @@ -0,0 +1,55 @@ +from langflow.api.v1.callback import ( + AsyncStreamingLLMCallbackHandler, + StreamingLLMCallbackHandler, +) +from langflow.processing.process import fix_memory_inputs, format_actions +from langflow.utils.logger import logger + + +async def get_result_and_steps(langchain_object, message: str, **kwargs): + """Get result and thought from extracted json""" + + try: + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + chat_input = None + memory_key = "" + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + memory_key = langchain_object.memory.memory_key + + if hasattr(langchain_object, "input_keys"): + for key in langchain_object.input_keys: + if key not in [memory_key, "chat_history"]: + chat_input = {key: message} + else: + chat_input = message # type: ignore + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = True + + fix_memory_inputs(langchain_object) + try: + async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)] + output = await langchain_object.acall(chat_input, callbacks=async_callbacks) + except Exception as exc: + # make the error message more informative + logger.debug(f"Error: {str(exc)}") + sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)] + output = langchain_object(chat_input, callbacks=sync_callbacks) + + intermediate_steps = ( + output.get("intermediate_steps", []) if isinstance(output, dict) else [] + ) + + result = ( + output.get(langchain_object.output_keys[0]) + if isinstance(output, dict) + else output + ) + thought = format_actions(intermediate_steps) if intermediate_steps else "" + except Exception as exc: + raise ValueError(f"Error: {str(exc)}") from exc + return result, thought diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py new file mode 100644 index 000000000..3b8852e00 --- /dev/null +++ b/src/backend/langflow/processing/process.py @@ -0,0 +1,172 @@ +import contextlib +import io +from langchain.schema import AgentAction +import json +from langflow.interface.run import ( + build_langchain_object_with_caching, + get_memory_key, + update_memory_keys, +) +from langflow.utils.logger import logger +from langflow.graph import Graph + + +from typing import Any, Dict, List, Tuple + + +def fix_memory_inputs(langchain_object): + """ + Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the + object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the + get_memory_key function and updates the memory keys using the update_memory_keys function. + """ + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + try: + if langchain_object.memory.memory_key in langchain_object.input_variables: + return + except AttributeError: + input_variables = ( + langchain_object.prompt.input_variables + if hasattr(langchain_object, "prompt") + else langchain_object.input_keys + ) + if langchain_object.memory.memory_key in input_variables: + return + + possible_new_mem_key = get_memory_key(langchain_object) + if possible_new_mem_key is not None: + update_memory_keys(langchain_object, possible_new_mem_key) + + +def format_actions(actions: List[Tuple[AgentAction, str]]) -> str: + """Format a list of (AgentAction, answer) tuples into a string.""" + output = [] + for action, answer in actions: + log = action.log + tool = action.tool + tool_input = action.tool_input + output.append(f"Log: {log}") + if "Action" not in log and "Action Input" not in log: + output.append(f"Tool: {tool}") + output.append(f"Tool Input: {tool_input}") + output.append(f"Answer: {answer}") + output.append("") # Add a blank line + return "\n".join(output) + + +def get_result_and_thought(langchain_object, message: str): + """Get result and thought from extracted json""" + try: + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + chat_input = None + memory_key = "" + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + memory_key = langchain_object.memory.memory_key + + if hasattr(langchain_object, "input_keys"): + for key in langchain_object.input_keys: + if key not in [memory_key, "chat_history"]: + chat_input = {key: message} + else: + chat_input = message # type: ignore + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + + fix_memory_inputs(langchain_object) + + with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): + try: + # if hasattr(langchain_object, "acall"): + # output = await langchain_object.acall(chat_input) + # else: + output = langchain_object(chat_input) + except ValueError as exc: + # make the error message more informative + logger.debug(f"Error: {str(exc)}") + output = langchain_object.run(chat_input) + + intermediate_steps = ( + output.get("intermediate_steps", []) if isinstance(output, dict) else [] + ) + + result = ( + output.get(langchain_object.output_keys[0]) + if isinstance(output, dict) + else output + ) + if intermediate_steps: + thought = format_actions(intermediate_steps) + else: + thought = output_buffer.getvalue() + + except Exception as exc: + raise ValueError(f"Error: {str(exc)}") from exc + return result, thought + + +def load_or_build_langchain_object(data_graph, is_first_message=False): + """ + Load langchain object from cache if it exists, otherwise build it. + """ + if is_first_message: + build_langchain_object_with_caching.clear_cache() + return build_langchain_object_with_caching(data_graph) + + +def process_graph_cached(data_graph: Dict[str, Any], message: str): + """ + Process graph by extracting input variables and replacing ZeroShotPrompt + with PromptTemplate,then run the graph and return the result and thought. + """ + # Load langchain object + is_first_message = len(data_graph.get("chatHistory", [])) == 0 + langchain_object = load_or_build_langchain_object(data_graph, is_first_message) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + logger.debug("Generating result and thought") + result, thought = get_result_and_thought(langchain_object, message) + logger.debug("Generated result and thought") + return {"result": str(result), "thought": thought.strip()} + + +def load_flow_from_json(path: str, build=True): + """Load flow from json file""" + # This is done to avoid circular imports + + with open(path, "r", encoding="utf-8") as f: + flow_graph = json.load(f) + data_graph = flow_graph["data"] + nodes = data_graph["nodes"] + # Substitute ZeroShotPrompt with PromptTemplate + # nodes = replace_zero_shot_prompt_with_prompt_template(nodes) + # Add input variables + # nodes = payload.extract_input_variables(nodes) + + # Nodes, edges and root node + edges = data_graph["edges"] + graph = Graph(nodes, edges) + if build: + langchain_object = graph.build() + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + fix_memory_inputs(langchain_object) + return langchain_object + return graph diff --git a/src/frontend/src/controllers/API/index.ts b/src/frontend/src/controllers/API/index.ts index f6f46404b..0cffd04bf 100644 --- a/src/frontend/src/controllers/API/index.ts +++ b/src/frontend/src/controllers/API/index.ts @@ -14,13 +14,13 @@ export async function sendAll(data: sendAllProps) { export async function checkCode( code: string ): Promise> { - return await axios.post("/validate/code", { code }); + return await axios.post("api/v1/validate/code", { code }); } export async function checkPrompt( template: string ): Promise> { - return await axios.post("/validate/prompt", { template }); + return await axios.post("api/v1/validate/prompt", { template }); } export async function getExamples(): Promise { diff --git a/src/frontend/src/modals/chatModal/index.tsx b/src/frontend/src/modals/chatModal/index.tsx index cf2b52aac..39bb72994 100644 --- a/src/frontend/src/modals/chatModal/index.tsx +++ b/src/frontend/src/modals/chatModal/index.tsx @@ -182,10 +182,10 @@ export default function ChatModal({ try { const urlWs = process.env.NODE_ENV === "development" - ? `ws://localhost:7860/chat/${id.current}` + ? `ws://localhost:7860/api/v1/chat/${id.current}` : `${window.location.protocol === "https:" ? "wss" : "ws"}://${ window.location.host - }/chat/${id.current}`; + }api/v1/chat/${id.current}`; const newWs = new WebSocket(urlWs); newWs.onopen = () => { console.log("WebSocket connection established!"); diff --git a/src/frontend/vite.config.ts b/src/frontend/vite.config.ts index 172b37733..d4fa2248b 100644 --- a/src/frontend/vite.config.ts +++ b/src/frontend/vite.config.ts @@ -11,7 +11,7 @@ const apiRoutes = [ ]; // Use environment variable to determine the target. -const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860"; +const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860/api/v1"; const proxyTargets = apiRoutes.reduce((proxyObj, route) => { proxyObj[route] = { diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index 7aa8de176..8e181711f 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -5,7 +5,7 @@ from langflow.settings import settings # check that all agents are in settings.agents # are in json_response["agents"] def test_agents_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] @@ -13,7 +13,7 @@ def test_agents_settings(client: TestClient): def test_zero_shot_agent(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] @@ -52,7 +52,7 @@ def test_zero_shot_agent(client: TestClient): def test_json_agent(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] @@ -87,7 +87,7 @@ def test_json_agent(client: TestClient): def test_csv_agent(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] @@ -126,7 +126,7 @@ def test_csv_agent(client: TestClient): def test_initialize_agent(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] diff --git a/tests/test_cache.py b/tests/test_cache.py index 3d3e951fc..3214e7d15 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,10 +1,10 @@ import json +from langflow.graph import Graph +from langflow.processing.process import load_or_build_langchain_object import pytest from langflow.interface.run import ( - build_graph, build_langchain_object_with_caching, - load_or_build_langchain_object, ) @@ -62,7 +62,7 @@ def test_build_langchain_object_with_caching(basic_data_graph): # Test build_graph def test_build_graph(basic_data_graph): - graph = build_graph(basic_data_graph) + graph = Graph.from_payload(basic_data_graph) assert graph is not None assert len(graph.nodes) == len(basic_data_graph["nodes"]) assert len(graph.edges) == len(basic_data_graph["edges"]) diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index c958cf64d..0c7af56ad 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -3,7 +3,7 @@ from langflow.settings import settings def test_chains_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -12,7 +12,7 @@ def test_chains_settings(client: TestClient): # Test the ConversationChain object def test_conversation_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -94,7 +94,7 @@ def test_conversation_chain(client: TestClient): def test_llm_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -152,7 +152,7 @@ def test_llm_chain(client: TestClient): def test_llm_checker_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -228,7 +228,7 @@ def test_llm_checker_chain(client: TestClient): def test_llm_math_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -306,7 +306,7 @@ def test_llm_math_chain(client: TestClient): def test_series_character_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -368,7 +368,7 @@ def test_series_character_chain(client: TestClient): def test_mid_journey_prompt_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -407,7 +407,7 @@ def test_mid_journey_prompt_chain(client: TestClient): def test_time_travel_guide_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 83f6c62b1..9e07dfb24 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -4,7 +4,7 @@ from langflow.interface.tools.constants import CUSTOM_TOOLS def test_get_all(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() # We need to test the custom nodes @@ -21,7 +21,7 @@ import math def square(x): return x ** 2 """ - response1 = client.post("/validate/code", json={"code": code1}) + response1 = client.post("api/v1/validate/code", json={"code": code1}) assert response1.status_code == 200 assert response1.json() == {"imports": {"errors": []}, "function": {"errors": []}} @@ -32,7 +32,7 @@ import non_existent_module def square(x): return x ** 2 """ - response2 = client.post("/validate/code", json={"code": code2}) + response2 = client.post("api/v1/validate/code", json={"code": code2}) assert response2.status_code == 200 assert response2.json() == { "imports": {"errors": ["No module named 'non_existent_module'"]}, @@ -46,7 +46,7 @@ import math def square(x) return x ** 2 """ - response3 = client.post("/validate/code", json={"code": code3}) + response3 = client.post("api/v1/validate/code", json={"code": code3}) assert response3.status_code == 200 assert response3.json() == { "imports": {"errors": []}, @@ -54,11 +54,11 @@ def square(x) } # Test case with invalid JSON payload - response4 = client.post("/validate/code", json={"invalid_key": code1}) + response4 = client.post("api/v1/validate/code", json={"invalid_key": code1}) assert response4.status_code == 422 # Test case with an empty code string - response5 = client.post("/validate/code", json={"code": ""}) + response5 = client.post("api/v1/validate/code", json={"code": ""}) assert response5.status_code == 200 assert response5.json() == {"imports": {"errors": []}, "function": {"errors": []}} @@ -69,7 +69,7 @@ import math def square(x) return x ** 2 """ - response6 = client.post("/validate/code", json={"code": code6}) + response6 = client.post("api/v1/validate/code", json={"code": code6}) assert response6.status_code == 200 assert response6.json() == { "imports": {"errors": []}, @@ -95,13 +95,13 @@ INVALID_PROMPT = "This is an invalid prompt without any input variable." def test_valid_prompt(client: TestClient): - response = client.post("/validate/prompt", json={"template": VALID_PROMPT}) + response = client.post("api/v1/validate/prompt", json={"template": VALID_PROMPT}) assert response.status_code == 200 assert response.json() == {"input_variables": ["product"]} def test_invalid_prompt(client: TestClient): - response = client.post("/validate/prompt", json={"template": INVALID_PROMPT}) + response = client.post("api/v1/validate/prompt", json={"template": INVALID_PROMPT}) assert response.status_code == 200 assert response.json() == {"input_variables": []} @@ -116,7 +116,7 @@ def test_invalid_prompt(client: TestClient): ], ) def test_various_prompts(client, prompt, expected_input_variables): - response = client.post("/validate/prompt", json={"template": prompt}) + response = client.post("api/v1/validate/prompt", json={"template": prompt}) assert response.status_code == 200 assert response.json() == { "input_variables": expected_input_variables, diff --git a/tests/test_graph.py b/tests/test_graph.py index 8c6560d54..69a926cc3 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -16,7 +16,7 @@ from langflow.graph.vertex.types import ( ToolVertex, WrapperVertex, ) -from langflow.interface.run import get_result_and_thought +from langflow.processing.process import get_result_and_thought from langflow.utils.payload import get_root_node # Test cases for the graph module diff --git a/tests/test_llms_template.py b/tests/test_llms_template.py index f54b452f1..db550393e 100644 --- a/tests/test_llms_template.py +++ b/tests/test_llms_template.py @@ -3,7 +3,7 @@ from langflow.settings import settings def test_llms_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() llms = json_response["llms"] @@ -11,7 +11,7 @@ def test_llms_settings(client: TestClient): # def test_hugging_face_hub(client: TestClient): -# response = client.get("/all") +# response = client.get("api/v1/all") # assert response.status_code == 200 # json_response = response.json() # language_models = json_response["llms"] @@ -103,7 +103,7 @@ def test_llms_settings(client: TestClient): def test_openai(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() language_models = json_response["llms"] @@ -333,7 +333,7 @@ def test_openai(client: TestClient): def test_chat_open_ai(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() language_models = json_response["llms"] diff --git a/tests/test_loading.py b/tests/test_loading.py index 872314699..885eb7a82 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -2,7 +2,7 @@ import json import pytest from langchain.chains.base import Chain -from langflow import load_flow_from_json +from langflow.processing.process import load_flow_from_json from langflow.graph import Graph from langflow.utils.payload import get_root_node diff --git a/tests/test_prompts_template.py b/tests/test_prompts_template.py index 83da2f14d..a8562898c 100644 --- a/tests/test_prompts_template.py +++ b/tests/test_prompts_template.py @@ -3,7 +3,7 @@ from langflow.settings import settings def test_prompts_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() prompts = json_response["prompts"] @@ -11,7 +11,7 @@ def test_prompts_settings(client: TestClient): def test_prompt_template(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() prompts = json_response["prompts"] @@ -89,7 +89,7 @@ def test_prompt_template(client: TestClient): def test_few_shot_prompt_template(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() prompts = json_response["prompts"] @@ -168,7 +168,7 @@ def test_few_shot_prompt_template(client: TestClient): def test_zero_shot_prompt(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() prompts = json_response["prompts"] diff --git a/tests/test_vectorstore_template.py b/tests/test_vectorstore_template.py index 5b1d7e5bc..0aa823786 100644 --- a/tests/test_vectorstore_template.py +++ b/tests/test_vectorstore_template.py @@ -5,7 +5,7 @@ from langflow.settings import settings # check that all agents are in settings.agents # are in json_response["agents"] def test_vectorstores_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() vectorstores = json_response["vectorstores"] diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 5b60d0fed..611faff79 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -5,17 +5,17 @@ from fastapi.testclient import TestClient def test_websocket_connection(client: TestClient): - with client.websocket_connect("/chat/test_client") as websocket: + with client.websocket_connect("api/v1/chat/test_client") as websocket: assert websocket.scope["client"] == ["testclient", 50000] - assert websocket.scope["path"] == "/chat/test_client" + assert websocket.scope["path"] == "/api/v1/chat/test_client" def test_chat_history(client: TestClient): # Mock the process_graph function to return a specific value - with patch("langflow.api.chat_manager.process_graph") as mock_process_graph: + with patch("langflow.chat.manager.process_graph") as mock_process_graph: mock_process_graph.return_value = ("Hello, I'm a mock response!", "") - with client.websocket_connect("/chat/test_client") as websocket: + with client.websocket_connect("api/v1/chat/test_client") as websocket: # First message should be the history history = websocket.receive_json() assert history == [] # Empty history