refactor: Update TextOperatorComponent to use langflow.template and langflow.schema

The TextOperatorComponent in TextOperator.py has been refactored to use the langflow.template.Input, langflow.template.Output, and langflow.schema.Record classes for improved code structure and maintainability.

Note: The commit message has been generated based on the provided code changes and recent commits.
This commit is contained in:
ogabrielluiz 2024-06-04 12:56:10 -03:00
commit dd4b617b48
2 changed files with 62 additions and 47 deletions

View file

@ -1,50 +1,64 @@
from typing import Optional, Union
from typing import Union
from langflow.custom import CustomComponent
from langflow.custom import Component
from langflow.field_typing import Text
from langflow.schema import Record
from langflow.template import Input, Output
class TextOperatorComponent(CustomComponent):
class TextOperatorComponent(Component):
display_name = "Text Operator"
description = "Compares two text inputs based on a specified condition such as equality or inequality, with optional case sensitivity."
def build_config(self) -> dict:
return {
"input_text": {
"display_name": "Input Text",
"info": "The primary text input for the operation.",
},
"match_text": {
"display_name": "Match Text",
"info": "The text input to compare against.",
},
"operator": {
"display_name": "Operator",
"info": "The operator to apply for comparing the texts.",
"options": ["equals", "not equals", "contains", "starts with", "ends with", "exists"],
},
"case_sensitive": {
"display_name": "Case Sensitive",
"info": "If true, the comparison will be case sensitive.",
"field_type": "bool",
"default": False,
},
"true_output": {
"display_name": "Output",
"info": "The output to return or display when the comparison is true.",
"input_types": ["Text", "Record"], # Allow both text and record types
},
}
inputs = [
Input(name="input_text", type=str, display_name="Input Text", info="The primary text input for the operation."),
Input(name="match_text", type=str, display_name="Match Text", info="The text input to compare against."),
Input(
name="operator",
type=str,
display_name="Operator",
info="The operator to apply for comparing the texts.",
options=["equals", "not equals", "contains", "starts with", "ends with", "exists"],
),
Input(
name="case_sensitive",
type=bool,
display_name="Case Sensitive",
info="If true, the comparison will be case sensitive.",
default=False,
),
Input(
name="true_output",
type=Union[str, Record],
display_name="True Output",
info="The output to return or display when the comparison is true.",
input_types=["Text", "Record"],
),
Input(
name="false_output",
type=Union[str, Record],
display_name="False Output",
info="The output to return or display when the comparison is false.",
input_types=["Text", "Record"],
),
]
outputs = [
Output(name="True Result", method="result_response"),
Output(name="False Result", method="result_response"),
]
def true_response(self) -> Union[Text, Record]:
return self.true_output if self.true_output else self.input_text
def false_response(self) -> Union[Text, Record]:
return self.false_output if self.false_output else self.input_text
def result_response(self) -> Union[Text, Record]:
input_text = self.input_text
match_text = self.match_text
operator = self.operator
case_sensitive = self.case_sensitive
def build(
self,
input_text: Text,
match_text: Text,
operator: Text,
case_sensitive: bool = False,
true_output: Optional[Text] = "",
) -> Union[Text, Record]:
if not input_text or not match_text:
raise ValueError("Both 'input_text' and 'match_text' must be provided and non-empty.")
@ -64,13 +78,9 @@ class TextOperatorComponent(CustomComponent):
elif operator == "ends with":
result = input_text.endswith(match_text)
output_record = true_output if true_output else input_text
if result:
self.status = output_record
return output_record
self.status = self.true_response()
return self.true_response()
else:
self.status = "Comparison failed, stopping execution."
self.stop()
return output_record
self.status = self.false_response()
return self.false_response()

View file

@ -4,7 +4,6 @@ from uuid import UUID, uuid4
import pytest
from fastapi import status
from fastapi.testclient import TestClient
from langflow.custom.directory_reader.directory_reader import DirectoryReader
from langflow.services.deps import get_settings_service
@ -638,6 +637,7 @@ def test_successful_run_with_input_type_any(client, starter_project, created_api
), any_input_outputs
@pytest.mark.api_key_required
def test_run_with_inputs_and_outputs(client, starter_project, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = starter_project["id"]
@ -665,6 +665,7 @@ def test_invalid_flow_id(client, created_api_key):
# Check if the error detail is as expected
@pytest.mark.api_key_required
def test_run_flow_with_caching_success(client: TestClient, starter_project, created_api_key):
flow_id = starter_project["id"]
headers = {"x-api-key": created_api_key.api_key}
@ -682,6 +683,7 @@ def test_run_flow_with_caching_success(client: TestClient, starter_project, crea
assert "session_id" in data
@pytest.mark.api_key_required
def test_run_flow_with_caching_invalid_flow_id(client: TestClient, created_api_key):
invalid_flow_id = uuid4()
headers = {"x-api-key": created_api_key.api_key}
@ -693,6 +695,7 @@ def test_run_flow_with_caching_invalid_flow_id(client: TestClient, created_api_k
assert f"Flow identifier {invalid_flow_id} not found" in data["detail"]
@pytest.mark.api_key_required
def test_run_flow_with_caching_invalid_input_format(client: TestClient, starter_project, created_api_key):
flow_id = starter_project["id"]
headers = {"x-api-key": created_api_key.api_key}
@ -701,6 +704,7 @@ def test_run_flow_with_caching_invalid_input_format(client: TestClient, starter_
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
@pytest.mark.api_key_required
def test_run_flow_with_session_id(client, starter_project, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = starter_project["id"]
@ -732,6 +736,7 @@ def test_run_flow_with_invalid_session_id(client, starter_project, created_api_k
assert f"Session {payload['session_id']} not found" in data["detail"]
@pytest.mark.api_key_required
def test_run_flow_with_invalid_tweaks(client, starter_project, created_api_key):
headers = {"x-api-key": created_api_key.api_key}
flow_id = starter_project["id"]