refactor: Update langflow custom imports and base classes

This commit updates the langflow custom imports and base classes in the code. It adds the "Component" import and base class to the langflow custom __init__.py file. It also updates the langflow template __init__.py file to include the "Input", "Output", "FrontendNode", and "Template" imports and base classes. Additionally, it modifies the langflow base ChatComponent class to inherit from the Component class. These changes improve the organization and functionality of the langflow custom and template modules.

Note: The commit message has been generated based on the provided code changes and recent commits.
This commit is contained in:
ogabrielluiz 2024-06-02 20:35:59 -03:00
commit 7fb5644a87
18 changed files with 6760 additions and 6719 deletions

View file

@ -20,6 +20,7 @@ from langflow.api.v1.schemas import (
UploadFileResponse,
)
from langflow.custom import CustomComponent
from langflow.custom.custom_component.component import Component
from langflow.custom.utils import build_custom_component_template
from langflow.graph.graph.base import Graph
from langflow.graph.schema import RunOutputs
@ -475,7 +476,7 @@ async def custom_component(
raw_code: CustomComponentRequest,
user: User = Depends(get_current_active_user),
):
component = CustomComponent(code=raw_code.code)
component = Component(code=raw_code.code)
built_frontend_node, _ = build_custom_component_template(component, user_id=user.id)

View file

@ -1,13 +1,13 @@
from typing import Optional, Union
from langflow.custom import CustomComponent
from langflow.custom import Component
from langflow.field_typing import Text
from langflow.helpers.record import records_to_text
from langflow.memory import store_message
from langflow.schema import Record
class ChatComponent(CustomComponent):
class ChatComponent(Component):
display_name = "Chat Component"
description = "Use as base for chat components."

View file

