From c99f2a35bd6ef91c5f52af965cc65c68c8aacbec Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Thu, 28 Nov 2024 17:58:50 -0300 Subject: [PATCH] fix: adds better boolean check for DataFrame and fixes output display (#4933) * feat: Add DataFrameInput to inputs module * feat: add DataFrame support and refactor array processing * feat: add truth value testing for DataFrame class * refactor: remove Python 2 compatibility method from DataFrame class --- src/backend/base/langflow/inputs/__init__.py | 2 ++ src/backend/base/langflow/schema/artifact.py | 22 +++++++++++-------- src/backend/base/langflow/schema/dataframe.py | 7 ++++++ 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/src/backend/base/langflow/inputs/__init__.py b/src/backend/base/langflow/inputs/__init__.py index b659ed3a9..cd4241ae5 100644 --- a/src/backend/base/langflow/inputs/__init__.py +++ b/src/backend/base/langflow/inputs/__init__.py @@ -1,6 +1,7 @@ from .inputs import ( BoolInput, CodeInput, + DataFrameInput, DataInput, DefaultPromptField, DictInput, @@ -51,4 +52,5 @@ __all__ = [ "SliderInput", "StrInput", "TableInput", + "DataFrameInput", ] diff --git a/src/backend/base/langflow/schema/artifact.py b/src/backend/base/langflow/schema/artifact.py index 0e95b041d..64cbb7bb5 100644 --- a/src/backend/base/langflow/schema/artifact.py +++ b/src/backend/base/langflow/schema/artifact.py @@ -6,6 +6,7 @@ from loguru import logger from pydantic import BaseModel from langflow.schema.data import Data +from langflow.schema.dataframe import DataFrame from langflow.schema.encoders import CUSTOM_ENCODERS from langflow.schema.message import Message from langflow.schema.serialize import recursive_serialize_or_str @@ -40,9 +41,8 @@ def get_artifact_type(value, build_result=None) -> str: case dict(): result = ArtifactType.OBJECT - case list(): + case list() | DataFrame(): result = ArtifactType.ARRAY - if result == ArtifactType.UNKNOWN and ( (build_result and isinstance(build_result, Generator)) or (isinstance(value, Message) and isinstance(value.text, Generator)) @@ -52,17 +52,21 @@ def get_artifact_type(value, build_result=None) -> str: return result.value +def _to_list_of_dicts(raw): + _raw = [] + for item in raw: + if hasattr(item, "dict") or hasattr(item, "model_dump"): + _raw.append(recursive_serialize_or_str(item)) + else: + _raw.append(str(item)) + return _raw + + def post_process_raw(raw, artifact_type: str): if artifact_type == ArtifactType.STREAM.value: raw = "" elif artifact_type == ArtifactType.ARRAY.value: - _raw = [] - for item in raw: - if hasattr(item, "dict") or hasattr(item, "model_dump"): - _raw.append(recursive_serialize_or_str(item)) - else: - _raw.append(str(item)) - raw = _raw + raw = raw.to_dict(orient="records") if isinstance(raw, DataFrame) else _to_list_of_dicts(raw) elif artifact_type == ArtifactType.UNKNOWN.value and raw is not None: if isinstance(raw, BaseModel | dict): try: diff --git a/src/backend/base/langflow/schema/dataframe.py b/src/backend/base/langflow/schema/dataframe.py index bd027351e..bc99470b8 100644 --- a/src/backend/base/langflow/schema/dataframe.py +++ b/src/backend/base/langflow/schema/dataframe.py @@ -96,3 +96,10 @@ class DataFrame(pandas_DataFrame): return DataFrame(*args, **kwargs).__finalize__(self) return _c + + def __bool__(self): + """Truth value testing for the DataFrame. + + Returns True if the DataFrame has at least one row, False otherwise. + """ + return not self.empty