From c8a72aaeca9248c33c55608fddb0516b7e46d442 Mon Sep 17 00:00:00 2001 From: ogabrielluiz Date: Wed, 5 Jun 2024 11:39:03 -0300 Subject: [PATCH] refactor: Add StreamURL and Log types to schema.py and update ChatOutputResponse in utils/schemas.py --- src/backend/base/langflow/api/v1/chat.py | 6 ++++-- src/backend/base/langflow/api/v1/schemas.py | 10 ++-------- src/backend/base/langflow/graph/graph/base.py | 5 +++-- src/backend/base/langflow/graph/schema.py | 17 ++++++++++++++++- src/backend/base/langflow/graph/utils.py | 12 +++++++++--- src/backend/base/langflow/graph/vertex/base.py | 5 +++-- src/backend/base/langflow/graph/vertex/types.py | 12 +++++++----- .../langflow/interface/initialize/loading.py | 7 ++++--- 8 files changed, 48 insertions(+), 26 deletions(-) diff --git a/src/backend/base/langflow/api/v1/chat.py b/src/backend/base/langflow/api/v1/chat.py index 9ee3d8f4f..6e2a4dd35 100644 --- a/src/backend/base/langflow/api/v1/chat.py +++ b/src/backend/base/langflow/api/v1/chat.py @@ -17,12 +17,12 @@ from langflow.api.utils import ( from langflow.api.v1.schemas import ( FlowDataRequest, InputValueRequest, - Log, ResultDataResponse, StreamData, VertexBuildResponse, VerticesOrderResponse, ) +from langflow.schema.schema import Log from langflow.services.auth.utils import get_current_active_user from langflow.services.chat.service import ChatService from langflow.services.deps import get_chat_service, get_session, get_session_service @@ -161,6 +161,7 @@ async def build_vertex( else: graph = cache.get("result") vertex = graph.get_vertex(vertex_id) + log_object = None try: lock = chat_service._cache_locks[flow_id_str] ( @@ -179,6 +180,7 @@ async def build_vertex( inputs_dict=inputs.model_dump() if inputs else {}, files=files, ) + result_data_response = ResultDataResponse(**result_dict.model_dump()) except Exception as exc: @@ -187,12 +189,12 @@ async def build_vertex( log_type = type(exc).__name__ valid = False result_data_response = ResultDataResponse(results={}) + log_object = Log(message=log_message, type=log_type) # If there's an error building the vertex # we need to clear the cache await chat_service.clear_cache(flow_id_str) - log_object = Log(message=log_message, type=log_type) result_data_response.logs.append(log_object) # Log the vertex build diff --git a/src/backend/base/langflow/api/v1/schemas.py b/src/backend/base/langflow/api/v1/schemas.py index d2f95b763..4e8915842 100644 --- a/src/backend/base/langflow/api/v1/schemas.py +++ b/src/backend/base/langflow/api/v1/schemas.py @@ -2,8 +2,6 @@ from datetime import datetime, timezone from enum import Enum from pathlib import Path from typing import Any, Dict, List, Optional, Union -from langflow.utils.schemas import ChatOutputResponse -from typing_extensions import TypedDict from uuid import UUID from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serializer @@ -11,11 +9,12 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator, model_serial from langflow.graph.schema import RunOutputs from langflow.schema import dotdict from langflow.schema.graph import Tweaks -from langflow.schema.schema import InputType, OutputType +from langflow.schema.schema import InputType, Log, OutputType from langflow.services.database.models.api_key.model import ApiKeyRead from langflow.services.database.models.base import orjson_dumps from langflow.services.database.models.flow import FlowCreate, FlowRead from langflow.services.database.models.user import UserRead +from langflow.utils.schemas import ChatOutputResponse class BuildStatus(Enum): @@ -245,11 +244,6 @@ class VerticesOrderResponse(BaseModel): vertices_to_run: List[str] -class Log(TypedDict): - message: Union[dict, str] - type: str - - class ResultDataResponse(BaseModel): results: Optional[Any] = Field(default_factory=dict) logs: List[Log | None] = Field(default_factory=list) diff --git a/src/backend/base/langflow/graph/graph/base.py b/src/backend/base/langflow/graph/graph/base.py index bbc8417de..9639f8745 100644 --- a/src/backend/base/langflow/graph/graph/base.py +++ b/src/backend/base/langflow/graph/graph/base.py @@ -17,7 +17,6 @@ from langflow.graph.vertex.base import Vertex from langflow.graph.vertex.types import InterfaceVertex, StateVertex from langflow.schema import Record from langflow.schema.schema import INPUT_FIELD_NAME, InputType -from langflow.services.cache.utils import CacheMiss from langflow.services.chat.service import ChatService from langflow.services.deps import get_chat_service from langflow.services.monitor.utils import log_transaction @@ -734,7 +733,9 @@ class Graph: vertex = self.get_vertex(vertex_id) try: if not vertex.frozen or not vertex._built: - await vertex.build(user_id=user_id, inputs=inputs_dict,files=files, fallback_to_env_vars=fallback_to_env_vars) + await vertex.build( + user_id=user_id, inputs=inputs_dict, files=files, fallback_to_env_vars=fallback_to_env_vars + ) if vertex.result is not None: params = vertex.artifacts_raw diff --git a/src/backend/base/langflow/graph/schema.py b/src/backend/base/langflow/graph/schema.py index 60e7ab590..82cfa2930 100644 --- a/src/backend/base/langflow/graph/schema.py +++ b/src/backend/base/langflow/graph/schema.py @@ -1,15 +1,17 @@ from enum import Enum from typing import Any, List, Optional -from pydantic import BaseModel, Field, field_serializer +from pydantic import BaseModel, Field, field_serializer, model_validator from langflow.graph.utils import serialize_field +from langflow.schema.schema import Log, StreamURL from langflow.utils.schemas import ChatOutputResponse, ContainsEnumMeta class ResultData(BaseModel): results: Optional[Any] = Field(default_factory=dict) artifacts: Optional[Any] = Field(default_factory=dict) + logs: Optional[List[dict]] = Field(default_factory=list) messages: Optional[list[ChatOutputResponse]] = Field(default_factory=list) timedelta: Optional[float] = None duration: Optional[str] = None @@ -23,6 +25,19 @@ class ResultData(BaseModel): return {key: serialize_field(val) for key, val in value.items()} return serialize_field(value) + @model_validator(mode="before") + @classmethod + def validate_model(cls, values): + if not values.get("logs") and values.get("artifacts"): + # Build the log from the artifacts + message = values["artifacts"] + if "stream_url" in message: + stream_url = StreamURL(location=message["stream_url"]) + values["logs"] = [Log(message=stream_url, type=message["type"])] + else: + values["logs"] = [Log(message=message, type=message["type"])] + return values + class InterfaceComponentTypes(str, Enum, metaclass=ContainsEnumMeta): # ChatInput and ChatOutput are the only ones that are diff --git a/src/backend/base/langflow/graph/utils.py b/src/backend/base/langflow/graph/utils.py index bdb2be5c1..066d7511a 100644 --- a/src/backend/base/langflow/graph/utils.py +++ b/src/backend/base/langflow/graph/utils.py @@ -1,11 +1,11 @@ -from typing import Any, Union, Generator from enum import Enum +from typing import Any, Generator, Union from langchain_core.documents import Document -from langflow.schema.schema import Record from pydantic import BaseModel from langflow.interface.utils import extract_input_variables_from_prompt +from langflow.schema.schema import Record class UnbuiltObject: @@ -79,9 +79,15 @@ def get_artifact_type(custom_component, build_result) -> str: case list(): result = ArtifactType.ARRAY - if result == ArtifactType.UNKNOWN: if isinstance(build_result, Generator): result = ArtifactType.STREAM return result.value + + +def post_process_raw(raw, artifact_type: str): + if artifact_type == ArtifactType.STREAM.value: + raw = "" + + return raw diff --git a/src/backend/base/langflow/graph/vertex/base.py b/src/backend/base/langflow/graph/vertex/base.py index 3cd646d03..be5c85861 100644 --- a/src/backend/base/langflow/graph/vertex/base.py +++ b/src/backend/base/langflow/graph/vertex/base.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Dict, Iterator, from loguru import logger from langflow.graph.schema import INPUT_COMPONENTS, OUTPUT_COMPONENTS, InterfaceComponentTypes, ResultData -from langflow.graph.utils import UnbuiltObject, UnbuiltResult, ArtifactType +from langflow.graph.utils import ArtifactType, UnbuiltObject, UnbuiltResult from langflow.graph.vertex.utils import log_transaction from langflow.interface.initialize import loading from langflow.interface.listing import lazy_load_dict @@ -428,8 +428,10 @@ class Vertex: sender=artifacts.get("sender"), sender_name=artifacts.get("sender_name"), session_id=artifacts.get("session_id"), + stream_url=artifacts.get("stream_url"), files=[{"path": file} if isinstance(file, str) else file for file in artifacts.get("files", [])], component_id=self.id, + type=self.artifacts_type, ).model_dump(exclude_none=True) ] except KeyError: @@ -447,7 +449,6 @@ class Vertex: messages = self.extract_messages_from_artifacts(artifacts) else: messages = [] - result_dict = ResultData( results=result_dict, artifacts=artifacts, diff --git a/src/backend/base/langflow/graph/vertex/types.py b/src/backend/base/langflow/graph/vertex/types.py index 9418dde35..16a2a0f0e 100644 --- a/src/backend/base/langflow/graph/vertex/types.py +++ b/src/backend/base/langflow/graph/vertex/types.py @@ -2,11 +2,11 @@ import json from typing import AsyncIterator, Dict, Iterator, List import yaml -from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessage, AIMessageChunk from loguru import logger from langflow.graph.schema import CHAT_COMPONENTS, RECORDS_COMPONENTS, InterfaceComponentTypes -from langflow.graph.utils import UnbuiltObject, serialize_field +from langflow.graph.utils import ArtifactType, UnbuiltObject, serialize_field from langflow.graph.vertex.base import Vertex from langflow.schema import Record from langflow.schema.schema import INPUT_FIELD_NAME @@ -87,7 +87,7 @@ class InterfaceVertex(Vertex): if isinstance(message, str): message = unescape_string(message) stream_url = None - if isinstance(self._built_object, AIMessage): + if isinstance(self._built_object, (AIMessage, AIMessageChunk)): artifacts = ChatOutputResponse.from_message( self._built_object, sender=sender, @@ -109,13 +109,14 @@ class InterfaceVertex(Vertex): # it means that it is a stream of messages else: message = self._built_object - + artifact_type = ArtifactType.STREAM if stream_url is not None else ArtifactType.OBJECT artifacts = ChatOutputResponse( message=message, sender=sender, sender_name=sender_name, stream_url=stream_url, - files=files + files=files, + type=artifact_type.value, ) self.will_stream = stream_url is not None @@ -198,6 +199,7 @@ class InterfaceVertex(Vertex): sender=self.params.get("sender", ""), sender_name=self.params.get("sender_name", ""), files=[{"path": file} if isinstance(file, str) else file for file in self.params.get("files", [])], + type=ArtifactType.OBJECT.value, ).model_dump() self.params[INPUT_FIELD_NAME] = complete_message self._built_object = Record(text=complete_message, data=self.artifacts) diff --git a/src/backend/base/langflow/interface/initialize/loading.py b/src/backend/base/langflow/interface/initialize/loading.py index f9e360fb1..163587fa0 100644 --- a/src/backend/base/langflow/interface/initialize/loading.py +++ b/src/backend/base/langflow/interface/initialize/loading.py @@ -7,9 +7,8 @@ import orjson from loguru import logger from langflow.custom.eval import eval_custom_component_code +from langflow.graph.utils import get_artifact_type, post_process_raw from langflow.schema.schema import Record -from langflow.graph.utils import get_artifact_type - if TYPE_CHECKING: from langflow.custom import CustomComponent @@ -134,5 +133,7 @@ async def instantiate_custom_component(params, user_id, vertex, fallback_to_env_ elif hasattr(raw, "model_dump"): raw = raw.model_dump() - artifact = {"repr": custom_repr, "raw": raw, "type": get_artifact_type(custom_component, build_result)} + artifact_type = get_artifact_type(custom_component, build_result) + raw = post_process_raw(raw, artifact_type) + artifact = {"repr": custom_repr, "raw": raw, "type": artifact_type} return custom_component, build_result, artifact