From b7e52f62be4d6852a4ab72b38de12e7acfe63c6a Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Tue, 27 Feb 2024 11:36:58 -0300 Subject: [PATCH] Refactor API schemas and update dependencies --- src/backend/langflow/api/v1/schemas.py | 18 +++++++----------- .../langflow/services/monitor/utils.py | 19 +++++++++++++------ src/backend/langflow/services/socket/utils.py | 13 ++++++++----- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index 23d8ddf9b..adb26202a 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -4,12 +4,12 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union from uuid import UUID -from langflow.api.utils import serialize_field +from pydantic import BaseModel, Field, field_validator + from langflow.services.database.models.api_key.model import ApiKeyRead from langflow.services.database.models.base import orjson_dumps from langflow.services.database.models.flow import FlowCreate, FlowRead from langflow.services.database.models.user import UserRead -from pydantic import BaseModel, Field, field_serializer, field_validator class BuildStatus(Enum): @@ -161,7 +161,9 @@ class StreamData(BaseModel): data: dict def __str__(self) -> str: - return f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n" + return ( + f"event: {self.event}\ndata: {orjson_dumps(self.data, indent_2=False)}\n\n" + ) class CustomComponentCode(BaseModel): @@ -220,18 +222,12 @@ class VerticesOrderResponse(BaseModel): ids: List[List[str]] -class ResultData(BaseModel): +class ResultDataResponse(BaseModel): results: Optional[Any] = Field(default_factory=dict) artifacts: Optional[Any] = Field(default_factory=dict) timedelta: Optional[float] = None duration: Optional[str] = None - @field_serializer("results") - def serialize_results(self, value): - if isinstance(value, dict): - return {key: serialize_field(val) for key, val in value.items()} - return serialize_field(value) - class VertexBuildResponse(BaseModel): id: Optional[str] = None @@ -239,7 +235,7 @@ class VertexBuildResponse(BaseModel): valid: bool params: Optional[str] """JSON string of the params.""" - data: ResultData + data: ResultDataResponse """Mapping of vertex ids to result dict containing the param name and result value.""" timestamp: Optional[datetime] = Field(default_factory=datetime.utcnow) """Timestamp of the build.""" diff --git a/src/backend/langflow/services/monitor/utils.py b/src/backend/langflow/services/monitor/utils.py index 87f58aa9f..d308e7653 100644 --- a/src/backend/langflow/services/monitor/utils.py +++ b/src/backend/langflow/services/monitor/utils.py @@ -1,12 +1,13 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union import duckdb -from langflow.services.deps import get_monitor_service from loguru import logger from pydantic import BaseModel +from langflow.services.deps import get_monitor_service + if TYPE_CHECKING: - from langflow.api.v1.schemas import ResultData + from langflow.api.v1.schemas import ResultDataResponse INDEX_KEY = "index" @@ -45,7 +46,9 @@ def model_to_sql_column_definitions(model: Type[BaseModel]) -> dict: return columns -def drop_and_create_table_if_schema_mismatch(db_path: str, table_name: str, model: Type[BaseModel]): +def drop_and_create_table_if_schema_mismatch( + db_path: str, table_name: str, model: Type[BaseModel] +): with duckdb.connect(db_path) as conn: # Get the current schema from the database try: @@ -66,8 +69,12 @@ def drop_and_create_table_if_schema_mismatch(db_path: str, table_name: str, mode conn.execute(f"CREATE SEQUENCE seq_{table_name} START 1;") except duckdb.CatalogException: pass - desired_schema[INDEX_KEY] = f"INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_{table_name}')" - columns_sql = ", ".join(f"{name} {data_type}" for name, data_type in desired_schema.items()) + desired_schema[INDEX_KEY] = ( + f"INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_{table_name}')" + ) + columns_sql = ", ".join( + f"{name} {data_type}" for name, data_type in desired_schema.items() + ) create_table_sql = f"CREATE TABLE {table_name} ({columns_sql})" conn.execute(create_table_sql) @@ -138,7 +145,7 @@ async def log_vertex_build( vertex_id: str, valid: bool, params: Any, - data: "ResultData", + data: "ResultDataResponse", artifacts: Optional[dict] = None, ): try: diff --git a/src/backend/langflow/services/socket/utils.py b/src/backend/langflow/services/socket/utils.py index 64ffdc15c..48208403a 100644 --- a/src/backend/langflow/services/socket/utils.py +++ b/src/backend/langflow/services/socket/utils.py @@ -2,14 +2,15 @@ import time from typing import Callable import socketio +from sqlmodel import select + from langflow.api.utils import format_elapsed_time -from langflow.api.v1.schemas import ResultData, VertexBuildResponse +from langflow.api.v1.schemas import ResultDataResponse, VertexBuildResponse from langflow.graph.graph.base import Graph from langflow.graph.vertex.base import StatelessVertex from langflow.services.database.models.flow.model import Flow from langflow.services.deps import get_session from langflow.services.monitor.utils import log_vertex_build -from sqlmodel import select def set_socketio_server(socketio_server): @@ -73,7 +74,7 @@ async def build_vertex( artifacts = vertex.artifacts timedelta = time.perf_counter() - start_time duration = format_elapsed_time(timedelta) - result_dict = ResultData( + result_dict = ResultDataResponse( results=result_dict, artifacts=artifacts, duration=duration, @@ -82,7 +83,7 @@ async def build_vertex( except Exception as exc: params = str(exc) valid = False - result_dict = ResultData(results={}) + result_dict = ResultDataResponse(results={}) artifacts = {} set_cache(flow_id, graph) await log_vertex_build( @@ -95,7 +96,9 @@ async def build_vertex( ) # Emit the vertex build response - response = VertexBuildResponse(valid=valid, params=params, id=vertex.id, data=result_dict) + response = VertexBuildResponse( + valid=valid, params=params, id=vertex.id, data=result_dict + ) await sio.emit("vertex_build", data=response.model_dump(), to=sid) except Exception as exc: