refactor: Add StreamURL and Log types to schema.py and update ChatOutputResponse in utils/schemas.py

This commit is contained in:
ogabrielluiz 2024-06-05 11:39:03 -03:00
commit c8a72aaeca
8 changed files with 48 additions and 26 deletions

View file

@ -17,12 +17,12 @@ from langflow.api.utils import (
from langflow.api.v1.schemas import (
FlowDataRequest,
InputValueRequest,
Log,
ResultDataResponse,
StreamData,
VertexBuildResponse,
VerticesOrderResponse,
)
from langflow.schema.schema import Log
from langflow.services.auth.utils import get_current_active_user
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service, get_session, get_session_service
@ -161,6 +161,7 @@ async def build_vertex(
else:
graph = cache.get("result")
vertex = graph.get_vertex(vertex_id)
log_object = None
try:
lock = chat_service._cache_locks[flow_id_str]
(
@ -179,6 +180,7 @@ async def build_vertex(
inputs_dict=inputs.model_dump() if inputs else {},
files=files,
)
result_data_response = ResultDataResponse(**result_dict.model_dump())
except Exception as exc:
@ -187,12 +189,12 @@ async def build_vertex(
log_type = type(exc).__name__
valid = False
result_data_response = ResultDataResponse(results={})
log_object = Log(message=log_message, type=log_type)
# If there's an error building the vertex
# we need to clear the cache
await chat_service.clear_cache(flow_id_str)
log_object = Log(message=log_message, type=log_type)
result_data_response.logs.append(log_object)
# Log the vertex build

View file

@ -2,8 +2,6 @@ from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from langflow.utils.schemas import ChatOutputResponse
from typing_extensions import TypedDict
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serializer
@ -11,11 +9,12 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serial
from langflow.graph.schema import RunOutputs
from langflow.schema import dotdict
from langflow.schema.graph import Tweaks
from langflow.schema.schema import InputType, OutputType
from langflow.schema.schema import InputType, Log, OutputType
from langflow.services.database.models.api_key.model import ApiKeyRead
from langflow.services.database.models.base import orjson_dumps
from langflow.services.database.models.flow import FlowCreate, FlowRead
from langflow.services.database.models.user import UserRead
from langflow.utils.schemas import ChatOutputResponse
class BuildStatus(Enum):
@ -245,11 +244,6 @@ class VerticesOrderResponse(BaseModel):
vertices_to_run: List[str]
class Log(TypedDict):
message: Union[dict, str]
type: str
class ResultDataResponse(BaseModel):
results: Optional[Any] = Field(default_factory=dict)
logs: List[Log | None] = Field(default_factory=list)

View file

@ -17,7 +17,6 @@ from langflow.graph.vertex.base import Vertex
from langflow.graph.vertex.types import InterfaceVertex, StateVertex
from langflow.schema import Record
from langflow.schema.schema import INPUT_FIELD_NAME, InputType
from langflow.services.cache.utils import CacheMiss
from langflow.services.chat.service import ChatService
from langflow.services.deps import get_chat_service
from langflow.services.monitor.utils import log_transaction
@ -734,7 +733,9 @@ class Graph:
vertex = self.get_vertex(vertex_id)
try:
if not vertex.frozen or not vertex._built:
await vertex.build(user_id=user_id, inputs=inputs_dict,files=files, fallback_to_env_vars=fallback_to_env_vars)
await vertex.build(
user_id=user_id, inputs=inputs_dict, files=files, fallback_to_env_vars=fallback_to_env_vars
)
if vertex.result is not None:
params = vertex.artifacts_raw

View file

@ -1,15 +1,17 @@
from enum import Enum
from typing import Any, List, Optional
from pydantic import BaseModel, Field, field_serializer
from pydantic import BaseModel, Field, field_serializer, model_validator
from langflow.graph.utils import serialize_field
from langflow.schema.schema import Log, StreamURL
from langflow.utils.schemas import ChatOutputResponse, ContainsEnumMeta
class ResultData(BaseModel):
results: Optional[Any] = Field(default_factory=dict)
artifacts: Optional[Any] = Field(default_factory=dict)
logs: Optional[List[dict]] = Field(default_factory=list)
messages: Optional[list[ChatOutputResponse]] = Field(default_factory=list)
timedelta: Optional[float] = None
duration: Optional[str] = None
@ -23,6 +25,19 @@ class ResultData(BaseModel):
return {key: serialize_field(val) for key, val in value.items()}
return serialize_field(value)
@model_validator(mode="before")
@classmethod
def validate_model(cls, values):
if not values.get("logs") and values.get("artifacts"):
# Build the log from the artifacts
message = values["artifacts"]
if "stream_url" in message:
stream_url = StreamURL(location=message["stream_url"])
values["logs"] = [Log(message=stream_url, type=message["type"])]
else:
values["logs"] = [Log(message=message, type=message["type"])]
return values
class InterfaceComponentTypes(str, Enum, metaclass=ContainsEnumMeta):
# ChatInput and ChatOutput are the only ones that are

View file

@ -1,11 +1,11 @@
from typing import Any, Union, Generator
from enum import Enum
from typing import Any, Generator, Union
from langchain_core.documents import Document
from langflow.schema.schema import Record
from pydantic import BaseModel
from langflow.interface.utils import extract_input_variables_from_prompt
from langflow.schema.schema import Record
class UnbuiltObject:
@ -79,9 +79,15 @@ def get_artifact_type(custom_component, build_result) -> str:
case list():
result = ArtifactType.ARRAY
if result == ArtifactType.UNKNOWN:
if isinstance(build_result, Generator):
result = ArtifactType.STREAM
return result.value
def post_process_raw(raw, artifact_type: str):
if artifact_type == ArtifactType.STREAM.value:
raw = ""
return raw

View file

@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, Iterator,
from loguru import logger
from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData
from langflow.graph.utils import UnbuiltObject, UnbuiltResult, ArtifactType
from langflow.graph.utils import ArtifactType, UnbuiltObject, UnbuiltResult
from langflow.graph.vertex.utils import log_transaction
from langflow.interface.initialize import loading
from langflow.interface.listing import lazy_load_dict
@ -428,8 +428,10 @@ class Vertex:
sender=artifacts.get("sender"),
sender_name=artifacts.get("sender_name"),
session_id=artifacts.get("session_id"),
stream_url=artifacts.get("stream_url"),
files=[{"path": file} if isinstance(file, str) else file for file in artifacts.get("files", [])],
component_id=self.id,
type=self.artifacts_type,
).model_dump(exclude_none=True)
]
except KeyError:
@ -447,7 +449,6 @@ class Vertex:
messages = self.extract_messages_from_artifacts(artifacts)
else:
messages = []
result_dict = ResultData(
results=result_dict,
artifacts=artifacts,

View file

@ -2,11 +2,11 @@ import json
from typing import AsyncIterator, Dict, Iterator, List
import yaml
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, AIMessageChunk
from loguru import logger
from langflow.graph.schema import CHAT_COMPONENTS, RECORDS_COMPONENTS, InterfaceComponentTypes
from langflow.graph.utils import UnbuiltObject, serialize_field
from langflow.graph.utils import ArtifactType, UnbuiltObject, serialize_field
from langflow.graph.vertex.base import Vertex
from langflow.schema import Record
from langflow.schema.schema import INPUT_FIELD_NAME
@ -87,7 +87,7 @@ class InterfaceVertex(Vertex):
if isinstance(message, str):
message = unescape_string(message)
stream_url = None
if isinstance(self._built_object, AIMessage):
if isinstance(self._built_object, (AIMessage, AIMessageChunk)):
artifacts = ChatOutputResponse.from_message(
self._built_object,
sender=sender,
@ -109,13 +109,14 @@ class InterfaceVertex(Vertex):
# it means that it is a stream of messages
else:
message = self._built_object
artifact_type = ArtifactType.STREAM if stream_url is not None else ArtifactType.OBJECT
artifacts = ChatOutputResponse(
message=message,
sender=sender,
sender_name=sender_name,
stream_url=stream_url,
files=files
files=files,
type=artifact_type.value,
)
self.will_stream = stream_url is not None
@ -198,6 +199,7 @@ class InterfaceVertex(Vertex):
sender=self.params.get("sender", ""),
sender_name=self.params.get("sender_name", ""),
files=[{"path": file} if isinstance(file, str) else file for file in self.params.get("files", [])],
type=ArtifactType.OBJECT.value,
).model_dump()
self.params[INPUT_FIELD_NAME] = complete_message
self._built_object = Record(text=complete_message, data=self.artifacts)

View file

@ -7,9 +7,8 @@ import orjson
from loguru import logger
from langflow.custom.eval import eval_custom_component_code
from langflow.graph.utils import get_artifact_type, post_process_raw
from langflow.schema.schema import Record
from langflow.graph.utils import get_artifact_type
if TYPE_CHECKING:
from langflow.custom import CustomComponent
@ -134,5 +133,7 @@ async def instantiate_custom_component(params, user_id, vertex, fallback_to_env_
elif hasattr(raw, "model_dump"):
raw = raw.model_dump()
artifact = {"repr": custom_repr, "raw": raw, "type": get_artifact_type(custom_component, build_result)}
artifact_type = get_artifact_type(custom_component, build_result)
raw = post_process_raw(raw, artifact_type)
artifact = {"repr": custom_repr, "raw": raw, "type": artifact_type}
return custom_component, build_result, artifact