From a5f9ec03399abff696866a7143d576b829762ba5 Mon Sep 17 00:00:00 2001 From: italojohnny Date: Fri, 31 May 2024 19:15:23 -0300 Subject: [PATCH] add function to determine artifact type --- src/backend/base/langflow/graph/utils.py | 32 +++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/src/backend/base/langflow/graph/utils.py b/src/backend/base/langflow/graph/utils.py index 83e2177b1..86fbebc4a 100644 --- a/src/backend/base/langflow/graph/utils.py +++ b/src/backend/base/langflow/graph/utils.py @@ -1,6 +1,8 @@ -from typing import Any, Union +from typing import Any, Union, Generator +from enum import Enum from langchain_core.documents import Document +from langflow.schema.schema import Record from pydantic import BaseModel from langflow.interface.utils import extract_input_variables_from_prompt @@ -14,6 +16,14 @@ class UnbuiltResult: pass +class ArtifactType(str, Enum): + TEXT = "text" + RECORD = "record" + OBJECT = "object" + STREAM = "stream" + UNKNOWN = "unknown" + + def validate_prompt(prompt: str): """Validate prompt.""" if extract_input_variables_from_prompt(prompt): @@ -50,3 +60,23 @@ def serialize_field(value): elif isinstance(value, str): return {"result": value} return value + + +def get_artifact_type(custom_component, build_result) -> str: + result = ArtifactType.UNKNOWN + value = custom_component.repr_value + match value: + case Record(): + result = ArtifactType.RECORD + + case str(): + result = ArtifactType.TEXT + + case dict(): + result = ArtifactType.OBJECT + + if result == ArtifactType.UNKNOWN: + if isinstance(build_result, Generator): + result = ArtifactType.STREAM + + return result.value