feat: allow passing a Component to the set method (#3597)

* refactor: Add _find_matching_output_method to Component class

* feat: allow components to be passed in set method

* fix: Add test for graph set method with valid component

* set value variable to the output callable

* refactor: Update test_component.py to use set_component method

This commit refactors the test_component.py file in the custom_component directory. The test_set_invalid_input() function has been renamed to test_set_component() to better reflect its purpose. Additionally, the test_set_component() function now sets the agent parameter using the set_component() method instead of raising a ValueError. This change improves the readability and maintainability of the code.

* refactor: Fix formatting issue in _build_error_string_from_matching_pairs

The _build_error_string_from_matching_pairs method in the Component class had a formatting issue when input types were empty. This commit fixes the issue by adding a check for empty input types and providing an empty list as a fallback. This improves the accuracy and readability of the error string generated by the method.

* fix(component.py): add validation to ensure output method is a string to prevent potential runtime errors
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-08-28 16:31:14 -03:00 committed by GitHub
commit 3e19a3fd36
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 50 additions and 6 deletions

View file

@ -310,9 +310,40 @@ class Component(CustomComponent):
)
return method_is_output
def _build_error_string_from_matching_pairs(self, matching_pairs: list[tuple[Output, Input]]):
text = ""
for output, input_ in matching_pairs:
text += f"{output.name}[{','.join(output.types)}]->{input_.name}[{','.join(input_.input_types or [])}]\n"
return text
def _find_matching_output_method(self, value: "Component"):
# get all outputs of the value component
outputs = value.outputs
# check if the any of the types in the output.types matches ONLY one input in the current component
matching_pairs = []
for output in outputs:
for input_ in self.inputs:
for output_type in output.types:
if input_.input_types and output_type in input_.input_types:
matching_pairs.append((output, input_))
if len(matching_pairs) > 1:
matching_pairs_str = self._build_error_string_from_matching_pairs(matching_pairs)
raise ValueError(
f"There are multiple outputs from {value.__class__.__name__} that can connect to inputs in {self.__class__.__name__}: {matching_pairs_str}"
)
output, input_ = matching_pairs[0]
if not isinstance(output.method, str):
raise ValueError(f"Method {output.method} is not a valid output of {value.__class__.__name__}")
return getattr(value, output.method)
def _process_connection_or_parameter(self, key, value):
_input = self._get_or_create_input(key)
# We need to check if callable AND if it is a method from a class that inherits from Component
if isinstance(value, Component):
# We need to find the Output that can connect to an input of the current component
# if there's more than one output that matches, we need to raise an error
# because we don't know which one to connect to
value = self._find_matching_output_method(value)
if callable(value) and self._inherits_from_component(value):
try:
self._method_is_valid_output(value)

View file

@ -18,11 +18,9 @@ def test_set_invalid_output():
chatoutput.set(input_value=chatinput.build_config)
def test_set_invalid_input():
def test_set_component():
crewai_agent = CrewAIAgentComponent()
task = SequentialTaskComponent()
with pytest.raises(
ValueError,
match="You set CrewAI Agent as value for `agent`. You should pass one of the following: 'build_output'",
):
task.set(agent=crewai_agent)
task.set(agent=crewai_agent)
assert task._edges[0]["source"] == crewai_agent._id
assert crewai_agent in task._components

View file

@ -2,9 +2,11 @@ from collections import deque
import pytest
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
from langflow.components.inputs.ChatInput import ChatInput
from langflow.components.outputs.ChatOutput import ChatOutput
from langflow.components.outputs.TextOutput import TextOutputComponent
from langflow.components.tools.YfinanceTool import YfinanceToolComponent
from langflow.graph.graph.base import Graph
from langflow.graph.graph.constants import Finish
@ -139,3 +141,16 @@ def test_graph_functional_start_end():
assert len(results) == len(ids) + 1
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
assert results[-1] == Finish()
def test_graph_set_with_invalid_component():
chat_input = ChatInput(_id="chat_input")
chat_output = ChatOutput(input_value="test", _id="chat_output")
with pytest.raises(ValueError, match="There are multiple outputs"):
chat_output.set(sender_name=chat_input)
def test_graph_set_with_valid_component():
tool = YfinanceToolComponent()
tool_calling_agent = ToolCallingAgentComponent()
tool_calling_agent.set(tools=[tool])