@ -1,8 +1,7 @@
from typing import Optional, Union
from langflow.base.io.chat import ChatComponent
from langflow.field_typing import Text
from langflow.schema import Record
from langflow.template import Input, Output
class ChatInput(ChatComponent):
@ -10,28 +9,32 @@ class ChatInput(ChatComponent):
description = "Get chat inputs from the Playground."
icon = "ChatInput"
def build_config(self):
build_config = super().build_config()
build_config["input_value"] = {
"input_types": [],
"display_name": "Message",
"multiline": True,
}
inputs = [
Input(name="input_value", type=str, display_name="Message", multiline=True, input_types=[]),
Input(name="sender", type=str, display_name="Sender Type", options=["Machine", "User"]),
Input(name="sender_name", type=str, display_name="Sender Name"),
Input(name="session_id", type=str, display_name="Session ID"),
]
outputs = [
Output(name="Message", method="text_response"),
Output(name="Record", method="record_response"),
]
return build_config
def text_response(self) -> Text:
result = self.message
if self.session_id and isinstance(result, (Record, str)):
self.store_message(result, self.session_id, self.sender, self.sender_name)
return result
def build(
self,
sender: Optional[str] = "User",
sender_name: Optional[str] = "User",
input_value: Optional[str] = None,
session_id: Optional[str] = None,
return_record: Optional[bool] = False,
) -> Union[Text, Record]:
return super().build_no_record(
sender=sender,
sender_name=sender_name,
input_value=input_value,
session_id=session_id,
return_record=return_record,
def record_response(self) -> Record:
record = Record(
data={
"message": self.message,
"sender": self.sender,
"sender_name": self.sender_name,
"session_id": self.session_id,
}
)
if self.session_id and isinstance(record, (Record, str)):
self.store_message(record, self.session_id, self.sender, self.sender_name)
return record

View file

@ -1,3 +1,4 @@
from langflow.custom.custom_component import CustomComponent
from langflow.custom.custom_component.component import Component
__all__ = ["CustomComponent"]
__all__ = ["CustomComponent", "Component"]

View file

@ -0,0 +1,94 @@
import operator
import warnings
from typing import Any, ClassVar, Optional
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.custom.attributes import ATTR_FUNC_MAPPING
from langflow.custom.code_parser import CodeParser
from langflow.custom.eval import eval_custom_component_code
from langflow.utils import validate
class ComponentCodeNullError(HTTPException):
pass
class ComponentFunctionEntrypointNameNullError(HTTPException):
pass
class BaseComponent:
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
code: Optional[str] = None
_function_entrypoint_name: str = "build"
field_config: dict = {}
_user_id: Optional[str]
def __init__(self, **data):
self.cache = TTLCache(maxsize=1024, ttl=60)
for key, value in data.items():
if key == "user_id":
setattr(self, "_user_id", value)
else:
setattr(self, key, value)
def __setattr__(self, key, value):
if key == "_user_id" and hasattr(self, "_user_id"):
warnings.warn("user_id is immutable and cannot be changed.")
super().__setattr__(key, value)
@cachedmethod(cache=operator.attrgetter("cache"))
def get_code_tree(self, code: str):
parser = CodeParser(code)
return parser.parse_code()
def get_function(self):
if not self.code:
raise ComponentCodeNullError(
status_code=400,
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
)
if not self._function_entrypoint_name:
raise ComponentFunctionEntrypointNameNullError(
status_code=400,
detail={
"error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL,
"traceback": "",
},
)
return validate.create_function(self.code, self._function_entrypoint_name)
def build_template_config(self) -> dict:
"""
Builds the template configuration for the custom component.
Returns:
A dictionary representing the template configuration.
"""
if not self.code:
return {}
cc_class = eval_custom_component_code(self.code)
component_instance = cc_class()
template_config = {}
for attribute, func in ATTR_FUNC_MAPPING.items():
if hasattr(component_instance, attribute):
value = getattr(component_instance, attribute)
if value is not None:
template_config[attribute] = func(value=value)
for key in template_config.copy():
if key not in ATTR_FUNC_MAPPING.keys():
template_config.pop(key, None)
return template_config
def build(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError

View file

@ -1,94 +1,17 @@
import operator
import warnings
from typing import Any, ClassVar, Optional
from typing import ClassVar, List, Optional
from cachetools import TTLCache, cachedmethod
from fastapi import HTTPException
from langflow.template.field.base import Input, Output
from langflow.custom.attributes import ATTR_FUNC_MAPPING
from langflow.custom.code_parser import CodeParser
from langflow.custom.eval import eval_custom_component_code
from langflow.utils import validate
from .custom_component import CustomComponent
class ComponentCodeNullError(HTTPException):
pass
class Component(CustomComponent):
inputs: Optional[List[Input]] = None
outputs: Optional[List[Output]] = None
code_class_base_inheritance: ClassVar[str] = "Component"
class ComponentFunctionEntrypointNameNullError(HTTPException):
pass
class Component:
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
code: Optional[str] = None
_function_entrypoint_name: str = "build"
field_config: dict = {}
_user_id: Optional[str]
def __init__(self, **data):
self.cache = TTLCache(maxsize=1024, ttl=60)
for key, value in data.items():
if key == "user_id":
setattr(self, "_user_id", value)
else:
setattr(self, key, value)
def __setattr__(self, key, value):
if key == "_user_id" and hasattr(self, "_user_id"):
warnings.warn("user_id is immutable and cannot be changed.")
super().__setattr__(key, value)
@cachedmethod(cache=operator.attrgetter("cache"))
def get_code_tree(self, code: str):
parser = CodeParser(code)
return parser.parse_code()
def get_function(self):
if not self.code:
raise ComponentCodeNullError(
status_code=400,
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
)
if not self._function_entrypoint_name:
raise ComponentFunctionEntrypointNameNullError(
status_code=400,
detail={
"error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL,
"traceback": "",
},
)
return validate.create_function(self.code, self._function_entrypoint_name)
def build_template_config(self) -> dict:
"""
Builds the template configuration for the custom component.
Returns:
A dictionary representing the template configuration.
"""
if not self.code:
return {}
cc_class = eval_custom_component_code(self.code)
component_instance = cc_class()
template_config = {}
for attribute, func in ATTR_FUNC_MAPPING.items():
if hasattr(component_instance, attribute):
value = getattr(component_instance, attribute)
if value is not None:
template_config[attribute] = func(value=value)
for key in template_config.copy():
if key not in ATTR_FUNC_MAPPING.keys():
template_config.pop(key, None)
return template_config
def build(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError
def set_attributes(self, params: dict):
for key, value in params.items():
if key in self.__dict__:
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
setattr(self, key, value)

View file

@ -12,13 +12,12 @@ from langflow.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.custom.custom_component.component import Component
from langflow.custom.custom_component.base_component import BaseComponent
from langflow.helpers.flow import list_flows, load_flow, run_flow
from langflow.schema import Record
from langflow.schema.dotdict import dotdict
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
from langflow.services.storage.service import StorageService
from langflow.template.field.base import Input, Output
from langflow.utils import validate
if TYPE_CHECKING:
@ -27,7 +26,7 @@ if TYPE_CHECKING:
from langflow.services.storage.service import StorageService
class CustomComponent(Component):
class CustomComponent(BaseComponent):
"""
Represents a custom component in Langflow.
@ -80,9 +79,6 @@ class CustomComponent(Component):
"""The status of the component. This is displayed on the frontend. Defaults to None."""
_flows_records: Optional[List[Record]] = None
inputs: Optional[List[Input]] = None
outputs: Optional[List[Output]] = None
def build_inputs(self, user_id: Optional[Union[str, UUID]] = None):
"""
Builds the inputs for the custom component.
@ -100,12 +96,6 @@ class CustomComponent(Component):
build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs}
return build_config
def set_attributes(self, params: dict):
for key, value in params.items():
if key in self.__dict__:
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
setattr(self, key, value)
def update_state(self, name: str, value: Any):
if not self.vertex:
raise ValueError("Vertex is not set")
@ -493,4 +483,3 @@ class CustomComponent(Component):
Any: The result of the build process.
"""
raise NotImplementedError
raise NotImplementedError

View file

@ -12,6 +12,7 @@ from pydantic import BaseModel
from langflow.custom import CustomComponent
from langflow.custom.code_parser.utils import extract_inner_type
from langflow.custom.custom_component.component import Component
from langflow.custom.directory_reader.utils import (
abuild_custom_component_list_from_path,
build_custom_component_list_from_path,
@ -24,7 +25,7 @@ from langflow.field_typing.range_spec import RangeSpec
from langflow.helpers.custom import format_type
from langflow.schema import dotdict
from langflow.template.field.base import Input
from langflow.template.frontend_node.custom_components import CustomComponentFrontendNode
from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode
from langflow.utils import validate
from langflow.utils.util import get_base_classes
@ -325,7 +326,7 @@ def build_custom_component_template_from_inputs(
custom_component: CustomComponent, user_id: Optional[Union[str, UUID]] = None
):
# The List of Inputs fills the role of the build_config and the entrypoint_args
frontend_node = CustomComponentFrontendNode.from_inputs(**custom_component.template_config)
frontend_node = ComponentFrontendNode.from_inputs(**custom_component.template_config)
field_config = run_build_inputs(
custom_component,
user_id=user_id,
@ -336,6 +337,8 @@ def build_custom_component_template_from_inputs(
return_types = custom_component.get_method_return_type(output.method)
return_types = [format_type(return_type) for return_type in return_types]
output.add_types(return_types)
# ! This should be removed when we have a better way to handle this
frontend_node.get_base_classes_from_outputs()
return frontend_node.to_dict(add_name=False), custom_component
@ -384,7 +387,7 @@ def create_component_template(component):
component_code = component["code"]
component_output_types = component["output_types"]
component_extractor = CustomComponent(code=component_code)
component_extractor = Component(code=component_code)
component_template, _ = build_custom_component_template(component_extractor)
if not component_template["output_types"] and component_output_types:

View file

@ -11,10 +11,11 @@ if TYPE_CHECKING:
class SourceHandle(BaseModel):
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
baseClasses: Optional[List[str]] = Field(None, description="List of base classes for the source handle.")
dataType: str = Field(..., description="Data type for the source handle.")
id: str = Field(..., description="Unique identifier for the source handle.")
conditionalPath: Optional[bool] = Field(None, description="Conditional path for the source handle.")
name: str = Field(..., description="Name of the source handle.")
output_types: List[str] = Field(..., description="List of output types for the source handle.")
class TargetHandle(BaseModel):
@ -49,11 +50,11 @@ class Edge:
def validate_handles(self, source, target) -> None:
if self.target_handle.inputTypes is None:
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
self.valid_handles = self.target_handle.type in self.source_handle.output_types
else:
self.valid_handles = (
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
or self.target_handle.type in self.source_handle.baseClasses
any(output_type in self.target_handle.inputTypes for output_type in self.source_handle.output_types)
or self.target_handle.type in self.source_handle.output_types
)
if not self.valid_handles:
logger.debug(self.source_handle)
@ -70,16 +71,29 @@ class Edge:
def validate_edge(self, source, target) -> None:
# Validate that the outputs of the source node are valid inputs
# for the target node
self.source_types = source.output
# .outputs is a list of Output objects as dictionaries
# meaning: check for "types" key in each dictionary
self.source_types = [output for output in source.outputs if output["name"] == self.source_handle.name]
self.target_reqs = target.required_inputs + target.optional_inputs
# Both lists contain strings and sometimes a string contains the value we are
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
# so we need to check if any of the strings in source_types is in target_reqs
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
self.valid = any(
any(output_type in target_req for output_type in output["types"])
for output in self.source_types
for target_req in self.target_reqs
)
# Get what type of input the target node is expecting
# Update the matched type to be the first found match
self.matched_type = next(
(output for output in self.source_types if output in self.target_reqs),
(
output_type
for output in self.source_types
for output_type in output["types"]
for target_req in self.target_reqs
if output_type in target_req
),
None,
)
no_matched_type = self.matched_type is None

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,10 @@
from langflow.template.field.base import Input, Output
from langflow.template.frontend_node.base import FrontendNode
from langflow.template.template.base import Template
__all__ = [
"Input",
"Output",
"FrontendNode",
"Template",
]

View file

@ -101,6 +101,9 @@ class FrontendNode(BaseModel):
def add_extra_base_classes(self) -> None:
pass
def get_base_classes_from_outputs(self) -> list[str]:
self.base_classes = [output_type for output in self.outputs for output_type in output.types]
def add_base_class(self, base_class: Union[str, List[str]]) -> None:
"""Adds a base class to the frontend node."""
if isinstance(base_class, str):

View file

@ -62,6 +62,8 @@ export default function ParameterComponent({
index,
outputName,
}: ParameterComponentType): JSX.Element {
console.log("title", title);
console.log("data", data);
const infoHtml = useRef<HTMLDivElement & ReactNode>(null);
const nodes = useFlowStore((state) => state.nodes);
const edges = useFlowStore((state) => state.edges);
@ -72,7 +74,6 @@ export default function ParameterComponent({
const updateNodeInternals = useUpdateNodeInternals();
const [errorDuplicateKey, setErrorDuplicateKey] = useState(false);
const setFilterEdge = useFlowStore((state) => state.setFilterEdge);
const { handleOnNewValue: handleOnNewValueHook } = useHandleOnNewValue(
data,
name,

View file

@ -55,14 +55,14 @@ export default function GenericNode({
const [nodeName, setNodeName] = useState(data.node!.display_name);
const [inputDescription, setInputDescription] = useState(false);
const [nodeDescription, setNodeDescription] = useState(
data.node?.description!,
data.node?.description!
);
const [isOutdated, setIsOutdated] = useState(false);
const buildStatus = useFlowStore(
(state) => state.flowBuildStatus[data.id]?.status,
(state) => state.flowBuildStatus[data.id]?.status
);
const lastRunTime = useFlowStore(
(state) => state.flowBuildStatus[data.id]?.timestamp,
(state) => state.flowBuildStatus[data.id]?.timestamp
);
const [validationStatus, setValidationStatus] =
useState<validationStatusType | null>(null);
@ -115,7 +115,7 @@ export default function GenericNode({
updateNodeInternals(data.id);
},
[data.id, data.node, setNode, setIsOutdated],
[data.id, data.node, setNode, setIsOutdated]
);
if (!data.node!.template) {
@ -255,7 +255,7 @@ export default function GenericNode({
const isDark = useDarkStore((state) => state.dark);
const renderIconStatus = (
buildStatus: BuildStatus | undefined,
validationStatus: validationStatusType | null,
validationStatus: validationStatusType | null
) => {
if (buildStatus === BuildStatus.BUILDING) {
return <Loading className="text-medium-indigo" />;
@ -296,7 +296,7 @@ export default function GenericNode({
};
const getSpecificClassFromBuildStatus = (
buildStatus: BuildStatus | undefined,
validationStatus: validationStatusType | null,
validationStatus: validationStatusType | null
) => {
let isInvalid = validationStatus && !validationStatus.valid;
@ -320,11 +320,11 @@ export default function GenericNode({
selected: boolean,
showNode: boolean,
buildStatus: BuildStatus | undefined,
validationStatus: validationStatusType | null,
validationStatus: validationStatusType | null
) => {
const specificClassFromBuildStatus = getSpecificClassFromBuildStatus(
buildStatus,
validationStatus,
validationStatus
);
const baseBorderClass = getBaseBorderClass(selected);
@ -333,7 +333,7 @@ export default function GenericNode({
baseBorderClass,
nodeSizeClass,
"generic-node-div",
specificClassFromBuildStatus,
specificClassFromBuildStatus
);
return names;
};
@ -393,7 +393,7 @@ export default function GenericNode({
selected,
showNode,
buildStatus,
validationStatus,
validationStatus
)}
>
{data.node?.beta && showNode && (
@ -538,7 +538,7 @@ export default function GenericNode({
}
title={getFieldTitle(
data.node?.template!,
templateField,
templateField
)}
info={data.node?.template[templateField].info}
name={templateField}
@ -566,7 +566,7 @@ export default function GenericNode({
proxy={data.node?.template[templateField].proxy}
showNode={showNode}
/>
),
)
)}
{/* <ParameterComponent
index={0}
@ -725,7 +725,7 @@ export default function GenericNode({
!data.node?.description) &&
nameEditable
? "font-light italic"
: "",
: ""
)}
onDoubleClick={(e) => {
setInputDescription(true);
@ -787,13 +787,13 @@ export default function GenericNode({
}
title={getFieldTitle(
data.node?.template!,
templateField,
templateField
)}
info={data.node?.template[templateField].info}
name={templateField}
tooltipTitle={
data.node?.template[templateField].input_types?.join(
"\n",
"\n"
) ?? data.node?.template[templateField].type
}
required={data.node!.template[templateField].required}
@ -820,7 +820,7 @@ export default function GenericNode({
<div
className={classNames(
Object.keys(data.node!.template).length < 1 ? "hidden" : "",
"flex-max-width justify-center",
"flex-max-width justify-center"
)}
>
{" "}
@ -842,7 +842,7 @@ export default function GenericNode({
nodeColors[types[data.type]] ??
nodeColors.unknown
}
title={output.selected ?? output.types[0]}
title={output.name}
tooltipTitle={output.selected ?? output.types[0]}
id={{
output_types: [output.selected ?? output.types[0]],

View file

@ -4,10 +4,9 @@ from uuid import uuid4
import pytest
from langchain_core.documents import Document
from langflow.custom import CustomComponent
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
from langflow.custom.custom_component.component import Component, ComponentCodeNullError
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
from langflow.custom.utils import build_custom_component_template
from langflow.services.database.models.flow import Flow, FlowCreate
@ -77,7 +76,7 @@ def test_component_init():
"""
Test the initialization of the Component class.
"""
component = Component(code=code_default, function_entrypoint_name="build")
component = BaseComponent(code=code_default, function_entrypoint_name="build")
assert component.code == code_default
assert component.function_entrypoint_name == "build"
@ -86,7 +85,7 @@ def test_component_get_code_tree():
"""
Test the get_code_tree method of the Component class.
"""
component = Component(code=code_default, function_entrypoint_name="build")
component = BaseComponent(code=code_default, function_entrypoint_name="build")
tree = component.get_code_tree(component.code)
assert "imports" in tree
@ -96,7 +95,7 @@ def test_component_code_null_error():
Test the get_function method raises the
ComponentCodeNullError when the code is empty.
"""
component = Component(code="", function_entrypoint_name="")
component = BaseComponent(code="", function_entrypoint_name="")
with pytest.raises(ComponentCodeNullError):
component.get_function()
@ -200,7 +199,7 @@ def test_component_get_function_valid():
Test the get_function method of the Component
class with valid code and function_entrypoint_name.
"""
component = Component(code="def build(): pass", function_entrypoint_name="build")
component = BaseComponent(code="def build(): pass", function_entrypoint_name="build")
my_function = component.get_function()
assert callable(my_function)
@ -357,7 +356,7 @@ def test_component_get_code_tree_syntax_error():
Test the get_code_tree method of the Component class
raises the CodeSyntaxError when given incorrect syntax.
"""
component = Component(code="import os as", function_entrypoint_name="build")
component = BaseComponent(code="import os as", function_entrypoint_name="build")
with pytest.raises(CodeSyntaxError):
component.get_code_tree(component.code)