add function to determine artifact type
This commit is contained in:
parent
be03ef094d
commit
a5f9ec0339
1 changed files with 31 additions and 1 deletions
|
|
@ -1,6 +1,8 @@
|
|||
from typing import Any, Union
|
||||
from typing import Any, Union, Generator
|
||||
from enum import Enum
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langflow.schema.schema import Record
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
|
@ -14,6 +16,14 @@ class UnbuiltResult:
|
|||
pass
|
||||
|
||||
|
||||
class ArtifactType(str, Enum):
|
||||
TEXT = "text"
|
||||
RECORD = "record"
|
||||
OBJECT = "object"
|
||||
STREAM = "stream"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
def validate_prompt(prompt: str):
|
||||
"""Validate prompt."""
|
||||
if extract_input_variables_from_prompt(prompt):
|
||||
|
|
@ -50,3 +60,23 @@ def serialize_field(value):
|
|||
elif isinstance(value, str):
|
||||
return {"result": value}
|
||||
return value
|
||||
|
||||
|
||||
def get_artifact_type(custom_component, build_result) -> str:
|
||||
result = ArtifactType.UNKNOWN
|
||||
value = custom_component.repr_value
|
||||
match value:
|
||||
case Record():
|
||||
result = ArtifactType.RECORD
|
||||
|
||||
case str():
|
||||
result = ArtifactType.TEXT
|
||||
|
||||
case dict():
|
||||
result = ArtifactType.OBJECT
|
||||
|
||||
if result == ArtifactType.UNKNOWN:
|
||||
if isinstance(build_result, Generator):
|
||||
result = ArtifactType.STREAM
|
||||
|
||||
return result.value
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue