fix stream

This commit is contained in:
italojohnny 2024-05-31 19:27:23 -03:00
commit be03ef094d
7 changed files with 53 additions and 30 deletions

View file

@ -169,7 +169,7 @@ async def build_vertex(
next_runnable_vertices,
top_level_vertices,
result_dict,
log_message,
params,
valid,
_,
vertex,
@ -185,10 +185,10 @@ async def build_vertex(
except Exception as exc:
logger.exception(f"Error building vertex: {exc}")
log_message = format_exception_message(exc)
params = format_exception_message(exc)
valid = False
result_data_response = ResultDataResponse(results={})
artifacts = {}
# If there's an error building the vertex
# we need to clear the cache
await chat_service.clear_cache(flow_id_str)
@ -203,21 +203,21 @@ async def build_vertex(
flow_id=flow_id_str,
vertex_id=vertex_id,
valid=valid,
logs=result_data_response.logs,
params=params,
data=result_data_response,
artifacts=artifacts,
)
timedelta = time.perf_counter() - start_time
duration = format_elapsed_time(timedelta)
result_data_response.duration = duration
result_data_response.timedelta = timedelta
async with chat_service._cache_locks[flow_id] as lock:
vertex.add_build_time(timedelta)
inactivated_vertices = None
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
await chat_service.set_cache(flow_id=flow_id, data=graph, lock=lock)
vertex.add_build_time(timedelta)
inactivated_vertices = None
inactivated_vertices = list(graph.inactivated_vertices)
graph.reset_inactivated_vertices()
graph.reset_activated_vertices()
await chat_service.set_cache(flow_id_str, graph)
# graph.stop_vertex tells us if the user asked
# to stop the build of the graph at a certain vertex
@ -231,6 +231,7 @@ async def build_vertex(
next_vertices_ids=next_runnable_vertices,
top_level_vertices=top_level_vertices,
valid=valid,
params=params,
id=vertex.id,
data=result_data_response,
)

View file

@ -2,7 +2,6 @@ from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langflow.utils.schemas import ChatOutputResponse
from typing_extensions import TypedDict
from uuid import UUID
@ -252,7 +251,8 @@ class Log(TypedDict):
class ResultDataResponse(BaseModel):
results: Optional[Any] = Field(default_factory=dict)
logs: List[Log | None] = Field(default_factory=list)
messages: List[ChatOutputResponse | None] = Field(default_factory=list)
message: Optional[Any] = Field(default_factory=dict)
artifacts: Optional[Any] = Field(default_factory=dict)
timedelta: Optional[float] = None
duration: Optional[str] = None
@ -263,6 +263,8 @@ class VertexBuildResponse(BaseModel):
next_vertices_ids: Optional[List[str]] = None
top_level_vertices: Optional[List[str]] = None
valid: bool
params: Optional[Any] = Field(default_factory=dict)
"""JSON string of the params."""
data: ResultDataResponse
"""Mapping of vertex ids to result dict containing the param name and result value."""
timestamp: Optional[datetime] = Field(default_factory=lambda: datetime.now(timezone.utc))

View file

@ -496,9 +496,9 @@ class InterfaceVertex(Vertex):
flow_id=self.graph.flow_id,
vertex_id=self.id,
valid=True,
logs=self._built_object_repr(),
params=self._built_object_repr(),
data=self.result,
messages=self.artifacts,
artifacts=self.artifacts,
)
self._validate_built_object()

View file

@ -76,6 +76,7 @@ class MessageModel(BaseModel):
session_id: str
message: str
files: list[str] = []
artifacts: dict
class Config:
from_attributes = True
@ -87,6 +88,12 @@ class MessageModel(BaseModel):
return json.loads(v)
return v
@field_validator("artifacts", mode="before")
def validate_target_args(cls, v):
if isinstance(v, str):
return json.loads(v)
return v
@classmethod
def from_record(cls, record: "Record", flow_id: Optional[str] = None):
# first check if the record has all the required fields
@ -107,6 +114,12 @@ class MessageModel(BaseModel):
class MessageModelResponse(MessageModel):
index: Optional[int] = Field(default=None)
@field_validator("artifacts", mode="before")
def serialize_artifacts(v):
if isinstance(v, str):
return json.loads(v)
return v
@field_validator("index", mode="before")
def validate_id(cls, v):
if isinstance(v, float):
@ -122,15 +135,16 @@ class VertexBuildModel(BaseModel):
id: Optional[str] = Field(default=None, alias="id")
flow_id: str
valid: bool
logs: Any
params: Any
data: dict
artifacts: dict
timestamp: datetime = Field(default_factory=datetime.now)
class Config:
from_attributes = True
populate_by_name = True
@field_serializer("data")
@field_serializer("data", "artifacts")
def serialize_dict(v):
if isinstance(v, dict):
# check if the value of each key is a BaseModel or a list of BaseModels
@ -144,8 +158,8 @@ class VertexBuildModel(BaseModel):
return v.model_dump_json()
return v
@field_validator("logs", mode="before")
def validate_logs(cls, v):
@field_validator("params", mode="before")
def validate_params(cls, v):
if isinstance(v, str):
try:
return json.loads(v)
@ -153,7 +167,7 @@ class VertexBuildModel(BaseModel):
return v
return v
@field_serializer("logs")
@field_serializer("params")
def serialize_params(v):
if isinstance(v, list) and all(isinstance(i, BaseModel) for i in v):
return json.dumps([i.model_dump() for i in v])
@ -165,11 +179,17 @@ class VertexBuildModel(BaseModel):
return json.loads(v)
return v
@field_validator("artifacts", mode="before")
def validate_artifacts(cls, v):
if isinstance(v, str):
return json.loads(v)
elif isinstance(v, BaseModel):
return v.model_dump()
return v
class VertexBuildResponseModel(VertexBuildModel):
messages: list[MessageModel] = []
@field_serializer("data")
@field_serializer("data", "artifacts")
def serialize_dict(v):
return v

View file

@ -69,7 +69,7 @@ class MonitorService(Service):
valid: Optional[bool] = None,
order_by: Optional[str] = "timestamp",
):
query = "SELECT id, flow_id, valid, logs, data, timestamp FROM vertex_builds"
query = "SELECT index,flow_id, valid, params, data, artifacts, timestamp FROM vertex_builds"
conditions = []
if flow_id:
conditions.append(f"flow_id = '{flow_id}'")

View file

@ -146,9 +146,9 @@ async def log_vertex_build(
flow_id: str,
vertex_id: str,
valid: bool,
logs: Any,
params: Any,
data: "ResultDataResponse",
messages: Optional[dict] = None,
artifacts: Optional[dict] = None,
):
try:
monitor_service = get_monitor_service()
@ -157,9 +157,9 @@ async def log_vertex_build(
"flow_id": flow_id,
"id": vertex_id,
"valid": valid,
"logs": logs,
"params": params,
"data": data.model_dump(),
"messages": messages or {},
"artifacts": artifacts or {},
"timestamp": monitor_service.get_timestamp(),
}
monitor_service.add_row(table_name="vertex_builds", data=row)

View file

@ -90,9 +90,9 @@ async def build_vertex(
flow_id=flow_id,
vertex_id=vertex_id,
valid=valid,
logs=params,
params=params,
data=result_dict,
messages=artifacts,
artifacts=artifacts,
)
# Emit the vertex build response