From c786e9970dd107b9ab27b3c8b52ea982a1f9eb94 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 26 Feb 2024 23:01:20 -0300 Subject: [PATCH] Format exception message --- src/backend/langflow/api/utils.py | 54 +++++++++++++++++++++++++---- src/backend/langflow/api/v1/chat.py | 29 ++++++++++++---- 2 files changed, 69 insertions(+), 14 deletions(-) diff --git a/src/backend/langflow/api/utils.py b/src/backend/langflow/api/utils.py index bd71e6dcb..703fbd9e3 100644 --- a/src/backend/langflow/api/utils.py +++ b/src/backend/langflow/api/utils.py @@ -22,7 +22,9 @@ API_WORDS = ["api", "key", "token"] def has_api_terms(word: str): - return "api" in word and ("key" in word or ("token" in word and "tokens" not in word)) + return "api" in word and ( + "key" in word or ("token" in word and "tokens" not in word) + ) def remove_api_keys(flow: dict): @@ -32,7 +34,11 @@ def remove_api_keys(flow: dict): node_data = node.get("data").get("node") template = node_data.get("template") for value in template.values(): - if isinstance(value, dict) and has_api_terms(value["name"]) and value.get("password"): + if ( + isinstance(value, dict) + and has_api_terms(value["name"]) + and value.get("password") + ): value["value"] = None return flow @@ -53,7 +59,9 @@ def build_input_keys_response(langchain_object, artifacts): input_keys_response["input_keys"][key] = value # If the object has memory, that memory will have a memory_variables attribute # memory variables should be removed from the input keys - if hasattr(langchain_object, "memory") and hasattr(langchain_object.memory, "memory_variables"): + if hasattr(langchain_object, "memory") and hasattr( + langchain_object.memory, "memory_variables" + ): # Remove memory variables from input keys input_keys_response["input_keys"] = { key: value @@ -63,7 +71,9 @@ def build_input_keys_response(langchain_object, artifacts): # Add memory variables to memory_keys input_keys_response["memory_keys"] = langchain_object.memory.memory_variables - if hasattr(langchain_object, "prompt") and hasattr(langchain_object.prompt, "template"): + if hasattr(langchain_object, "prompt") and hasattr( + langchain_object.prompt, "template" + ): input_keys_response["template"] = langchain_object.prompt.template return input_keys_response @@ -98,7 +108,11 @@ def raw_frontend_data_is_valid(raw_frontend_data): def is_valid_data(frontend_node, raw_frontend_data): """Check if the data is valid for processing.""" - return frontend_node and "template" in frontend_node and raw_frontend_data_is_valid(raw_frontend_data) + return ( + frontend_node + and "template" in frontend_node + and raw_frontend_data_is_valid(raw_frontend_data) + ) def update_template_values(frontend_template, raw_template): @@ -138,7 +152,9 @@ def get_file_path_value(file_path): # If the path is not in the cache dir, return empty string # This is to prevent access to files outside the cache dir # If the path is not a file, return empty string - if not path.exists() or not str(path).startswith(user_cache_dir("langflow", "langflow")): + if not path.exists() or not str(path).startswith( + user_cache_dir("langflow", "langflow") + ): return "" return file_path @@ -169,7 +185,9 @@ async def check_langflow_version(component: StoreComponentCreate): langflow_version = get_lf_version_from_pypi() if langflow_version is None: - raise HTTPException(status_code=500, detail="Unable to verify the latest version of Langflow") + raise HTTPException( + status_code=500, detail="Unable to verify the latest version of Langflow" + ) elif langflow_version != component.last_tested_version: warnings.warn( f"Your version of Langflow ({component.last_tested_version}) is outdated. " @@ -230,3 +248,25 @@ def build_and_cache_graph( graph = graph.update(other_graph) chat_service.set_cache(flow_id, graph) return graph + + +def format_syntax_error_message(exc: SyntaxError) -> str: + """Format a SyntaxError message for returning to the frontend.""" + return f"Syntax error in code. Error on line {exc.lineno}: {exc.text.strip()}" + + +def get_causing_exception(exc: Exception) -> Exception: + """Get the causing exception from an exception.""" + if hasattr(exc, "__cause__") and exc.__cause__: + return get_causing_exception(exc.__cause__) + return exc + + +def format_exception_message(exc: Exception) -> str: + """Format an exception message for returning to the frontend.""" + # We need to check if the __cause__ is a SyntaxError + # If it is, we need to return the message of the SyntaxError + causing_exception = get_causing_exception(exc) + if isinstance(causing_exception, SyntaxError): + return format_syntax_error_message(causing_exception) + return str(exc) diff --git a/src/backend/langflow/api/v1/chat.py b/src/backend/langflow/api/v1/chat.py index d04588cba..1163ee749 100644 --- a/src/backend/langflow/api/v1/chat.py +++ b/src/backend/langflow/api/v1/chat.py @@ -14,7 +14,11 @@ from fastapi.responses import StreamingResponse from loguru import logger from sqlmodel import Session -from langflow.api.utils import build_and_cache_graph, format_elapsed_time +from langflow.api.utils import ( + build_and_cache_graph, + format_elapsed_time, + format_exception_message, +) from langflow.api.v1.schemas import ( ResultData, StreamData, @@ -45,9 +49,13 @@ async def chat( user = await get_current_user_for_websocket(websocket, db) await websocket.accept() if not user: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" + ) elif not user.is_active: - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" + ) if client_id in chat_service.cache_service: await chat_service.handle_websocket(client_id, websocket) @@ -63,7 +71,9 @@ async def chat( logger.error(f"Error in chat websocket: {exc}") messsage = exc.detail if isinstance(exc, HTTPException) else str(exc) if "Could not validate credentials" in str(exc): - await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized") + await websocket.close( + code=status.WS_1008_POLICY_VIOLATION, reason="Unauthorized" + ) else: await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=messsage) @@ -133,8 +143,12 @@ async def build_vertex( cache = chat_service.get_cache(flow_id) if not cache: # If there's no cache - logger.warning(f"No cache found for {flow_id}. Building graph starting at {vertex_id}") - graph = build_and_cache_graph(flow_id=flow_id, session=next(get_session()), chat_service=chat_service) + logger.warning( + f"No cache found for {flow_id}. Building graph starting at {vertex_id}" + ) + graph = build_and_cache_graph( + flow_id=flow_id, session=next(get_session()), chat_service=chat_service + ) else: graph = cache.get("result") result_dict = {} @@ -165,7 +179,8 @@ async def build_vertex( raise ValueError(f"No result found for vertex {vertex_id}") except Exception as exc: - params = str(exc) + # + params = format_exception_message(exc) valid = False result_dict = ResultData(results={}) artifacts = {}