From 79c677fb090a975a15f42fdea21432206640dddf Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 09:58:30 -0300 Subject: [PATCH 01/12] =?UTF-8?q?=F0=9F=9A=80=20feat(pyproject.toml):=20ad?= =?UTF-8?q?d=20pytest=20configuration=20options=20Added=20pytest=20configu?= =?UTF-8?q?ration=20options=20to=20the=20pyproject.toml=20file.=20The=20mi?= =?UTF-8?q?nimum=20version=20of=20pytest=20is=20set=20to=206.0,=20the=20'-?= =?UTF-8?q?ra'=20option=20is=20added=20to=20addopts=20to=20show=20all=20te?= =?UTF-8?q?st=20results,=20testpaths=20are=20set=20to=20include=20both=20'?= =?UTF-8?q?tests'=20and=20'integration'=20directories,=20console=20output?= =?UTF-8?q?=20style=20is=20set=20to=20'progress',=20and=20DeprecationWarni?= =?UTF-8?q?ng=20is=20ignored.=20log=5Fcli=20is=20set=20to=20true=20to=20en?= =?UTF-8?q?able=20logging=20of=20pytest=20output=20to=20the=20console.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index cde7ddca5..87843198b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,6 +78,15 @@ types-pillow = "^9.5.0.2" [tool.poetry.extras] deploy = ["langchain-serve"] +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "-ra" +testpaths = ["tests", "integration"] +console_output_style = "progress" +filterwarnings = ["ignore::DeprecationWarning"] +log_cli = true + + [tool.ruff] line-length = 120 From 3342e03a2cb75bcbcc36060b5204bee89ba773af Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 09:59:37 -0300 Subject: [PATCH 02/12] =?UTF-8?q?=F0=9F=94=80=20refactor(langflow):=20move?= =?UTF-8?q?=20routers=20to=20a=20single=20file=20and=20add=20health=20chec?= =?UTF-8?q?k=20endpoint=20The=20routers=20for=20the=20langflow=20API=20hav?= =?UTF-8?q?e=20been=20moved=20to=20a=20single=20file=20for=20better=20orga?= =?UTF-8?q?nization=20and=20maintainability.=20The=20routers=20have=20been?= =?UTF-8?q?=20imported=20and=20included=20in=20the=20main.py=20file=20usin?= =?UTF-8?q?g=20the=20new=20file.=20A=20new=20health=20check=20endpoint=20h?= =?UTF-8?q?as=20been=20added=20to=20the=20API=20to=20check=20the=20status?= =?UTF-8?q?=20of=20the=20application.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/__init__.py | 2 +- src/backend/langflow/main.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/backend/langflow/__init__.py b/src/backend/langflow/__init__.py index 35fe814d2..17b1d940c 100644 --- a/src/backend/langflow/__init__.py +++ b/src/backend/langflow/__init__.py @@ -1,4 +1,4 @@ from langflow.cache import cache_manager -from langflow.interface.loading import load_flow_from_json +from langflow.processing.process import load_flow_from_json __all__ = ["load_flow_from_json", "cache_manager"] diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 56cc32e46..de39d8750 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -1,9 +1,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from langflow.api.chat import router as chat_router -from langflow.api.endpoints import router as endpoints_router -from langflow.api.validate import router as validate_router +from langflow.api import router def create_app(): @@ -14,6 +12,10 @@ def create_app(): "*", ] + @app.get("/health") + def get_health(): + return {"status": "OK"} + app.add_middleware( CORSMiddleware, allow_origins=origins, @@ -22,9 +24,7 @@ def create_app(): allow_headers=["*"], ) - app.include_router(endpoints_router) - app.include_router(validate_router) - app.include_router(chat_router) + app.include_router(router) return app From ac42e8a66c1ce0c31c732cdcbb35980aecb96c69 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:00:15 -0300 Subject: [PATCH 03/12] chore: remove refactored files --- src/backend/langflow/api/base.py | 84 --------- src/backend/langflow/api/callback.py | 32 ---- src/backend/langflow/api/chat.py | 26 --- src/backend/langflow/api/chat_manager.py | 223 ----------------------- src/backend/langflow/api/endpoints.py | 47 ----- src/backend/langflow/api/schemas.py | 70 ------- src/backend/langflow/api/validate.py | 57 ------ 7 files changed, 539 deletions(-) delete mode 100644 src/backend/langflow/api/base.py delete mode 100644 src/backend/langflow/api/callback.py delete mode 100644 src/backend/langflow/api/chat.py delete mode 100644 src/backend/langflow/api/chat_manager.py delete mode 100644 src/backend/langflow/api/endpoints.py delete mode 100644 src/backend/langflow/api/schemas.py delete mode 100644 src/backend/langflow/api/validate.py diff --git a/src/backend/langflow/api/base.py b/src/backend/langflow/api/base.py deleted file mode 100644 index 8cddc52e4..000000000 --- a/src/backend/langflow/api/base.py +++ /dev/null @@ -1,84 +0,0 @@ -from pydantic import BaseModel, validator - -from langflow.graph.utils import extract_input_variables_from_prompt - - -class CacheResponse(BaseModel): - data: dict - - -class Code(BaseModel): - code: str - - -class Prompt(BaseModel): - template: str - - -# Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}} -class CodeValidationResponse(BaseModel): - imports: dict - function: dict - - @validator("imports") - def validate_imports(cls, v): - return v or {"errors": []} - - @validator("function") - def validate_function(cls, v): - return v or {"errors": []} - - -class PromptValidationResponse(BaseModel): - input_variables: list - - -INVALID_CHARACTERS = { - " ", - ",", - ".", - ":", - ";", - "!", - "?", - "/", - "\\", - "(", - ")", - "[", - "]", - "{", - "}", -} - - -def validate_prompt(template: str): - input_variables = extract_input_variables_from_prompt(template) - - # Check if there are invalid characters in the input_variables - input_variables = check_input_variables(input_variables) - - return PromptValidationResponse(input_variables=input_variables) - - -def check_input_variables(input_variables: list): - invalid_chars = [] - fixed_variables = [] - for variable in input_variables: - new_var = variable - for char in INVALID_CHARACTERS: - if char in variable: - invalid_chars.append(char) - new_var = new_var.replace(char, "") - fixed_variables.append(new_var) - if new_var != variable: - input_variables.remove(variable) - input_variables.append(new_var) - # If any of the input_variables is not in the fixed_variables, then it means that - # there are invalid characters in the input_variables - if any(var not in fixed_variables for var in input_variables): - raise ValueError( - f"Invalid input variables: {input_variables}. Please, use something like {fixed_variables} instead." - ) - - return input_variables diff --git a/src/backend/langflow/api/callback.py b/src/backend/langflow/api/callback.py deleted file mode 100644 index d63e107c4..000000000 --- a/src/backend/langflow/api/callback.py +++ /dev/null @@ -1,32 +0,0 @@ -import asyncio -from typing import Any - -from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler - -from langflow.api.schemas import ChatResponse - - -# https://github.com/hwchase17/chat-langchain/blob/master/callback.py -class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): - """Callback handler for streaming LLM responses.""" - - def __init__(self, websocket): - self.websocket = websocket - - async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - resp = ChatResponse(message=token, type="stream", intermediate_steps="") - await self.websocket.send_json(resp.dict()) - - -class StreamingLLMCallbackHandler(BaseCallbackHandler): - """Callback handler for streaming LLM responses.""" - - def __init__(self, websocket): - self.websocket = websocket - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - resp = ChatResponse(message=token, type="stream", intermediate_steps="") - - loop = asyncio.get_event_loop() - coroutine = self.websocket.send_json(resp.dict()) - asyncio.run_coroutine_threadsafe(coroutine, loop) diff --git a/src/backend/langflow/api/chat.py b/src/backend/langflow/api/chat.py deleted file mode 100644 index 4afa6c22f..000000000 --- a/src/backend/langflow/api/chat.py +++ /dev/null @@ -1,26 +0,0 @@ -from fastapi import ( - APIRouter, - WebSocket, - WebSocketDisconnect, - WebSocketException, - status, -) - -from langflow.api.chat_manager import ChatManager -from langflow.utils.logger import logger - -router = APIRouter() -chat_manager = ChatManager() - - -@router.websocket("/chat/{client_id}") -async def websocket_endpoint(client_id: str, websocket: WebSocket): - """Websocket endpoint for chat.""" - try: - await chat_manager.handle_websocket(client_id, websocket) - except WebSocketException as exc: - logger.error(exc) - await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc)) - except WebSocketDisconnect as exc: - logger.error(exc) - await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc)) diff --git a/src/backend/langflow/api/chat_manager.py b/src/backend/langflow/api/chat_manager.py deleted file mode 100644 index 8b1c7a621..000000000 --- a/src/backend/langflow/api/chat_manager.py +++ /dev/null @@ -1,223 +0,0 @@ -import asyncio -import json -from collections import defaultdict -from typing import Dict, List - -from fastapi import WebSocket, status - -from langflow.api.schemas import ChatMessage, ChatResponse, FileResponse -from langflow.cache import cache_manager -from langflow.cache.manager import Subject -from langflow.interface.run import ( - get_result_and_steps, - load_or_build_langchain_object, -) -from langflow.interface.utils import pil_to_base64, try_setting_streaming_options -from langflow.utils.logger import logger - - -class ChatHistory(Subject): - def __init__(self): - super().__init__() - self.history: Dict[str, List[ChatMessage]] = defaultdict(list) - - def add_message(self, client_id: str, message: ChatMessage): - """Add a message to the chat history.""" - - self.history[client_id].append(message) - - if not isinstance(message, FileResponse): - self.notify() - - def get_history(self, client_id: str, filter_messages=True) -> List[ChatMessage]: - """Get the chat history for a client.""" - if history := self.history.get(client_id, []): - if filter_messages: - return [msg for msg in history if msg.type not in ["start", "stream"]] - return history - else: - return [] - - def empty_history(self, client_id: str): - """Empty the chat history for a client.""" - self.history[client_id] = [] - - -class ChatManager: - def __init__(self): - self.active_connections: Dict[str, WebSocket] = {} - self.chat_history = ChatHistory() - self.cache_manager = cache_manager - self.cache_manager.attach(self.update) - - def on_chat_history_update(self): - """Send the last chat message to the client.""" - client_id = self.cache_manager.current_client_id - if client_id in self.active_connections: - chat_response = self.chat_history.get_history( - client_id, filter_messages=False - )[-1] - if chat_response.is_bot: - # Process FileResponse - if isinstance(chat_response, FileResponse): - # If data_type is pandas, convert to csv - if chat_response.data_type == "pandas": - chat_response.data = chat_response.data.to_csv() - elif chat_response.data_type == "image": - # Base64 encode the image - chat_response.data = pil_to_base64(chat_response.data) - # get event loop - loop = asyncio.get_event_loop() - - coroutine = self.send_json(client_id, chat_response) - asyncio.run_coroutine_threadsafe(coroutine, loop) - - def update(self): - if self.cache_manager.current_client_id in self.active_connections: - self.last_cached_object_dict = self.cache_manager.get_last() - # Add a new ChatResponse with the data - chat_response = FileResponse( - message=None, - type="file", - data=self.last_cached_object_dict["obj"], - data_type=self.last_cached_object_dict["type"], - ) - - self.chat_history.add_message( - self.cache_manager.current_client_id, chat_response - ) - - async def connect(self, client_id: str, websocket: WebSocket): - await websocket.accept() - self.active_connections[client_id] = websocket - - def disconnect(self, client_id: str): - self.active_connections.pop(client_id, None) - - async def send_message(self, client_id: str, message: str): - websocket = self.active_connections[client_id] - await websocket.send_text(message) - - async def send_json(self, client_id: str, message: ChatMessage): - websocket = self.active_connections[client_id] - await websocket.send_json(message.dict()) - - async def process_message(self, client_id: str, payload: Dict): - # Process the graph data and chat message - chat_message = payload.pop("message", "") - chat_message = ChatMessage(message=chat_message) - self.chat_history.add_message(client_id, chat_message) - - graph_data = payload - start_resp = ChatResponse(message=None, type="start", intermediate_steps="") - await self.send_json(client_id, start_resp) - - is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1 - # Generate result and thought - try: - logger.debug("Generating result and thought") - - result, intermediate_steps = await process_graph( - graph_data=graph_data, - is_first_message=is_first_message, - chat_message=chat_message, - websocket=self.active_connections[client_id], - ) - except Exception as e: - # Log stack trace - logger.exception(e) - self.chat_history.empty_history(client_id) - raise e - # Send a response back to the frontend, if needed - intermediate_steps = intermediate_steps or "" - history = self.chat_history.get_history(client_id, filter_messages=False) - file_responses = [] - if history: - # Iterate backwards through the history - for msg in reversed(history): - if isinstance(msg, FileResponse): - if msg.data_type == "image": - # Base64 encode the image - msg.data = pil_to_base64(msg.data) - file_responses.append(msg) - if msg.type == "start": - break - - response = ChatResponse( - message=result, - intermediate_steps=intermediate_steps.strip(), - type="end", - files=file_responses, - ) - await self.send_json(client_id, response) - self.chat_history.add_message(client_id, response) - - async def handle_websocket(self, client_id: str, websocket: WebSocket): - await self.connect(client_id, websocket) - - try: - chat_history = self.chat_history.get_history(client_id) - # iterate and make BaseModel into dict - chat_history = [chat.dict() for chat in chat_history] - await websocket.send_json(chat_history) - - while True: - json_payload = await websocket.receive_json() - try: - payload = json.loads(json_payload) - except TypeError: - payload = json_payload - if "clear_history" in payload: - self.chat_history.history[client_id] = [] - continue - - with self.cache_manager.set_client_id(client_id): - await self.process_message(client_id, payload) - - except Exception as e: - # Handle any exceptions that might occur - logger.exception(e) - # send a message to the client - await self.active_connections[client_id].close( - code=status.WS_1011_INTERNAL_ERROR, reason=str(e)[:120] - ) - self.disconnect(client_id) - finally: - try: - connection = self.active_connections.get(client_id) - if connection: - await connection.close(code=1000, reason="Client disconnected") - self.disconnect(client_id) - except Exception as e: - logger.exception(e) - self.disconnect(client_id) - - -async def process_graph( - graph_data: Dict, - is_first_message: bool, - chat_message: ChatMessage, - websocket: WebSocket, -): - langchain_object = load_or_build_langchain_object(graph_data, is_first_message) - langchain_object = try_setting_streaming_options(langchain_object, websocket) - logger.debug("Loaded langchain object") - - if langchain_object is None: - # Raise user facing error - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) - - # Generate result and thought - try: - logger.debug("Generating result and thought") - result, intermediate_steps = await get_result_and_steps( - langchain_object, chat_message.message or "", websocket=websocket - ) - logger.debug("Generated result and intermediate_steps") - return result, intermediate_steps - except Exception as e: - # Log stack trace - logger.exception(e) - raise e diff --git a/src/backend/langflow/api/endpoints.py b/src/backend/langflow/api/endpoints.py deleted file mode 100644 index 021a81ca8..000000000 --- a/src/backend/langflow/api/endpoints.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging -from importlib.metadata import version - -from fastapi import APIRouter, HTTPException - -from langflow.api.schemas import ( - ExportedFlow, - GraphData, - PredictRequest, - PredictResponse, -) -from langflow.interface.run import process_graph_cached -from langflow.interface.types import build_langchain_types_dict - -# build router -router = APIRouter() -logger = logging.getLogger(__name__) - - -@router.get("/all") -def get_all(): - return build_langchain_types_dict() - - -@router.post("/predict", response_model=PredictResponse) -async def get_load(predict_request: PredictRequest): - try: - exported_flow: ExportedFlow = predict_request.exported_flow - graph_data: GraphData = exported_flow.data - data = graph_data.dict() - response = process_graph_cached(data, predict_request.message) - return PredictResponse(result=response.get("result", "")) - except Exception as e: - # Log stack trace - logger.exception(e) - raise HTTPException(status_code=500, detail=str(e)) from e - - -# get endpoint to return version of langflow -@router.get("/version") -def get_version(): - return {"version": version("langflow")} - - -@router.get("/health") -def get_health(): - return {"status": "OK"} diff --git a/src/backend/langflow/api/schemas.py b/src/backend/langflow/api/schemas.py deleted file mode 100644 index f73b0642d..000000000 --- a/src/backend/langflow/api/schemas.py +++ /dev/null @@ -1,70 +0,0 @@ -from typing import Any, Dict, List, Union - -from pydantic import BaseModel, validator - - -class GraphData(BaseModel): - """Data inside the exported flow.""" - - nodes: List[Dict[str, Any]] - edges: List[Dict[str, Any]] - - -class ExportedFlow(BaseModel): - """Exported flow from LangFlow.""" - - description: str - name: str - id: str - data: GraphData - - -class PredictRequest(BaseModel): - """Predict request schema.""" - - message: str - exported_flow: ExportedFlow - - -class PredictResponse(BaseModel): - """Predict response schema.""" - - result: str - - -class ChatMessage(BaseModel): - """Chat message schema.""" - - is_bot: bool = False - message: Union[str, None] = None - type: str = "human" - - -class ChatResponse(ChatMessage): - """Chat response schema.""" - - intermediate_steps: str - type: str - is_bot: bool = True - files: list = [] - - @validator("type") - def validate_message_type(cls, v): - if v not in ["start", "stream", "end", "error", "info", "file"]: - raise ValueError("type must be start, stream, end, error, info, or file") - return v - - -class FileResponse(ChatMessage): - """File response schema.""" - - data: Any - data_type: str - type: str = "file" - is_bot: bool = True - - @validator("data_type") - def validate_data_type(cls, v): - if v not in ["image", "csv"]: - raise ValueError("data_type must be image or csv") - return v diff --git a/src/backend/langflow/api/validate.py b/src/backend/langflow/api/validate.py deleted file mode 100644 index e90e554f0..000000000 --- a/src/backend/langflow/api/validate.py +++ /dev/null @@ -1,57 +0,0 @@ -import json - -from fastapi import APIRouter, HTTPException - -from langflow.api.base import ( - Code, - CodeValidationResponse, - Prompt, - PromptValidationResponse, - validate_prompt, -) -from langflow.graph.vertex.types import VectorStoreVertex -from langflow.interface.run import build_graph -from langflow.utils.logger import logger -from langflow.utils.validate import validate_code - -# build router -router = APIRouter(prefix="/validate", tags=["validate"]) - - -@router.post("/code", status_code=200, response_model=CodeValidationResponse) -def post_validate_code(code: Code): - try: - errors = validate_code(code.code) - return CodeValidationResponse( - imports=errors.get("imports", {}), - function=errors.get("function", {}), - ) - except Exception as e: - return HTTPException(status_code=500, detail=str(e)) - - -@router.post("/prompt", status_code=200, response_model=PromptValidationResponse) -def post_validate_prompt(prompt: Prompt): - try: - return validate_prompt(prompt.template) - except Exception as e: - logger.exception(e) - raise HTTPException(status_code=500, detail=str(e)) from e - - -# validate node -@router.post("/node/{node_id}", status_code=200) -def post_validate_node(node_id: str, data: dict): - try: - # build graph - graph = build_graph(data) - # validate node - node = graph.get_node(node_id) - if node is None: - raise ValueError(f"Node {node_id} not found") - if not isinstance(node, VectorStoreVertex): - node.build() - return json.dumps({"valid": True, "params": str(node._built_object_repr())}) - except Exception as e: - logger.exception(e) - return json.dumps({"valid": False, "params": str(e)}) From bdbb4a81279f535538cc55128b358674c5ec7a45 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:00:38 -0300 Subject: [PATCH 04/12] =?UTF-8?q?=F0=9F=9A=80=20feat(api):=20add=20version?= =?UTF-8?q?ing=20to=20the=20API=20and=20restructure=20the=20router=20The?= =?UTF-8?q?=20API=20now=20has=20versioning,=20with=20the=20prefix=20"/api/?= =?UTF-8?q?v1".=20The=20router=20has=20been=20restructured=20to=20include?= =?UTF-8?q?=20the=20chat,=20endpoints,=20and=20validate=20routers.=20This?= =?UTF-8?q?=20improves=20the=20organization=20of=20the=20code=20and=20make?= =?UTF-8?q?s=20it=20easier=20to=20add=20new=20routers=20in=20the=20future.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/__init__.py | 3 +++ src/backend/langflow/api/router.py | 8 ++++++++ src/backend/langflow/api/v1/__init__.py | 5 +++++ 3 files changed, 16 insertions(+) create mode 100644 src/backend/langflow/api/router.py create mode 100644 src/backend/langflow/api/v1/__init__.py diff --git a/src/backend/langflow/api/__init__.py b/src/backend/langflow/api/__init__.py index e69de29bb..f887c47e1 100644 --- a/src/backend/langflow/api/__init__.py +++ b/src/backend/langflow/api/__init__.py @@ -0,0 +1,3 @@ +from langflow.api.router import router + +__all__ = ["router"] diff --git a/src/backend/langflow/api/router.py b/src/backend/langflow/api/router.py new file mode 100644 index 000000000..23b5aa1c5 --- /dev/null +++ b/src/backend/langflow/api/router.py @@ -0,0 +1,8 @@ +# Router for base api +from fastapi import APIRouter +from langflow.api.v1 import chat_router, endpoints_router, validate_router + +router = APIRouter(prefix="/api/v1", tags=["api"]) +router.include_router(chat_router) +router.include_router(endpoints_router) +router.include_router(validate_router) diff --git a/src/backend/langflow/api/v1/__init__.py b/src/backend/langflow/api/v1/__init__.py new file mode 100644 index 000000000..d835b4535 --- /dev/null +++ b/src/backend/langflow/api/v1/__init__.py @@ -0,0 +1,5 @@ +from langflow.api.v1.endpoints import router as endpoints_router +from langflow.api.v1.validate import router as validate_router +from langflow.api.v1.chat import router as chat_router + +__all__ = ["chat_router", "endpoints_router", "validate_router"] From 3e5878ddc282557315f532605357ff782af13d64 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:01:18 -0300 Subject: [PATCH 05/12] =?UTF-8?q?=F0=9F=8E=89=20feat(langflow):=20add=20ne?= =?UTF-8?q?w=20files=20base.py=20and=20callback.py=20The=20base.py=20file?= =?UTF-8?q?=20contains=20the=20following=20classes=20and=20functions:=20-?= =?UTF-8?q?=20CacheResponse:=20a=20pydantic=20BaseModel=20that=20represent?= =?UTF-8?q?s=20a=20response=20containing=20a=20dictionary=20of=20data=20-?= =?UTF-8?q?=20Code:=20a=20pydantic=20BaseModel=20that=20represents=20a=20c?= =?UTF-8?q?ode=20string=20-=20Prompt:=20a=20pydantic=20BaseModel=20that=20?= =?UTF-8?q?represents=20a=20prompt=20template=20string=20-=20CodeValidatio?= =?UTF-8?q?nResponse:=20a=20pydantic=20BaseModel=20that=20represents=20a?= =?UTF-8?q?=20response=20containing=20the=20validation=20results=20of=20co?= =?UTF-8?q?de=20-=20PromptValidationResponse:=20a=20pydantic=20BaseModel?= =?UTF-8?q?=20that=20represents=20a=20response=20containing=20the=20valida?= =?UTF-8?q?tion=20results=20of=20a=20prompt=20-=20validate=5Fprompt:=20a?= =?UTF-8?q?=20function=20that=20validates=20a=20prompt=20template=20string?= =?UTF-8?q?=20and=20returns=20a=20PromptValidationResponse=20object=20-=20?= =?UTF-8?q?check=5Finput=5Fvariables:=20a=20function=20that=20checks=20if?= =?UTF-8?q?=20input=20variables=20contain=20invalid=20characters=20and=20r?= =?UTF-8?q?eturns=20a=20list=20of=20fixed=20input=20variables?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The callback.py file contains the following classes: - AsyncStreamingLLMCallbackHandler: an AsyncCallbackHandler that handles streaming LLM responses asynchronously - StreamingLLMCallbackHandler: a BaseCallbackHandler that handles streaming LLM responses These files were added to provide support for Langflow's backend API. --- src/backend/langflow/api/v1/base.py | 84 +++++++++++++++++++++++++ src/backend/langflow/api/v1/callback.py | 32 ++++++++++ 2 files changed, 116 insertions(+) create mode 100644 src/backend/langflow/api/v1/base.py create mode 100644 src/backend/langflow/api/v1/callback.py diff --git a/src/backend/langflow/api/v1/base.py b/src/backend/langflow/api/v1/base.py new file mode 100644 index 000000000..6941bedf3 --- /dev/null +++ b/src/backend/langflow/api/v1/base.py @@ -0,0 +1,84 @@ +from pydantic import BaseModel, validator + +from langflow.interface.utils import extract_input_variables_from_prompt + + +class CacheResponse(BaseModel): + data: dict + + +class Code(BaseModel): + code: str + + +class Prompt(BaseModel): + template: str + + +# Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}} +class CodeValidationResponse(BaseModel): + imports: dict + function: dict + + @validator("imports") + def validate_imports(cls, v): + return v or {"errors": []} + + @validator("function") + def validate_function(cls, v): + return v or {"errors": []} + + +class PromptValidationResponse(BaseModel): + input_variables: list + + +INVALID_CHARACTERS = { + " ", + ",", + ".", + ":", + ";", + "!", + "?", + "/", + "\\", + "(", + ")", + "[", + "]", + "{", + "}", +} + + +def validate_prompt(template: str): + input_variables = extract_input_variables_from_prompt(template) + + # Check if there are invalid characters in the input_variables + input_variables = check_input_variables(input_variables) + + return PromptValidationResponse(input_variables=input_variables) + + +def check_input_variables(input_variables: list): + invalid_chars = [] + fixed_variables = [] + for variable in input_variables: + new_var = variable + for char in INVALID_CHARACTERS: + if char in variable: + invalid_chars.append(char) + new_var = new_var.replace(char, "") + fixed_variables.append(new_var) + if new_var != variable: + input_variables.remove(variable) + input_variables.append(new_var) + # If any of the input_variables is not in the fixed_variables, then it means that + # there are invalid characters in the input_variables + if any(var not in fixed_variables for var in input_variables): + raise ValueError( + f"Invalid input variables: {input_variables}. Please, use something like {fixed_variables} instead." + ) + + return input_variables diff --git a/src/backend/langflow/api/v1/callback.py b/src/backend/langflow/api/v1/callback.py new file mode 100644 index 000000000..b58393d7b --- /dev/null +++ b/src/backend/langflow/api/v1/callback.py @@ -0,0 +1,32 @@ +import asyncio +from typing import Any + +from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler + +from langflow.api.v1.schemas import ChatResponse + + +# https://github.com/hwchase17/chat-langchain/blob/master/callback.py +class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler): + """Callback handler for streaming LLM responses.""" + + def __init__(self, websocket): + self.websocket = websocket + + async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + resp = ChatResponse(message=token, type="stream", intermediate_steps="") + await self.websocket.send_json(resp.dict()) + + +class StreamingLLMCallbackHandler(BaseCallbackHandler): + """Callback handler for streaming LLM responses.""" + + def __init__(self, websocket): + self.websocket = websocket + + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: + resp = ChatResponse(message=token, type="stream", intermediate_steps="") + + loop = asyncio.get_event_loop() + coroutine = self.websocket.send_json(resp.dict()) + asyncio.run_coroutine_threadsafe(coroutine, loop) From 2bfe93e0b8e44ff82785889e7589928d2cb8799b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:01:44 -0300 Subject: [PATCH 06/12] =?UTF-8?q?=F0=9F=9A=80=20feat(langflow):=20add=20ne?= =?UTF-8?q?w=20API=20endpoints=20for=20chat,=20validation,=20and=20version?= =?UTF-8?q?=20This=20commit=20adds=20new=20API=20endpoints=20for=20chat,?= =?UTF-8?q?=20validation,=20and=20version.=20The=20chat=20endpoint=20is=20?= =?UTF-8?q?a=20websocket=20endpoint=20for=20chat.=20The=20validation=20end?= =?UTF-8?q?point=20has=20three=20sub-endpoints=20for=20validating=20code,?= =?UTF-8?q?=20prompt,=20and=20node.=20The=20version=20endpoint=20returns?= =?UTF-8?q?=20the=20version=20of=20LangFlow.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/api/v1/chat.py | 26 +++++++++ src/backend/langflow/api/v1/endpoints.py | 44 +++++++++++++++ src/backend/langflow/api/v1/schemas.py | 70 ++++++++++++++++++++++++ src/backend/langflow/api/v1/validate.py | 57 +++++++++++++++++++ 4 files changed, 197 insertions(+) create mode 100644 src/backend/langflow/api/v1/chat.py create mode 100644 src/backend/langflow/api/v1/endpoints.py create mode 100644 src/backend/langflow/api/v1/schemas.py create mode 100644 src/backend/langflow/api/v1/validate.py diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py new file mode 100644 index 000000000..7df4c65ed --- /dev/null +++ b/src/backend/langflow/api/v1/chat.py @@ -0,0 +1,26 @@ +from fastapi import ( + APIRouter, + WebSocket, + WebSocketDisconnect, + WebSocketException, + status, +) + +from langflow.chat.manager import ChatManager +from langflow.utils.logger import logger + +router = APIRouter() +chat_manager = ChatManager() + + +@router.websocket("/chat/{client_id}") +async def websocket_endpoint(client_id: str, websocket: WebSocket): + """Websocket endpoint for chat.""" + try: + await chat_manager.handle_websocket(client_id, websocket) + except WebSocketException as exc: + logger.error(exc) + await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc)) + except WebSocketDisconnect as exc: + logger.error(exc) + await websocket.close(code=status.WS_1000_NORMAL_CLOSURE, reason=str(exc)) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py new file mode 100644 index 000000000..1e9b0deb1 --- /dev/null +++ b/src/backend/langflow/api/v1/endpoints.py @@ -0,0 +1,44 @@ +import logging +from importlib.metadata import version + +from fastapi import APIRouter, HTTPException + +from langflow.api.v1.schemas import ( + ExportedFlow, + GraphData, + PredictRequest, + PredictResponse, +) + +from langflow.interface.types import build_langchain_types_dict + +# build router +router = APIRouter() +logger = logging.getLogger(__name__) + + +@router.get("/all") +def get_all(): + return build_langchain_types_dict() + + +@router.post("/predict", response_model=PredictResponse) +async def get_load(predict_request: PredictRequest): + try: + from langflow.processing.process import process_graph_cached + + exported_flow: ExportedFlow = predict_request.exported_flow + graph_data: GraphData = exported_flow.data + data = graph_data.dict() + response = process_graph_cached(data, predict_request.message) + return PredictResponse(result=response.get("result", "")) + except Exception as e: + # Log stack trace + logger.exception(e) + raise HTTPException(status_code=500, detail=str(e)) from e + + +# get endpoint to return version of langflow +@router.get("/version") +def get_version(): + return {"version": version("langflow")} diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py new file mode 100644 index 000000000..f73b0642d --- /dev/null +++ b/src/backend/langflow/api/v1/schemas.py @@ -0,0 +1,70 @@ +from typing import Any, Dict, List, Union + +from pydantic import BaseModel, validator + + +class GraphData(BaseModel): + """Data inside the exported flow.""" + + nodes: List[Dict[str, Any]] + edges: List[Dict[str, Any]] + + +class ExportedFlow(BaseModel): + """Exported flow from LangFlow.""" + + description: str + name: str + id: str + data: GraphData + + +class PredictRequest(BaseModel): + """Predict request schema.""" + + message: str + exported_flow: ExportedFlow + + +class PredictResponse(BaseModel): + """Predict response schema.""" + + result: str + + +class ChatMessage(BaseModel): + """Chat message schema.""" + + is_bot: bool = False + message: Union[str, None] = None + type: str = "human" + + +class ChatResponse(ChatMessage): + """Chat response schema.""" + + intermediate_steps: str + type: str + is_bot: bool = True + files: list = [] + + @validator("type") + def validate_message_type(cls, v): + if v not in ["start", "stream", "end", "error", "info", "file"]: + raise ValueError("type must be start, stream, end, error, info, or file") + return v + + +class FileResponse(ChatMessage): + """File response schema.""" + + data: Any + data_type: str + type: str = "file" + is_bot: bool = True + + @validator("data_type") + def validate_data_type(cls, v): + if v not in ["image", "csv"]: + raise ValueError("data_type must be image or csv") + return v diff --git a/src/backend/langflow/api/v1/validate.py b/src/backend/langflow/api/v1/validate.py new file mode 100644 index 000000000..009cb9a30 --- /dev/null +++ b/src/backend/langflow/api/v1/validate.py @@ -0,0 +1,57 @@ +import json + +from fastapi import APIRouter, HTTPException + +from langflow.api.v1.base import ( + Code, + CodeValidationResponse, + Prompt, + PromptValidationResponse, + validate_prompt, +) +from langflow.graph.vertex.types import VectorStoreVertex +from langflow.graph import Graph +from langflow.utils.logger import logger +from langflow.utils.validate import validate_code + +# build router +router = APIRouter(prefix="/validate", tags=["validate"]) + + +@router.post("/code", status_code=200, response_model=CodeValidationResponse) +def post_validate_code(code: Code): + try: + errors = validate_code(code.code) + return CodeValidationResponse( + imports=errors.get("imports", {}), + function=errors.get("function", {}), + ) + except Exception as e: + return HTTPException(status_code=500, detail=str(e)) + + +@router.post("/prompt", status_code=200, response_model=PromptValidationResponse) +def post_validate_prompt(prompt: Prompt): + try: + return validate_prompt(prompt.template) + except Exception as e: + logger.exception(e) + raise HTTPException(status_code=500, detail=str(e)) from e + + +# validate node +@router.post("/node/{node_id}", status_code=200) +def post_validate_node(node_id: str, data: dict): + try: + # build graph + graph = Graph.from_payload(data) + # validate node + node = graph.get_node(node_id) + if node is None: + raise ValueError(f"Node {node_id} not found") + if not isinstance(node, VectorStoreVertex): + node.build() + return json.dumps({"valid": True, "params": str(node._built_object_repr())}) + except Exception as e: + logger.exception(e) + return json.dumps({"valid": False, "params": str(e)}) From 7f4eea1e593f29438986b7d21568b06664bb5c13 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:02:21 -0300 Subject: [PATCH 07/12] =?UTF-8?q?=F0=9F=9A=80=20feat(chat):=20add=20ChatMa?= =?UTF-8?q?nager=20and=20ChatHistory=20classes=20to=20manage=20chat=20hist?= =?UTF-8?q?ory=20and=20active=20connections=20=E2=9C=A8=20feat(utils.py):?= =?UTF-8?q?=20add=20process=5Fgraph=20function=20to=20process=20graph=20da?= =?UTF-8?q?ta=20and=20generate=20result=20and=20thought=20The=20ChatManage?= =?UTF-8?q?r=20class=20manages=20active=20connections=20and=20chat=20histo?= =?UTF-8?q?ry.=20The=20ChatHistory=20class=20manages=20the=20chat=20histor?= =?UTF-8?q?y=20for=20a=20client.=20The=20process=5Fgraph=20function=20proc?= =?UTF-8?q?esses=20graph=20data=20and=20generates=20a=20result=20and=20tho?= =?UTF-8?q?ught.=20This=20function=20is=20used=20in=20the=20ChatManager=20?= =?UTF-8?q?class=20to=20generate=20a=20response=20back=20to=20the=20fronte?= =?UTF-8?q?nd.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/chat/__init__.py | 0 src/backend/langflow/chat/manager.py | 190 ++++++++++++++++++++++++++ src/backend/langflow/chat/utils.py | 41 ++++++ 3 files changed, 231 insertions(+) create mode 100644 src/backend/langflow/chat/__init__.py create mode 100644 src/backend/langflow/chat/manager.py create mode 100644 src/backend/langflow/chat/utils.py diff --git a/src/backend/langflow/chat/__init__.py b/src/backend/langflow/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/chat/manager.py b/src/backend/langflow/chat/manager.py new file mode 100644 index 000000000..d24057b68 --- /dev/null +++ b/src/backend/langflow/chat/manager.py @@ -0,0 +1,190 @@ +from collections import defaultdict +from fastapi import WebSocket, status +from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse +from langflow.cache import cache_manager +from langflow.cache.manager import Subject +from langflow.chat.utils import process_graph +from langflow.interface.utils import pil_to_base64 +from langflow.utils.logger import logger + + +import asyncio +import json +from typing import Dict, List + + +class ChatHistory(Subject): + def __init__(self): + super().__init__() + self.history: Dict[str, List[ChatMessage]] = defaultdict(list) + + def add_message(self, client_id: str, message: ChatMessage): + """Add a message to the chat history.""" + + self.history[client_id].append(message) + + if not isinstance(message, FileResponse): + self.notify() + + def get_history(self, client_id: str, filter_messages=True) -> List[ChatMessage]: + """Get the chat history for a client.""" + if history := self.history.get(client_id, []): + if filter_messages: + return [msg for msg in history if msg.type not in ["start", "stream"]] + return history + else: + return [] + + def empty_history(self, client_id: str): + """Empty the chat history for a client.""" + self.history[client_id] = [] + + +class ChatManager: + def __init__(self): + self.active_connections: Dict[str, WebSocket] = {} + self.chat_history = ChatHistory() + self.cache_manager = cache_manager + self.cache_manager.attach(self.update) + + def on_chat_history_update(self): + """Send the last chat message to the client.""" + client_id = self.cache_manager.current_client_id + if client_id in self.active_connections: + chat_response = self.chat_history.get_history( + client_id, filter_messages=False + )[-1] + if chat_response.is_bot: + # Process FileResponse + if isinstance(chat_response, FileResponse): + # If data_type is pandas, convert to csv + if chat_response.data_type == "pandas": + chat_response.data = chat_response.data.to_csv() + elif chat_response.data_type == "image": + # Base64 encode the image + chat_response.data = pil_to_base64(chat_response.data) + # get event loop + loop = asyncio.get_event_loop() + + coroutine = self.send_json(client_id, chat_response) + asyncio.run_coroutine_threadsafe(coroutine, loop) + + def update(self): + if self.cache_manager.current_client_id in self.active_connections: + self.last_cached_object_dict = self.cache_manager.get_last() + # Add a new ChatResponse with the data + chat_response = FileResponse( + message=None, + type="file", + data=self.last_cached_object_dict["obj"], + data_type=self.last_cached_object_dict["type"], + ) + + self.chat_history.add_message( + self.cache_manager.current_client_id, chat_response + ) + + async def connect(self, client_id: str, websocket: WebSocket): + await websocket.accept() + self.active_connections[client_id] = websocket + + def disconnect(self, client_id: str): + self.active_connections.pop(client_id, None) + + async def send_message(self, client_id: str, message: str): + websocket = self.active_connections[client_id] + await websocket.send_text(message) + + async def send_json(self, client_id: str, message: ChatMessage): + websocket = self.active_connections[client_id] + await websocket.send_json(message.dict()) + + async def process_message(self, client_id: str, payload: Dict): + # Process the graph data and chat message + chat_message = payload.pop("message", "") + chat_message = ChatMessage(message=chat_message) + self.chat_history.add_message(client_id, chat_message) + + graph_data = payload + start_resp = ChatResponse(message=None, type="start", intermediate_steps="") + await self.send_json(client_id, start_resp) + + is_first_message = len(self.chat_history.get_history(client_id=client_id)) <= 1 + # Generate result and thought + try: + logger.debug("Generating result and thought") + + result, intermediate_steps = await process_graph( + graph_data=graph_data, + is_first_message=is_first_message, + chat_message=chat_message, + websocket=self.active_connections[client_id], + ) + except Exception as e: + # Log stack trace + logger.exception(e) + self.chat_history.empty_history(client_id) + raise e + # Send a response back to the frontend, if needed + intermediate_steps = intermediate_steps or "" + history = self.chat_history.get_history(client_id, filter_messages=False) + file_responses = [] + if history: + # Iterate backwards through the history + for msg in reversed(history): + if isinstance(msg, FileResponse): + if msg.data_type == "image": + # Base64 encode the image + msg.data = pil_to_base64(msg.data) + file_responses.append(msg) + if msg.type == "start": + break + + response = ChatResponse( + message=result, + intermediate_steps=intermediate_steps.strip(), + type="end", + files=file_responses, + ) + await self.send_json(client_id, response) + self.chat_history.add_message(client_id, response) + + async def handle_websocket(self, client_id: str, websocket: WebSocket): + await self.connect(client_id, websocket) + + try: + chat_history = self.chat_history.get_history(client_id) + # iterate and make BaseModel into dict + chat_history = [chat.dict() for chat in chat_history] + await websocket.send_json(chat_history) + + while True: + json_payload = await websocket.receive_json() + try: + payload = json.loads(json_payload) + except TypeError: + payload = json_payload + if "clear_history" in payload: + self.chat_history.history[client_id] = [] + continue + + with self.cache_manager.set_client_id(client_id): + await self.process_message(client_id, payload) + + except Exception as e: + # Handle any exceptions that might occur + logger.exception(e) + # send a message to the client + await self.active_connections[client_id].close( + code=status.WS_1011_INTERNAL_ERROR, reason=str(e)[:120] + ) + self.disconnect(client_id) + finally: + try: + connection = self.active_connections.get(client_id) + if connection: + await connection.close(code=1000, reason="Client disconnected") + self.disconnect(client_id) + except Exception as e: + logger.exception(e) + self.disconnect(client_id) diff --git a/src/backend/langflow/chat/utils.py b/src/backend/langflow/chat/utils.py new file mode 100644 index 000000000..410a442be --- /dev/null +++ b/src/backend/langflow/chat/utils.py @@ -0,0 +1,41 @@ +from fastapi import WebSocket +from langflow.api.v1.schemas import ChatMessage +from langflow.processing.process import ( + load_or_build_langchain_object, +) +from langflow.processing.base import get_result_and_steps +from langflow.interface.utils import try_setting_streaming_options +from langflow.utils.logger import logger + + +from typing import Dict + + +async def process_graph( + graph_data: Dict, + is_first_message: bool, + chat_message: ChatMessage, + websocket: WebSocket, +): + langchain_object = load_or_build_langchain_object(graph_data, is_first_message) + langchain_object = try_setting_streaming_options(langchain_object, websocket) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + try: + logger.debug("Generating result and thought") + result, intermediate_steps = await get_result_and_steps( + langchain_object, chat_message.message or "", websocket=websocket + ) + logger.debug("Generated result and intermediate_steps") + return result, intermediate_steps + except Exception as e: + # Log stack trace + logger.exception(e) + raise e From 3bfee4d4455af8ae18a947ca31785f7978e1384d Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:05:01 -0300 Subject: [PATCH 08/12] =?UTF-8?q?=F0=9F=9A=80=20feat(graph):=20add=20from?= =?UTF-8?q?=5Fpayload=20class=20method=20to=20Graph=20class=20=F0=9F=9A=80?= =?UTF-8?q?=20feat(utils.py):=20import=20extract=5Finput=5Fvariables=5Ffro?= =?UTF-8?q?m=5Fprompt=20from=20langflow.interface.utils=20The=20`from=5Fpa?= =?UTF-8?q?yload`=20class=20method=20is=20added=20to=20the=20`Graph`=20cla?= =?UTF-8?q?ss=20to=20create=20a=20graph=20from=20a=20payload.=20This=20met?= =?UTF-8?q?hod=20takes=20a=20dictionary=20as=20input=20and=20returns=20a?= =?UTF-8?q?=20`Graph`=20object.=20The=20`extract=5Finput=5Fvariables=5Ffro?= =?UTF-8?q?m=5Fprompt`=20function=20is=20imported=20from=20`langflow.inter?= =?UTF-8?q?face.utils`=20to=20extract=20input=20variables=20from=20a=20pro?= =?UTF-8?q?mpt.=20This=20function=20is=20used=20in=20other=20parts=20of=20?= =?UTF-8?q?the=20codebase=20to=20extract=20input=20variables=20from=20prom?= =?UTF-8?q?pts.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/langflow/graph/graph/base.py | 21 +++++++++++++++++++++ src/backend/langflow/graph/graph/utils.py | 0 src/backend/langflow/graph/utils.py | 8 ++------ 3 files changed, 23 insertions(+), 6 deletions(-) create mode 100644 src/backend/langflow/graph/graph/utils.py diff --git a/src/backend/langflow/graph/graph/base.py b/src/backend/langflow/graph/graph/base.py index 020f539ec..5fd00d09b 100644 --- a/src/backend/langflow/graph/graph/base.py +++ b/src/backend/langflow/graph/graph/base.py @@ -24,6 +24,27 @@ class Graph: self._edges = edges self._build_graph() + @classmethod + @classmethod + def from_payload(cls, payload: Dict) -> "Graph": + """ + Creates a graph from a payload. + + Args: + payload (Dict): The payload to create the graph from. + + Returns: + Graph: The created graph. + """ + if "data" in payload: + payload = payload["data"] + try: + nodes = payload["nodes"] + edges = payload["edges"] + return cls(nodes, edges) + except KeyError as exc: + raise ValueError("Invalid payload") from exc + def _build_graph(self) -> None: """Builds the graph from the nodes and edges.""" self.nodes = self._build_vertices() diff --git a/src/backend/langflow/graph/graph/utils.py b/src/backend/langflow/graph/graph/utils.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/graph/utils.py b/src/backend/langflow/graph/utils.py index e22b27cf5..b78b2f961 100644 --- a/src/backend/langflow/graph/utils.py +++ b/src/backend/langflow/graph/utils.py @@ -1,6 +1,7 @@ -import re from typing import Any, Union +from langflow.interface.utils import extract_input_variables_from_prompt + def validate_prompt(prompt: str): """Validate prompt.""" @@ -15,11 +16,6 @@ def fix_prompt(prompt: str): return prompt + " {input}" -def extract_input_variables_from_prompt(prompt: str) -> list[str]: - """Extract input variables from prompt.""" - return re.findall(r"{(.*?)}", prompt) - - def flatten_list(list_of_lists: list[Union[list, Any]]) -> list: """Flatten list of lists.""" new_list = [] From 228f938cd8371ab5ecd5aa2d5b68622ad40ab03b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:05:46 -0300 Subject: [PATCH 09/12] =?UTF-8?q?=F0=9F=94=A8=20refactor(types.py):=20move?= =?UTF-8?q?=20extract=5Finput=5Fvariables=5Ffrom=5Fprompt=20import=20to=20?= =?UTF-8?q?interface.utils=20module=20=F0=9F=94=A8=20refactor(custom.py,?= =?UTF-8?q?=20loading.py,=20prompts/custom.py,=20run.py):=20update=20impor?= =?UTF-8?q?t=20statements=20to=20use=20extract=5Finput=5Fvariables=5Ffrom?= =?UTF-8?q?=5Fprompt=20from=20interface.utils=20module=20=F0=9F=94=A8=20re?= =?UTF-8?q?factor(run.py):=20remove=20unused=20imports=20and=20functions?= =?UTF-8?q?=20=F0=9F=94=A8=20refactor(utils.py):=20add=20type=20hinting=20?= =?UTF-8?q?to=20extract=5Finput=5Fvariables=5Ffrom=5Fprompt=20function=20a?= =?UTF-8?q?nd=20remove=20unused=20imports=20The=20extract=5Finput=5Fvariab?= =?UTF-8?q?les=5Ffrom=5Fprompt=20function=20has=20been=20moved=20to=20the?= =?UTF-8?q?=20interface.utils=20module=20to=20improve=20code=20organizatio?= =?UTF-8?q?n.=20The=20import=20statements=20in=20the=20affected=20modules?= =?UTF-8?q?=20have=20been=20updated=20to=20reflect=20this=20change.=20Unus?= =?UTF-8?q?ed=20imports=20and=20functions=20have=20been=20removed=20from?= =?UTF-8?q?=20the=20run.py=20module.=20Type=20hinting=20has=20been=20added?= =?UTF-8?q?=20to=20the=20extract=5Finput=5Fvariables=5Ffrom=5Fprompt=20fun?= =?UTF-8?q?ction=20in=20the=20interface.utils=20module.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🚀 feat(processing): add processing module with get_result_and_steps and fix_memory_inputs functions The processing module was added to the project with two functions: get_result_and_steps and fix_memory_inputs. The get_result_and_steps function extracts the result and thought from a LangChain object and returns them. The fix_memory_inputs function checks if a LangChain object has a memory attribute and if that memory key exists in the object's input variables. If not, it gets a possible new memory key using the get_memory_key function and updates the memory keys using the update_memory_keys function. --- src/backend/langflow/graph/vertex/types.py | 3 +- .../langflow/interface/chains/custom.py | 2 +- src/backend/langflow/interface/loading.py | 33 --- .../langflow/interface/prompts/custom.py | 2 +- src/backend/langflow/interface/run.py | 191 +----------------- src/backend/langflow/interface/utils.py | 6 + src/backend/langflow/processing/__init__.py | 0 src/backend/langflow/processing/base.py | 55 +++++ src/backend/langflow/processing/process.py | 172 ++++++++++++++++ 9 files changed, 238 insertions(+), 226 deletions(-) create mode 100644 src/backend/langflow/processing/__init__.py create mode 100644 src/backend/langflow/processing/base.py create mode 100644 src/backend/langflow/processing/process.py diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index b81e72439..4eb20f416 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -1,7 +1,8 @@ from typing import Any, Dict, List, Optional, Union from langflow.graph.vertex.base import Vertex -from langflow.graph.utils import extract_input_variables_from_prompt, flatten_list +from langflow.graph.utils import flatten_list +from langflow.interface.utils import extract_input_variables_from_prompt class AgentVertex(Vertex): diff --git a/src/backend/langflow/interface/chains/custom.py b/src/backend/langflow/interface/chains/custom.py index cb76a53c8..ba4ba8b62 100644 --- a/src/backend/langflow/interface/chains/custom.py +++ b/src/backend/langflow/interface/chains/custom.py @@ -5,7 +5,7 @@ from langchain.memory.buffer import ConversationBufferMemory from langchain.schema import BaseMemory from pydantic import Field, root_validator -from langflow.graph.utils import extract_input_variables_from_prompt +from langflow.interface.utils import extract_input_variables_from_prompt DEFAULT_SUFFIX = """" Current conversation: diff --git a/src/backend/langflow/interface/loading.py b/src/backend/langflow/interface/loading.py index 16a7b186c..eb4623f5a 100644 --- a/src/backend/langflow/interface/loading.py +++ b/src/backend/langflow/interface/loading.py @@ -12,7 +12,6 @@ from langchain.agents.load_tools import ( _LLM_TOOLS, ) from langchain.agents.loading import load_agent_from_config -from langflow.graph import Graph from langchain.agents.tools import Tool from langchain.base_language import BaseLanguageModel from langchain.callbacks.base import BaseCallbackManager @@ -22,7 +21,6 @@ from pydantic import ValidationError from langflow.interface.agents.custom import CUSTOM_AGENTS from langflow.interface.importing.utils import get_function, import_by_type -from langflow.interface.run import fix_memory_inputs from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.types import get_type_list from langflow.interface.utils import load_file_into_dict @@ -163,37 +161,6 @@ def instantiate_utility(node_type, class_object, params): return class_object(**params) -def load_flow_from_json(path: str, build=True): - """Load flow from json file""" - # This is done to avoid circular imports - - with open(path, "r", encoding="utf-8") as f: - flow_graph = json.load(f) - data_graph = flow_graph["data"] - nodes = data_graph["nodes"] - # Substitute ZeroShotPrompt with PromptTemplate - # nodes = replace_zero_shot_prompt_with_prompt_template(nodes) - # Add input variables - # nodes = payload.extract_input_variables(nodes) - - # Nodes, edges and root node - edges = data_graph["edges"] - graph = Graph(nodes, edges) - if build: - langchain_object = graph.build() - if hasattr(langchain_object, "verbose"): - langchain_object.verbose = True - - if hasattr(langchain_object, "return_intermediate_steps"): - # https://github.com/hwchase17/langchain/issues/2068 - # Deactivating until we have a frontend solution - # to display intermediate steps - langchain_object.return_intermediate_steps = False - fix_memory_inputs(langchain_object) - return langchain_object - return graph - - def replace_zero_shot_prompt_with_prompt_template(nodes): """Replace ZeroShotPrompt with PromptTemplate""" for node in nodes: diff --git a/src/backend/langflow/interface/prompts/custom.py b/src/backend/langflow/interface/prompts/custom.py index b1dbef370..286210271 100644 --- a/src/backend/langflow/interface/prompts/custom.py +++ b/src/backend/langflow/interface/prompts/custom.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional, Type from langchain.prompts import PromptTemplate from pydantic import root_validator -from langflow.graph.utils import extract_input_variables_from_prompt +from langflow.interface.utils import extract_input_variables_from_prompt # Steps to create a BaseCustomPrompt: # 1. Create a prompt template that endes with: diff --git a/src/backend/langflow/interface/run.py b/src/backend/langflow/interface/run.py index c2483416f..89f71fd8b 100644 --- a/src/backend/langflow/interface/run.py +++ b/src/backend/langflow/interface/run.py @@ -1,10 +1,3 @@ -import contextlib -import io -from typing import Any, Dict, List, Tuple - -from langchain.schema import AgentAction - -from langflow.api.callback import AsyncStreamingLLMCallbackHandler, StreamingLLMCallbackHandler # type: ignore from langflow.cache.base import compute_dict_hash, load_cache, memoize_dict from langflow.graph import Graph from langflow.utils.logger import logger @@ -24,15 +17,6 @@ def load_langchain_object(data_graph, is_first_message=False): return computed_hash, langchain_object -def load_or_build_langchain_object(data_graph, is_first_message=False): - """ - Load langchain object from cache if it exists, otherwise build it. - """ - if is_first_message: - build_langchain_object_with_caching.clear_cache() - return build_langchain_object_with_caching(data_graph) - - @memoize_dict(maxsize=10) def build_langchain_object_with_caching(data_graph): """ @@ -40,16 +24,10 @@ def build_langchain_object_with_caching(data_graph): """ logger.debug("Building langchain object") - graph = build_graph(data_graph) + graph = Graph.from_payload(data_graph) return graph.build() -def build_graph(data_graph): - nodes = data_graph["nodes"] - edges = data_graph["edges"] - return Graph(nodes, edges) - - def build_langchain_object(data_graph): """ Build langchain object from data_graph. @@ -66,29 +44,6 @@ def build_langchain_object(data_graph): return graph.build() -def process_graph_cached(data_graph: Dict[str, Any], message: str): - """ - Process graph by extracting input variables and replacing ZeroShotPrompt - with PromptTemplate,then run the graph and return the result and thought. - """ - # Load langchain object - is_first_message = len(data_graph.get("chatHistory", [])) == 0 - langchain_object = load_or_build_langchain_object(data_graph, is_first_message) - logger.debug("Loaded langchain object") - - if langchain_object is None: - # Raise user facing error - raise ValueError( - "There was an error loading the langchain_object. Please, check all the nodes and try again." - ) - - # Generate result and thought - logger.debug("Generating result and thought") - result, thought = get_result_and_thought(langchain_object, message) - logger.debug("Generated result and thought") - return {"result": str(result), "thought": thought.strip()} - - def get_memory_key(langchain_object): """ Given a LangChain object, this function retrieves the current memory key from the object's memory attribute. @@ -124,147 +79,3 @@ def update_memory_keys(langchain_object, possible_new_mem_key): langchain_object.memory.input_key = input_key langchain_object.memory.output_key = output_key langchain_object.memory.memory_key = possible_new_mem_key - - -def fix_memory_inputs(langchain_object): - """ - Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the - object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the - get_memory_key function and updates the memory keys using the update_memory_keys function. - """ - if hasattr(langchain_object, "memory") and langchain_object.memory is not None: - try: - if langchain_object.memory.memory_key in langchain_object.input_variables: - return - except AttributeError: - input_variables = ( - langchain_object.prompt.input_variables - if hasattr(langchain_object, "prompt") - else langchain_object.input_keys - ) - if langchain_object.memory.memory_key in input_variables: - return - - possible_new_mem_key = get_memory_key(langchain_object) - if possible_new_mem_key is not None: - update_memory_keys(langchain_object, possible_new_mem_key) - - -async def get_result_and_steps(langchain_object, message: str, **kwargs): - """Get result and thought from extracted json""" - - try: - if hasattr(langchain_object, "verbose"): - langchain_object.verbose = True - chat_input = None - memory_key = "" - if hasattr(langchain_object, "memory") and langchain_object.memory is not None: - memory_key = langchain_object.memory.memory_key - - if hasattr(langchain_object, "input_keys"): - for key in langchain_object.input_keys: - if key not in [memory_key, "chat_history"]: - chat_input = {key: message} - else: - chat_input = message # type: ignore - - if hasattr(langchain_object, "return_intermediate_steps"): - # https://github.com/hwchase17/langchain/issues/2068 - # Deactivating until we have a frontend solution - # to display intermediate steps - langchain_object.return_intermediate_steps = True - - fix_memory_inputs(langchain_object) - try: - async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)] - output = await langchain_object.acall(chat_input, callbacks=async_callbacks) - except Exception as exc: - # make the error message more informative - logger.debug(f"Error: {str(exc)}") - sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)] - output = langchain_object(chat_input, callbacks=sync_callbacks) - - intermediate_steps = ( - output.get("intermediate_steps", []) if isinstance(output, dict) else [] - ) - - result = ( - output.get(langchain_object.output_keys[0]) - if isinstance(output, dict) - else output - ) - thought = format_actions(intermediate_steps) if intermediate_steps else "" - except Exception as exc: - raise ValueError(f"Error: {str(exc)}") from exc - return result, thought - - -def get_result_and_thought(langchain_object, message: str): - """Get result and thought from extracted json""" - try: - if hasattr(langchain_object, "verbose"): - langchain_object.verbose = True - chat_input = None - memory_key = "" - if hasattr(langchain_object, "memory") and langchain_object.memory is not None: - memory_key = langchain_object.memory.memory_key - - if hasattr(langchain_object, "input_keys"): - for key in langchain_object.input_keys: - if key not in [memory_key, "chat_history"]: - chat_input = {key: message} - else: - chat_input = message # type: ignore - - if hasattr(langchain_object, "return_intermediate_steps"): - # https://github.com/hwchase17/langchain/issues/2068 - # Deactivating until we have a frontend solution - # to display intermediate steps - langchain_object.return_intermediate_steps = False - - fix_memory_inputs(langchain_object) - - with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): - try: - # if hasattr(langchain_object, "acall"): - # output = await langchain_object.acall(chat_input) - # else: - output = langchain_object(chat_input) - except ValueError as exc: - # make the error message more informative - logger.debug(f"Error: {str(exc)}") - output = langchain_object.run(chat_input) - - intermediate_steps = ( - output.get("intermediate_steps", []) if isinstance(output, dict) else [] - ) - - result = ( - output.get(langchain_object.output_keys[0]) - if isinstance(output, dict) - else output - ) - if intermediate_steps: - thought = format_actions(intermediate_steps) - else: - thought = output_buffer.getvalue() - - except Exception as exc: - raise ValueError(f"Error: {str(exc)}") from exc - return result, thought - - -def format_actions(actions: List[Tuple[AgentAction, str]]) -> str: - """Format a list of (AgentAction, answer) tuples into a string.""" - output = [] - for action, answer in actions: - log = action.log - tool = action.tool - tool_input = action.tool_input - output.append(f"Log: {log}") - if "Action" not in log and "Action Input" not in log: - output.append(f"Tool: {tool}") - output.append(f"Tool Input: {tool_input}") - output.append(f"Answer: {answer}") - output.append("") # Add a blank line - return "\n".join(output) diff --git a/src/backend/langflow/interface/utils.py b/src/backend/langflow/interface/utils.py index 2b7c5acd1..32c605654 100644 --- a/src/backend/langflow/interface/utils.py +++ b/src/backend/langflow/interface/utils.py @@ -2,6 +2,7 @@ import base64 import json import os from io import BytesIO +import re import yaml from langchain.base_language import BaseLanguageModel @@ -48,3 +49,8 @@ def try_setting_streaming_options(langchain_object, websocket): llm.streaming = True return langchain_object + + +def extract_input_variables_from_prompt(prompt: str) -> list[str]: + """Extract input variables from prompt.""" + return re.findall(r"{(.*?)}", prompt) diff --git a/src/backend/langflow/processing/__init__.py b/src/backend/langflow/processing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/backend/langflow/processing/base.py b/src/backend/langflow/processing/base.py new file mode 100644 index 000000000..97b0d5be0 --- /dev/null +++ b/src/backend/langflow/processing/base.py @@ -0,0 +1,55 @@ +from langflow.api.v1.callback import ( + AsyncStreamingLLMCallbackHandler, + StreamingLLMCallbackHandler, +) +from langflow.processing.process import fix_memory_inputs, format_actions +from langflow.utils.logger import logger + + +async def get_result_and_steps(langchain_object, message: str, **kwargs): + """Get result and thought from extracted json""" + + try: + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + chat_input = None + memory_key = "" + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + memory_key = langchain_object.memory.memory_key + + if hasattr(langchain_object, "input_keys"): + for key in langchain_object.input_keys: + if key not in [memory_key, "chat_history"]: + chat_input = {key: message} + else: + chat_input = message # type: ignore + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = True + + fix_memory_inputs(langchain_object) + try: + async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)] + output = await langchain_object.acall(chat_input, callbacks=async_callbacks) + except Exception as exc: + # make the error message more informative + logger.debug(f"Error: {str(exc)}") + sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)] + output = langchain_object(chat_input, callbacks=sync_callbacks) + + intermediate_steps = ( + output.get("intermediate_steps", []) if isinstance(output, dict) else [] + ) + + result = ( + output.get(langchain_object.output_keys[0]) + if isinstance(output, dict) + else output + ) + thought = format_actions(intermediate_steps) if intermediate_steps else "" + except Exception as exc: + raise ValueError(f"Error: {str(exc)}") from exc + return result, thought diff --git a/src/backend/langflow/processing/process.py b/src/backend/langflow/processing/process.py new file mode 100644 index 000000000..3b8852e00 --- /dev/null +++ b/src/backend/langflow/processing/process.py @@ -0,0 +1,172 @@ +import contextlib +import io +from langchain.schema import AgentAction +import json +from langflow.interface.run import ( + build_langchain_object_with_caching, + get_memory_key, + update_memory_keys, +) +from langflow.utils.logger import logger +from langflow.graph import Graph + + +from typing import Any, Dict, List, Tuple + + +def fix_memory_inputs(langchain_object): + """ + Given a LangChain object, this function checks if it has a memory attribute and if that memory key exists in the + object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the + get_memory_key function and updates the memory keys using the update_memory_keys function. + """ + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + try: + if langchain_object.memory.memory_key in langchain_object.input_variables: + return + except AttributeError: + input_variables = ( + langchain_object.prompt.input_variables + if hasattr(langchain_object, "prompt") + else langchain_object.input_keys + ) + if langchain_object.memory.memory_key in input_variables: + return + + possible_new_mem_key = get_memory_key(langchain_object) + if possible_new_mem_key is not None: + update_memory_keys(langchain_object, possible_new_mem_key) + + +def format_actions(actions: List[Tuple[AgentAction, str]]) -> str: + """Format a list of (AgentAction, answer) tuples into a string.""" + output = [] + for action, answer in actions: + log = action.log + tool = action.tool + tool_input = action.tool_input + output.append(f"Log: {log}") + if "Action" not in log and "Action Input" not in log: + output.append(f"Tool: {tool}") + output.append(f"Tool Input: {tool_input}") + output.append(f"Answer: {answer}") + output.append("") # Add a blank line + return "\n".join(output) + + +def get_result_and_thought(langchain_object, message: str): + """Get result and thought from extracted json""" + try: + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + chat_input = None + memory_key = "" + if hasattr(langchain_object, "memory") and langchain_object.memory is not None: + memory_key = langchain_object.memory.memory_key + + if hasattr(langchain_object, "input_keys"): + for key in langchain_object.input_keys: + if key not in [memory_key, "chat_history"]: + chat_input = {key: message} + else: + chat_input = message # type: ignore + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + + fix_memory_inputs(langchain_object) + + with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer): + try: + # if hasattr(langchain_object, "acall"): + # output = await langchain_object.acall(chat_input) + # else: + output = langchain_object(chat_input) + except ValueError as exc: + # make the error message more informative + logger.debug(f"Error: {str(exc)}") + output = langchain_object.run(chat_input) + + intermediate_steps = ( + output.get("intermediate_steps", []) if isinstance(output, dict) else [] + ) + + result = ( + output.get(langchain_object.output_keys[0]) + if isinstance(output, dict) + else output + ) + if intermediate_steps: + thought = format_actions(intermediate_steps) + else: + thought = output_buffer.getvalue() + + except Exception as exc: + raise ValueError(f"Error: {str(exc)}") from exc + return result, thought + + +def load_or_build_langchain_object(data_graph, is_first_message=False): + """ + Load langchain object from cache if it exists, otherwise build it. + """ + if is_first_message: + build_langchain_object_with_caching.clear_cache() + return build_langchain_object_with_caching(data_graph) + + +def process_graph_cached(data_graph: Dict[str, Any], message: str): + """ + Process graph by extracting input variables and replacing ZeroShotPrompt + with PromptTemplate,then run the graph and return the result and thought. + """ + # Load langchain object + is_first_message = len(data_graph.get("chatHistory", [])) == 0 + langchain_object = load_or_build_langchain_object(data_graph, is_first_message) + logger.debug("Loaded langchain object") + + if langchain_object is None: + # Raise user facing error + raise ValueError( + "There was an error loading the langchain_object. Please, check all the nodes and try again." + ) + + # Generate result and thought + logger.debug("Generating result and thought") + result, thought = get_result_and_thought(langchain_object, message) + logger.debug("Generated result and thought") + return {"result": str(result), "thought": thought.strip()} + + +def load_flow_from_json(path: str, build=True): + """Load flow from json file""" + # This is done to avoid circular imports + + with open(path, "r", encoding="utf-8") as f: + flow_graph = json.load(f) + data_graph = flow_graph["data"] + nodes = data_graph["nodes"] + # Substitute ZeroShotPrompt with PromptTemplate + # nodes = replace_zero_shot_prompt_with_prompt_template(nodes) + # Add input variables + # nodes = payload.extract_input_variables(nodes) + + # Nodes, edges and root node + edges = data_graph["edges"] + graph = Graph(nodes, edges) + if build: + langchain_object = graph.build() + if hasattr(langchain_object, "verbose"): + langchain_object.verbose = True + + if hasattr(langchain_object, "return_intermediate_steps"): + # https://github.com/hwchase17/langchain/issues/2068 + # Deactivating until we have a frontend solution + # to display intermediate steps + langchain_object.return_intermediate_steps = False + fix_memory_inputs(langchain_object) + return langchain_object + return graph From 478bb446c3fd9ee6077c7abad6698a583e075e5b Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:06:14 -0300 Subject: [PATCH 10/12] =?UTF-8?q?=F0=9F=90=9B=20fix(frontend):=20add=20mis?= =?UTF-8?q?sing=20api/v1=20prefix=20to=20API=20routes=20=F0=9F=90=9B=20fix?= =?UTF-8?q?(frontend):=20add=20missing=20api/v1=20prefix=20to=20WebSocket?= =?UTF-8?q?=20URL=20=F0=9F=90=9B=20fix(frontend):=20add=20missing=20api/v1?= =?UTF-8?q?=20prefix=20to=20Vite=20proxy=20target=20The=20API=20routes,=20?= =?UTF-8?q?WebSocket=20URL,=20and=20Vite=20proxy=20target=20were=20missing?= =?UTF-8?q?=20the=20"api/v1"=20prefix,=20causing=20the=20frontend=20to=20n?= =?UTF-8?q?ot=20be=20able=20to=20communicate=20with=20the=20backend.=20Thi?= =?UTF-8?q?s=20commit=20adds=20the=20missing=20prefix=20to=20all=20three?= =?UTF-8?q?=20locations=20to=20fix=20the=20issue.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/frontend/src/controllers/API/index.ts | 4 ++-- src/frontend/src/modals/chatModal/index.tsx | 4 ++-- src/frontend/vite.config.ts | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/frontend/src/controllers/API/index.ts b/src/frontend/src/controllers/API/index.ts index f6f46404b..0cffd04bf 100644 --- a/src/frontend/src/controllers/API/index.ts +++ b/src/frontend/src/controllers/API/index.ts @@ -14,13 +14,13 @@ export async function sendAll(data: sendAllProps) { export async function checkCode( code: string ): Promise> { - return await axios.post("/validate/code", { code }); + return await axios.post("api/v1/validate/code", { code }); } export async function checkPrompt( template: string ): Promise> { - return await axios.post("/validate/prompt", { template }); + return await axios.post("api/v1/validate/prompt", { template }); } export async function getExamples(): Promise { diff --git a/src/frontend/src/modals/chatModal/index.tsx b/src/frontend/src/modals/chatModal/index.tsx index cf2b52aac..39bb72994 100644 --- a/src/frontend/src/modals/chatModal/index.tsx +++ b/src/frontend/src/modals/chatModal/index.tsx @@ -182,10 +182,10 @@ export default function ChatModal({ try { const urlWs = process.env.NODE_ENV === "development" - ? `ws://localhost:7860/chat/${id.current}` + ? `ws://localhost:7860/api/v1/chat/${id.current}` : `${window.location.protocol === "https:" ? "wss" : "ws"}://${ window.location.host - }/chat/${id.current}`; + }api/v1/chat/${id.current}`; const newWs = new WebSocket(urlWs); newWs.onopen = () => { console.log("WebSocket connection established!"); diff --git a/src/frontend/vite.config.ts b/src/frontend/vite.config.ts index 172b37733..d4fa2248b 100644 --- a/src/frontend/vite.config.ts +++ b/src/frontend/vite.config.ts @@ -11,7 +11,7 @@ const apiRoutes = [ ]; // Use environment variable to determine the target. -const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860"; +const target = process.env.VITE_PROXY_TARGET || "http://127.0.0.1:7860/api/v1"; const proxyTargets = apiRoutes.reduce((proxyObj, route) => { proxyObj[route] = { From 6b5539232fa98fcb8323aa66289d6414b4fa1dd2 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:06:40 -0300 Subject: [PATCH 11/12] =?UTF-8?q?=F0=9F=9A=80=20chore(server,=20tests):=20?= =?UTF-8?q?update=20API=20endpoint=20URLs=20to=20include=20version=20numbe?= =?UTF-8?q?r=20The=20API=20endpoint=20URLs=20have=20been=20updated=20to=20?= =?UTF-8?q?include=20the=20version=20number=20to=20improve=20the=20API's?= =?UTF-8?q?=20versioning=20and=20maintainability.=20The=20changes=20were?= =?UTF-8?q?=20made=20to=20the=20server.ts=20file=20and=20the=20tests=20tha?= =?UTF-8?q?t=20use=20the=20API=20endpoints.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🐛 fix(tests): update API endpoint paths in test files The API endpoint paths in the test files were outdated and have been updated to reflect the current API version. This ensures that the tests are running against the correct endpoints and that the tests are up-to-date with the current API version. --- tests/test_agents_template.py | 10 +++++----- tests/test_cache.py | 6 +++--- tests/test_chains_template.py | 16 ++++++++-------- tests/test_endpoints.py | 20 ++++++++++---------- tests/test_graph.py | 2 +- tests/test_llms_template.py | 8 ++++---- tests/test_loading.py | 2 +- tests/test_prompts_template.py | 8 ++++---- tests/test_websocket.py | 8 ++++---- 9 files changed, 40 insertions(+), 40 deletions(-) diff --git a/tests/test_agents_template.py b/tests/test_agents_template.py index 7aa8de176..8e181711f 100644 --- a/tests/test_agents_template.py +++ b/tests/test_agents_template.py @@ -5,7 +5,7 @@ from langflow.settings import settings # check that all agents are in settings.agents # are in json_response["agents"] def test_agents_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] @@ -13,7 +13,7 @@ def test_agents_settings(client: TestClient): def test_zero_shot_agent(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] @@ -52,7 +52,7 @@ def test_zero_shot_agent(client: TestClient): def test_json_agent(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] @@ -87,7 +87,7 @@ def test_json_agent(client: TestClient): def test_csv_agent(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] @@ -126,7 +126,7 @@ def test_csv_agent(client: TestClient): def test_initialize_agent(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() agents = json_response["agents"] diff --git a/tests/test_cache.py b/tests/test_cache.py index 3d3e951fc..3214e7d15 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -1,10 +1,10 @@ import json +from langflow.graph import Graph +from langflow.processing.process import load_or_build_langchain_object import pytest from langflow.interface.run import ( - build_graph, build_langchain_object_with_caching, - load_or_build_langchain_object, ) @@ -62,7 +62,7 @@ def test_build_langchain_object_with_caching(basic_data_graph): # Test build_graph def test_build_graph(basic_data_graph): - graph = build_graph(basic_data_graph) + graph = Graph.from_payload(basic_data_graph) assert graph is not None assert len(graph.nodes) == len(basic_data_graph["nodes"]) assert len(graph.edges) == len(basic_data_graph["edges"]) diff --git a/tests/test_chains_template.py b/tests/test_chains_template.py index c958cf64d..0c7af56ad 100644 --- a/tests/test_chains_template.py +++ b/tests/test_chains_template.py @@ -3,7 +3,7 @@ from langflow.settings import settings def test_chains_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -12,7 +12,7 @@ def test_chains_settings(client: TestClient): # Test the ConversationChain object def test_conversation_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -94,7 +94,7 @@ def test_conversation_chain(client: TestClient): def test_llm_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -152,7 +152,7 @@ def test_llm_chain(client: TestClient): def test_llm_checker_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -228,7 +228,7 @@ def test_llm_checker_chain(client: TestClient): def test_llm_math_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -306,7 +306,7 @@ def test_llm_math_chain(client: TestClient): def test_series_character_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -368,7 +368,7 @@ def test_series_character_chain(client: TestClient): def test_mid_journey_prompt_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] @@ -407,7 +407,7 @@ def test_mid_journey_prompt_chain(client: TestClient): def test_time_travel_guide_chain(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() chains = json_response["chains"] diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index 83f6c62b1..9e07dfb24 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -4,7 +4,7 @@ from langflow.interface.tools.constants import CUSTOM_TOOLS def test_get_all(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() # We need to test the custom nodes @@ -21,7 +21,7 @@ import math def square(x): return x ** 2 """ - response1 = client.post("/validate/code", json={"code": code1}) + response1 = client.post("api/v1/validate/code", json={"code": code1}) assert response1.status_code == 200 assert response1.json() == {"imports": {"errors": []}, "function": {"errors": []}} @@ -32,7 +32,7 @@ import non_existent_module def square(x): return x ** 2 """ - response2 = client.post("/validate/code", json={"code": code2}) + response2 = client.post("api/v1/validate/code", json={"code": code2}) assert response2.status_code == 200 assert response2.json() == { "imports": {"errors": ["No module named 'non_existent_module'"]}, @@ -46,7 +46,7 @@ import math def square(x) return x ** 2 """ - response3 = client.post("/validate/code", json={"code": code3}) + response3 = client.post("api/v1/validate/code", json={"code": code3}) assert response3.status_code == 200 assert response3.json() == { "imports": {"errors": []}, @@ -54,11 +54,11 @@ def square(x) } # Test case with invalid JSON payload - response4 = client.post("/validate/code", json={"invalid_key": code1}) + response4 = client.post("api/v1/validate/code", json={"invalid_key": code1}) assert response4.status_code == 422 # Test case with an empty code string - response5 = client.post("/validate/code", json={"code": ""}) + response5 = client.post("api/v1/validate/code", json={"code": ""}) assert response5.status_code == 200 assert response5.json() == {"imports": {"errors": []}, "function": {"errors": []}} @@ -69,7 +69,7 @@ import math def square(x) return x ** 2 """ - response6 = client.post("/validate/code", json={"code": code6}) + response6 = client.post("api/v1/validate/code", json={"code": code6}) assert response6.status_code == 200 assert response6.json() == { "imports": {"errors": []}, @@ -95,13 +95,13 @@ INVALID_PROMPT = "This is an invalid prompt without any input variable." def test_valid_prompt(client: TestClient): - response = client.post("/validate/prompt", json={"template": VALID_PROMPT}) + response = client.post("api/v1/validate/prompt", json={"template": VALID_PROMPT}) assert response.status_code == 200 assert response.json() == {"input_variables": ["product"]} def test_invalid_prompt(client: TestClient): - response = client.post("/validate/prompt", json={"template": INVALID_PROMPT}) + response = client.post("api/v1/validate/prompt", json={"template": INVALID_PROMPT}) assert response.status_code == 200 assert response.json() == {"input_variables": []} @@ -116,7 +116,7 @@ def test_invalid_prompt(client: TestClient): ], ) def test_various_prompts(client, prompt, expected_input_variables): - response = client.post("/validate/prompt", json={"template": prompt}) + response = client.post("api/v1/validate/prompt", json={"template": prompt}) assert response.status_code == 200 assert response.json() == { "input_variables": expected_input_variables, diff --git a/tests/test_graph.py b/tests/test_graph.py index 8c6560d54..69a926cc3 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -16,7 +16,7 @@ from langflow.graph.vertex.types import ( ToolVertex, WrapperVertex, ) -from langflow.interface.run import get_result_and_thought +from langflow.processing.process import get_result_and_thought from langflow.utils.payload import get_root_node # Test cases for the graph module diff --git a/tests/test_llms_template.py b/tests/test_llms_template.py index ccf2f6388..da0b94318 100644 --- a/tests/test_llms_template.py +++ b/tests/test_llms_template.py @@ -3,7 +3,7 @@ from langflow.settings import settings def test_llms_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() llms = json_response["llms"] @@ -11,7 +11,7 @@ def test_llms_settings(client: TestClient): # def test_hugging_face_hub(client: TestClient): -# response = client.get("/all") +# response = client.get("api/v1/all") # assert response.status_code == 200 # json_response = response.json() # language_models = json_response["llms"] @@ -103,7 +103,7 @@ def test_llms_settings(client: TestClient): def test_openai(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() language_models = json_response["llms"] @@ -333,7 +333,7 @@ def test_openai(client: TestClient): def test_chat_open_ai(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() language_models = json_response["llms"] diff --git a/tests/test_loading.py b/tests/test_loading.py index 872314699..885eb7a82 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -2,7 +2,7 @@ import json import pytest from langchain.chains.base import Chain -from langflow import load_flow_from_json +from langflow.processing.process import load_flow_from_json from langflow.graph import Graph from langflow.utils.payload import get_root_node diff --git a/tests/test_prompts_template.py b/tests/test_prompts_template.py index 83da2f14d..a8562898c 100644 --- a/tests/test_prompts_template.py +++ b/tests/test_prompts_template.py @@ -3,7 +3,7 @@ from langflow.settings import settings def test_prompts_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() prompts = json_response["prompts"] @@ -11,7 +11,7 @@ def test_prompts_settings(client: TestClient): def test_prompt_template(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() prompts = json_response["prompts"] @@ -89,7 +89,7 @@ def test_prompt_template(client: TestClient): def test_few_shot_prompt_template(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() prompts = json_response["prompts"] @@ -168,7 +168,7 @@ def test_few_shot_prompt_template(client: TestClient): def test_zero_shot_prompt(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() prompts = json_response["prompts"] diff --git a/tests/test_websocket.py b/tests/test_websocket.py index 5b60d0fed..611faff79 100644 --- a/tests/test_websocket.py +++ b/tests/test_websocket.py @@ -5,17 +5,17 @@ from fastapi.testclient import TestClient def test_websocket_connection(client: TestClient): - with client.websocket_connect("/chat/test_client") as websocket: + with client.websocket_connect("api/v1/chat/test_client") as websocket: assert websocket.scope["client"] == ["testclient", 50000] - assert websocket.scope["path"] == "/chat/test_client" + assert websocket.scope["path"] == "/api/v1/chat/test_client" def test_chat_history(client: TestClient): # Mock the process_graph function to return a specific value - with patch("langflow.api.chat_manager.process_graph") as mock_process_graph: + with patch("langflow.chat.manager.process_graph") as mock_process_graph: mock_process_graph.return_value = ("Hello, I'm a mock response!", "") - with client.websocket_connect("/chat/test_client") as websocket: + with client.websocket_connect("api/v1/chat/test_client") as websocket: # First message should be the history history = websocket.receive_json() assert history == [] # Empty history From 1e854fc4694b6404b6f09050be9566f8fd1d6784 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 6 Jun 2023 10:06:51 -0300 Subject: [PATCH 12/12] update endpoint --- tests/test_vectorstore_template.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_vectorstore_template.py b/tests/test_vectorstore_template.py index 5b1d7e5bc..0aa823786 100644 --- a/tests/test_vectorstore_template.py +++ b/tests/test_vectorstore_template.py @@ -5,7 +5,7 @@ from langflow.settings import settings # check that all agents are in settings.agents # are in json_response["agents"] def test_vectorstores_settings(client: TestClient): - response = client.get("/all") + response = client.get("api/v1/all") assert response.status_code == 200 json_response = response.json() vectorstores = json_response["vectorstores"]