feat: allow passing a Component to the set method (#3597)
* refactor: Add _find_matching_output_method to Component class * feat: allow components to be passed in set method * fix: Add test for graph set method with valid component * set value variable to the output callable * refactor: Update test_component.py to use set_component method This commit refactors the test_component.py file in the custom_component directory. The test_set_invalid_input() function has been renamed to test_set_component() to better reflect its purpose. Additionally, the test_set_component() function now sets the agent parameter using the set_component() method instead of raising a ValueError. This change improves the readability and maintainability of the code. * refactor: Fix formatting issue in _build_error_string_from_matching_pairs The _build_error_string_from_matching_pairs method in the Component class had a formatting issue when input types were empty. This commit fixes the issue by adding a check for empty input types and providing an empty list as a fallback. This improves the accuracy and readability of the error string generated by the method. * fix(component.py): add validation to ensure output method is a string to prevent potential runtime errors
This commit is contained in:
parent
1a0bc3b968
commit
3e19a3fd36
3 changed files with 50 additions and 6 deletions
|
|
@ -310,9 +310,40 @@ class Component(CustomComponent):
|
|||
)
|
||||
return method_is_output
|
||||
|
||||
def _build_error_string_from_matching_pairs(self, matching_pairs: list[tuple[Output, Input]]):
|
||||
text = ""
|
||||
for output, input_ in matching_pairs:
|
||||
text += f"{output.name}[{','.join(output.types)}]->{input_.name}[{','.join(input_.input_types or [])}]\n"
|
||||
return text
|
||||
|
||||
def _find_matching_output_method(self, value: "Component"):
|
||||
# get all outputs of the value component
|
||||
outputs = value.outputs
|
||||
# check if the any of the types in the output.types matches ONLY one input in the current component
|
||||
matching_pairs = []
|
||||
for output in outputs:
|
||||
for input_ in self.inputs:
|
||||
for output_type in output.types:
|
||||
if input_.input_types and output_type in input_.input_types:
|
||||
matching_pairs.append((output, input_))
|
||||
if len(matching_pairs) > 1:
|
||||
matching_pairs_str = self._build_error_string_from_matching_pairs(matching_pairs)
|
||||
raise ValueError(
|
||||
f"There are multiple outputs from {value.__class__.__name__} that can connect to inputs in {self.__class__.__name__}: {matching_pairs_str}"
|
||||
)
|
||||
output, input_ = matching_pairs[0]
|
||||
if not isinstance(output.method, str):
|
||||
raise ValueError(f"Method {output.method} is not a valid output of {value.__class__.__name__}")
|
||||
return getattr(value, output.method)
|
||||
|
||||
def _process_connection_or_parameter(self, key, value):
|
||||
_input = self._get_or_create_input(key)
|
||||
# We need to check if callable AND if it is a method from a class that inherits from Component
|
||||
if isinstance(value, Component):
|
||||
# We need to find the Output that can connect to an input of the current component
|
||||
# if there's more than one output that matches, we need to raise an error
|
||||
# because we don't know which one to connect to
|
||||
value = self._find_matching_output_method(value)
|
||||
if callable(value) and self._inherits_from_component(value):
|
||||
try:
|
||||
self._method_is_valid_output(value)
|
||||
|
|
|
|||
|
|
@ -18,11 +18,9 @@ def test_set_invalid_output():
|
|||
chatoutput.set(input_value=chatinput.build_config)
|
||||
|
||||
|
||||
def test_set_invalid_input():
|
||||
def test_set_component():
|
||||
crewai_agent = CrewAIAgentComponent()
|
||||
task = SequentialTaskComponent()
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="You set CrewAI Agent as value for `agent`. You should pass one of the following: 'build_output'",
|
||||
):
|
||||
task.set(agent=crewai_agent)
|
||||
task.set(agent=crewai_agent)
|
||||
assert task._edges[0]["source"] == crewai_agent._id
|
||||
assert crewai_agent in task._components
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ from collections import deque
|
|||
|
||||
import pytest
|
||||
|
||||
from langflow.components.agents.ToolCallingAgent import ToolCallingAgentComponent
|
||||
from langflow.components.inputs.ChatInput import ChatInput
|
||||
from langflow.components.outputs.ChatOutput import ChatOutput
|
||||
from langflow.components.outputs.TextOutput import TextOutputComponent
|
||||
from langflow.components.tools.YfinanceTool import YfinanceToolComponent
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.graph.constants import Finish
|
||||
|
||||
|
|
@ -139,3 +141,16 @@ def test_graph_functional_start_end():
|
|||
assert len(results) == len(ids) + 1
|
||||
assert all(result.vertex.id in ids for result in results if hasattr(result, "vertex"))
|
||||
assert results[-1] == Finish()
|
||||
|
||||
|
||||
def test_graph_set_with_invalid_component():
|
||||
chat_input = ChatInput(_id="chat_input")
|
||||
chat_output = ChatOutput(input_value="test", _id="chat_output")
|
||||
with pytest.raises(ValueError, match="There are multiple outputs"):
|
||||
chat_output.set(sender_name=chat_input)
|
||||
|
||||
|
||||
def test_graph_set_with_valid_component():
|
||||
tool = YfinanceToolComponent()
|
||||
tool_calling_agent = ToolCallingAgentComponent()
|
||||
tool_calling_agent.set(tools=[tool])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue