enhance the build_logs function

This commit is contained in:
italojohnny 2024-06-18 02:59:57 -03:00
commit a67fbf90f0

View file

@ -1,6 +1,9 @@
from typing import Literal
from typing import Literal, Union, Generator
from enum import Enum
from pydantic import BaseModel
from langflow.schema.message import Message
from langflow.schema import Data
from typing_extensions import TypedDict
INPUT_FIELD_NAME = "input_value"
@ -8,26 +11,88 @@ InputType = Literal["chat", "text", "any"]
OutputType = Literal["chat", "text", "any", "debug"]
class StreamURL(TypedDict):
class LogType(str, Enum):
MESSAGE = "message"
DATA = "data"
STREAM = "stream"
OBJECT = "object"
ARRAY = "array"
TEXT = "text"
UNKNOWN = "unknown"
class StreamURL(BaseModel):
location: str
class Log(TypedDict):
message: str | dict | StreamURL | list
class Log(BaseModel):
message: Union[StreamURL, dict, list, str]
type: str
def build_logs(vertex) -> dict:
logs = {}
for key in vertex.artifacts:
message = vertex.artifacts[key]["raw"]
_type = vertex.artifacts[key]["type"]
def get_type(payload):
result = LogType.UNKNOWN
match payload:
case Message():
result = LogType.MESSAGE
if "stream_url" in message and "type" in message:
stream_url = StreamURL(location=message["stream_url"])
log = Log(message=stream_url, type=_type)
elif _type:
log = Log(message=message, type=_type)
case Data():
result = LogType.DATA
logs[key] = [log]
return logs
case dict():
result = LogType.OBJECT
case list():
result = LogType.ARRAY
case str():
result = LogType.TEXT
if result == LogType.UNKNOWN:
if payload and isinstance(payload, Generator):
result = LogType.STREAM
elif isinstance(payload, Message) and isinstance(payload.text, Generator):
result = LogType.STREAM
return result
def get_message(payload):
message = None
if hasattr(payload, "data"):
message = payload.data
elif hasattr(payload, "model_dump"):
message = payload.model_dump()
if message is None and isinstance(payload, (dict, str, Data)):
message = payload.data if isinstance(payload, Data) else payload
return message or payload
def build_logs(vertex, result) -> dict:
logs = dict()
payload = result[0].repr_value
for index, output in enumerate(vertex.outputs):
message = get_message(payload)
_type = get_type(payload)
match _type:
case LogType.STREAM if "stream_url" in message:
message = StreamURL(location=message["stream_url"])
case LogType.STREAM:
message = ""
case LogType.MESSAGE if hasattr(message, "message"):
message = message.message
case LogType.UNKNOWN if message is None:
message = ""
name = output.get("name", f"output_{index}")
logs |= {name: Log(message=message, type=_type).model_dump()}
return {} # logs # TODO