Refactor model_to_sql_column_definitions function and add log_vertex_build function

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-30 23:06:02 -03:00
commit 84f810f28a

View file

@ -1,10 +1,15 @@
from typing import Any, Dict, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, Optional, Type
import duckdb
from langflow.services.deps import get_monitor_service
from loguru import logger
from pydantic import BaseModel
from langflow.services.deps import get_monitor_service
if TYPE_CHECKING:
from langflow.api.v1.schemas import ResultDict
INDEX_KEY = "index"
def get_table_schema_as_dict(conn: duckdb.DuckDBPyConnection, table_name: str) -> dict:
@ -14,8 +19,12 @@ def get_table_schema_as_dict(conn: duckdb.DuckDBPyConnection, table_name: str) -
def model_to_sql_column_definitions(model: Type[BaseModel]) -> dict:
columns = {}
for field_name, field_type in model.__fields__.items():
field_info = field_type.type_
for field_name, field_type in model.model_fields.items():
if hasattr(field_type.annotation, "__args__"):
field_args = field_type.annotation.__args__
else:
field_args = []
field_info = field_args[0] if field_args else field_type.annotation
if field_info.__name__ == "int":
sql_type = "INTEGER"
elif field_info.__name__ == "str":
@ -26,6 +35,8 @@ def model_to_sql_column_definitions(model: Type[BaseModel]) -> dict:
sql_type = "BOOLEAN"
elif field_info.__name__ == "dict":
sql_type = "JSON"
elif field_info.__name__ == "Any":
sql_type = "VARCHAR"
else:
continue # Skip types we don't handle
columns[field_name] = sql_type
@ -52,7 +63,7 @@ def drop_and_create_table_if_schema_mismatch(db_path: str, table_name: str, mode
conn.execute(f"CREATE SEQUENCE seq_{table_name} START 1;")
except duckdb.CatalogException:
pass
desired_schema["id"] = f"INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_{table_name}')"
desired_schema[INDEX_KEY] = f"INTEGER PRIMARY KEY DEFAULT NEXTVAL('seq_{table_name}')"
columns_sql = ", ".join(f"{name} {data_type}" for name, data_type in desired_schema.items())
create_table_sql = f"CREATE TABLE {table_name} ({columns_sql})"
conn.execute(create_table_sql)
@ -69,7 +80,7 @@ def add_row_to_table(
# Extract data for the insert statement
validated_dict = validated_data.model_dump(exclude_unset=True)
keys = [key for key in validated_dict.keys() if key != "id"]
keys = [key for key in validated_dict.keys() if key != INDEX_KEY]
columns = ", ".join(keys)
values_placeholders = ", ".join(["?" for _ in keys])
@ -107,3 +118,27 @@ async def log_message(
monitor_service.add_row(table_name="messages", data=row)
except Exception as e:
logger.error(f"Error logging message: {e}")
async def log_vertex_build(
flow_id: str,
vertex_id: str,
valid: bool,
params: Any,
data: "ResultDict",
artifacts: Optional[dict] = None,
):
try:
monitor_service = get_monitor_service()
row = {
"flow_id": flow_id,
"id": vertex_id,
"valid": valid,
"params": params,
"data": data.model_dump(),
"artifacts": artifacts or {},
# "timestamp": monitor_service.get_timestamp(),
}
monitor_service.add_row(table_name="vertex_builds", data=row)
except Exception as e:
logger.error(f"Error logging vertex build: {e}")