add function to determine artifact type
This commit is contained in:
parent
7f6385a609
commit
0dcfc21e05
1 changed files with 23 additions and 0 deletions
|
|
@ -1,6 +1,8 @@
|
|||
from typing import Any, Union
|
||||
from enum import Enum
|
||||
|
||||
from langchain_core.documents import Document
|
||||
from langflow.schema.schema import Record
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
|
|
@ -14,6 +16,12 @@ class UnbuiltResult:
|
|||
pass
|
||||
|
||||
|
||||
class ArtifactType(str, Enum):
|
||||
TEXT = "text"
|
||||
RECORD = "record"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
def validate_prompt(prompt: str):
|
||||
"""Validate prompt."""
|
||||
if extract_input_variables_from_prompt(prompt):
|
||||
|
|
@ -50,3 +58,18 @@ def serialize_field(value):
|
|||
elif isinstance(value, str):
|
||||
return {"result": value}
|
||||
return value
|
||||
|
||||
|
||||
def get_artifact_type(build_result: Any) -> str:
|
||||
result = None
|
||||
match build_result:
|
||||
case Record():
|
||||
result = ArtifactType.RECORD
|
||||
|
||||
case str():
|
||||
result = ArtifactType.TEXT
|
||||
|
||||
case _:
|
||||
result = ArtifactType.UNKNOWN
|
||||
|
||||
return result.value
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue