refactor: Update langflow custom imports and base classes
This commit updates the langflow custom imports and base classes in the code. It adds the "Component" import and base class to the langflow custom __init__.py file. It also updates the langflow template __init__.py file to include the "Input", "Output", "FrontendNode", and "Template" imports and base classes. Additionally, it modifies the langflow base ChatComponent class to inherit from the Component class. These changes improve the organization and functionality of the langflow custom and template modules. Note: The commit message has been generated based on the provided code changes and recent commits.
This commit is contained in:
parent
1fa3a6a379
commit
7fb5644a87
18 changed files with 6760 additions and 6719 deletions
|
|
@ -20,6 +20,7 @@ from langflow.api.v1.schemas import (
|
|||
UploadFileResponse,
|
||||
)
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
from langflow.graph.graph.base import Graph
|
||||
from langflow.graph.schema import RunOutputs
|
||||
|
|
@ -475,7 +476,7 @@ async def custom_component(
|
|||
raw_code: CustomComponentRequest,
|
||||
user: User = Depends(get_current_active_user),
|
||||
):
|
||||
component = CustomComponent(code=raw_code.code)
|
||||
component = Component(code=raw_code.code)
|
||||
|
||||
built_frontend_node, _ = build_custom_component_template(component, user_id=user.id)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Text
|
||||
from langflow.helpers.record import records_to_text
|
||||
from langflow.memory import store_message
|
||||
from langflow.schema import Record
|
||||
|
||||
|
||||
class ChatComponent(CustomComponent):
|
||||
class ChatComponent(Component):
|
||||
display_name = "Chat Component"
|
||||
description = "Use as base for chat components."
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
from langflow.base.io.chat import ChatComponent
|
||||
from langflow.field_typing import Text
|
||||
from langflow.schema import Record
|
||||
from langflow.template import Input, Output
|
||||
|
||||
|
||||
class ChatInput(ChatComponent):
|
||||
|
|
@ -10,28 +9,32 @@ class ChatInput(ChatComponent):
|
|||
description = "Get chat inputs from the Playground."
|
||||
icon = "ChatInput"
|
||||
|
||||
def build_config(self):
|
||||
build_config = super().build_config()
|
||||
build_config["input_value"] = {
|
||||
"input_types": [],
|
||||
"display_name": "Message",
|
||||
"multiline": True,
|
||||
}
|
||||
inputs = [
|
||||
Input(name="input_value", type=str, display_name="Message", multiline=True, input_types=[]),
|
||||
Input(name="sender", type=str, display_name="Sender Type", options=["Machine", "User"]),
|
||||
Input(name="sender_name", type=str, display_name="Sender Name"),
|
||||
Input(name="session_id", type=str, display_name="Session ID"),
|
||||
]
|
||||
outputs = [
|
||||
Output(name="Message", method="text_response"),
|
||||
Output(name="Record", method="record_response"),
|
||||
]
|
||||
|
||||
return build_config
|
||||
def text_response(self) -> Text:
|
||||
result = self.message
|
||||
if self.session_id and isinstance(result, (Record, str)):
|
||||
self.store_message(result, self.session_id, self.sender, self.sender_name)
|
||||
return result
|
||||
|
||||
def build(
|
||||
self,
|
||||
sender: Optional[str] = "User",
|
||||
sender_name: Optional[str] = "User",
|
||||
input_value: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
return_record: Optional[bool] = False,
|
||||
) -> Union[Text, Record]:
|
||||
return super().build_no_record(
|
||||
sender=sender,
|
||||
sender_name=sender_name,
|
||||
input_value=input_value,
|
||||
session_id=session_id,
|
||||
return_record=return_record,
|
||||
def record_response(self) -> Record:
|
||||
record = Record(
|
||||
data={
|
||||
"message": self.message,
|
||||
"sender": self.sender,
|
||||
"sender_name": self.sender_name,
|
||||
"session_id": self.session_id,
|
||||
}
|
||||
)
|
||||
if self.session_id and isinstance(record, (Record, str)):
|
||||
self.store_message(record, self.session_id, self.sender, self.sender_name)
|
||||
return record
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from langflow.custom.custom_component import CustomComponent
|
||||
from langflow.custom.custom_component.component import Component
|
||||
|
||||
__all__ = ["CustomComponent"]
|
||||
__all__ = ["CustomComponent", "Component"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,94 @@
|
|||
import operator
|
||||
import warnings
|
||||
from typing import Any, ClassVar, Optional
|
||||
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
|
||||
from langflow.custom.attributes import ATTR_FUNC_MAPPING
|
||||
from langflow.custom.code_parser import CodeParser
|
||||
from langflow.custom.eval import eval_custom_component_code
|
||||
from langflow.utils import validate
|
||||
|
||||
|
||||
class ComponentCodeNullError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class ComponentFunctionEntrypointNameNullError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class BaseComponent:
|
||||
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
|
||||
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
|
||||
|
||||
code: Optional[str] = None
|
||||
_function_entrypoint_name: str = "build"
|
||||
field_config: dict = {}
|
||||
_user_id: Optional[str]
|
||||
|
||||
def __init__(self, **data):
|
||||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
for key, value in data.items():
|
||||
if key == "user_id":
|
||||
setattr(self, "_user_id", value)
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == "_user_id" and hasattr(self, "_user_id"):
|
||||
warnings.warn("user_id is immutable and cannot be changed.")
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@cachedmethod(cache=operator.attrgetter("cache"))
|
||||
def get_code_tree(self, code: str):
|
||||
parser = CodeParser(code)
|
||||
return parser.parse_code()
|
||||
|
||||
def get_function(self):
|
||||
if not self.code:
|
||||
raise ComponentCodeNullError(
|
||||
status_code=400,
|
||||
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
|
||||
)
|
||||
|
||||
if not self._function_entrypoint_name:
|
||||
raise ComponentFunctionEntrypointNameNullError(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL,
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
return validate.create_function(self.code, self._function_entrypoint_name)
|
||||
|
||||
def build_template_config(self) -> dict:
|
||||
"""
|
||||
Builds the template configuration for the custom component.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the template configuration.
|
||||
"""
|
||||
if not self.code:
|
||||
return {}
|
||||
|
||||
cc_class = eval_custom_component_code(self.code)
|
||||
component_instance = cc_class()
|
||||
template_config = {}
|
||||
|
||||
for attribute, func in ATTR_FUNC_MAPPING.items():
|
||||
if hasattr(component_instance, attribute):
|
||||
value = getattr(component_instance, attribute)
|
||||
if value is not None:
|
||||
template_config[attribute] = func(value=value)
|
||||
|
||||
for key in template_config.copy():
|
||||
if key not in ATTR_FUNC_MAPPING.keys():
|
||||
template_config.pop(key, None)
|
||||
|
||||
return template_config
|
||||
|
||||
def build(self, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,94 +1,17 @@
|
|||
import operator
|
||||
import warnings
|
||||
from typing import Any, ClassVar, Optional
|
||||
from typing import ClassVar, List, Optional
|
||||
|
||||
from cachetools import TTLCache, cachedmethod
|
||||
from fastapi import HTTPException
|
||||
from langflow.template.field.base import Input, Output
|
||||
|
||||
from langflow.custom.attributes import ATTR_FUNC_MAPPING
|
||||
from langflow.custom.code_parser import CodeParser
|
||||
from langflow.custom.eval import eval_custom_component_code
|
||||
from langflow.utils import validate
|
||||
from .custom_component import CustomComponent
|
||||
|
||||
|
||||
class ComponentCodeNullError(HTTPException):
|
||||
pass
|
||||
class Component(CustomComponent):
|
||||
inputs: Optional[List[Input]] = None
|
||||
outputs: Optional[List[Output]] = None
|
||||
code_class_base_inheritance: ClassVar[str] = "Component"
|
||||
|
||||
|
||||
class ComponentFunctionEntrypointNameNullError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class Component:
|
||||
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
|
||||
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
|
||||
|
||||
code: Optional[str] = None
|
||||
_function_entrypoint_name: str = "build"
|
||||
field_config: dict = {}
|
||||
_user_id: Optional[str]
|
||||
|
||||
def __init__(self, **data):
|
||||
self.cache = TTLCache(maxsize=1024, ttl=60)
|
||||
for key, value in data.items():
|
||||
if key == "user_id":
|
||||
setattr(self, "_user_id", value)
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == "_user_id" and hasattr(self, "_user_id"):
|
||||
warnings.warn("user_id is immutable and cannot be changed.")
|
||||
super().__setattr__(key, value)
|
||||
|
||||
@cachedmethod(cache=operator.attrgetter("cache"))
|
||||
def get_code_tree(self, code: str):
|
||||
parser = CodeParser(code)
|
||||
return parser.parse_code()
|
||||
|
||||
def get_function(self):
|
||||
if not self.code:
|
||||
raise ComponentCodeNullError(
|
||||
status_code=400,
|
||||
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
|
||||
)
|
||||
|
||||
if not self._function_entrypoint_name:
|
||||
raise ComponentFunctionEntrypointNameNullError(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL,
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
return validate.create_function(self.code, self._function_entrypoint_name)
|
||||
|
||||
def build_template_config(self) -> dict:
|
||||
"""
|
||||
Builds the template configuration for the custom component.
|
||||
|
||||
Returns:
|
||||
A dictionary representing the template configuration.
|
||||
"""
|
||||
if not self.code:
|
||||
return {}
|
||||
|
||||
cc_class = eval_custom_component_code(self.code)
|
||||
component_instance = cc_class()
|
||||
template_config = {}
|
||||
|
||||
for attribute, func in ATTR_FUNC_MAPPING.items():
|
||||
if hasattr(component_instance, attribute):
|
||||
value = getattr(component_instance, attribute)
|
||||
if value is not None:
|
||||
template_config[attribute] = func(value=value)
|
||||
|
||||
for key in template_config.copy():
|
||||
if key not in ATTR_FUNC_MAPPING.keys():
|
||||
template_config.pop(key, None)
|
||||
|
||||
return template_config
|
||||
|
||||
def build(self, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
def set_attributes(self, params: dict):
|
||||
for key, value in params.items():
|
||||
if key in self.__dict__:
|
||||
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
|
||||
setattr(self, key, value)
|
||||
|
|
|
|||
|
|
@ -12,13 +12,12 @@ from langflow.custom.code_parser.utils import (
|
|||
extract_inner_type_from_generic_alias,
|
||||
extract_union_types_from_generic_alias,
|
||||
)
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.custom.custom_component.base_component import BaseComponent
|
||||
from langflow.helpers.flow import list_flows, load_flow, run_flow
|
||||
from langflow.schema import Record
|
||||
from langflow.schema.dotdict import dotdict
|
||||
from langflow.services.deps import get_storage_service, get_variable_service, session_scope
|
||||
from langflow.services.storage.service import StorageService
|
||||
from langflow.template.field.base import Input, Output
|
||||
from langflow.utils import validate
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -27,7 +26,7 @@ if TYPE_CHECKING:
|
|||
from langflow.services.storage.service import StorageService
|
||||
|
||||
|
||||
class CustomComponent(Component):
|
||||
class CustomComponent(BaseComponent):
|
||||
"""
|
||||
Represents a custom component in Langflow.
|
||||
|
||||
|
|
@ -80,9 +79,6 @@ class CustomComponent(Component):
|
|||
"""The status of the component. This is displayed on the frontend. Defaults to None."""
|
||||
_flows_records: Optional[List[Record]] = None
|
||||
|
||||
inputs: Optional[List[Input]] = None
|
||||
outputs: Optional[List[Output]] = None
|
||||
|
||||
def build_inputs(self, user_id: Optional[Union[str, UUID]] = None):
|
||||
"""
|
||||
Builds the inputs for the custom component.
|
||||
|
|
@ -100,12 +96,6 @@ class CustomComponent(Component):
|
|||
build_config = {_input.name: _input.model_dump(by_alias=True, exclude_none=True) for _input in self.inputs}
|
||||
return build_config
|
||||
|
||||
def set_attributes(self, params: dict):
|
||||
for key, value in params.items():
|
||||
if key in self.__dict__:
|
||||
raise ValueError(f"Key {key} already exists in {self.__class__.__name__}")
|
||||
setattr(self, key, value)
|
||||
|
||||
def update_state(self, name: str, value: Any):
|
||||
if not self.vertex:
|
||||
raise ValueError("Vertex is not set")
|
||||
|
|
@ -493,4 +483,3 @@ class CustomComponent(Component):
|
|||
Any: The result of the build process.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from pydantic import BaseModel
|
|||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.custom.code_parser.utils import extract_inner_type
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.custom.directory_reader.utils import (
|
||||
abuild_custom_component_list_from_path,
|
||||
build_custom_component_list_from_path,
|
||||
|
|
@ -24,7 +25,7 @@ from langflow.field_typing.range_spec import RangeSpec
|
|||
from langflow.helpers.custom import format_type
|
||||
from langflow.schema import dotdict
|
||||
from langflow.template.field.base import Input
|
||||
from langflow.template.frontend_node.custom_components import CustomComponentFrontendNode
|
||||
from langflow.template.frontend_node.custom_components import ComponentFrontendNode, CustomComponentFrontendNode
|
||||
from langflow.utils import validate
|
||||
from langflow.utils.util import get_base_classes
|
||||
|
||||
|
|
@ -325,7 +326,7 @@ def build_custom_component_template_from_inputs(
|
|||
custom_component: CustomComponent, user_id: Optional[Union[str, UUID]] = None
|
||||
):
|
||||
# The List of Inputs fills the role of the build_config and the entrypoint_args
|
||||
frontend_node = CustomComponentFrontendNode.from_inputs(**custom_component.template_config)
|
||||
frontend_node = ComponentFrontendNode.from_inputs(**custom_component.template_config)
|
||||
field_config = run_build_inputs(
|
||||
custom_component,
|
||||
user_id=user_id,
|
||||
|
|
@ -336,6 +337,8 @@ def build_custom_component_template_from_inputs(
|
|||
return_types = custom_component.get_method_return_type(output.method)
|
||||
return_types = [format_type(return_type) for return_type in return_types]
|
||||
output.add_types(return_types)
|
||||
# ! This should be removed when we have a better way to handle this
|
||||
frontend_node.get_base_classes_from_outputs()
|
||||
|
||||
return frontend_node.to_dict(add_name=False), custom_component
|
||||
|
||||
|
|
@ -384,7 +387,7 @@ def create_component_template(component):
|
|||
component_code = component["code"]
|
||||
component_output_types = component["output_types"]
|
||||
|
||||
component_extractor = CustomComponent(code=component_code)
|
||||
component_extractor = Component(code=component_code)
|
||||
|
||||
component_template, _ = build_custom_component_template(component_extractor)
|
||||
if not component_template["output_types"] and component_output_types:
|
||||
|
|
|
|||
|
|
@ -11,10 +11,11 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class SourceHandle(BaseModel):
|
||||
baseClasses: List[str] = Field(..., description="List of base classes for the source handle.")
|
||||
baseClasses: Optional[List[str]] = Field(None, description="List of base classes for the source handle.")
|
||||
dataType: str = Field(..., description="Data type for the source handle.")
|
||||
id: str = Field(..., description="Unique identifier for the source handle.")
|
||||
conditionalPath: Optional[bool] = Field(None, description="Conditional path for the source handle.")
|
||||
name: str = Field(..., description="Name of the source handle.")
|
||||
output_types: List[str] = Field(..., description="List of output types for the source handle.")
|
||||
|
||||
|
||||
class TargetHandle(BaseModel):
|
||||
|
|
@ -49,11 +50,11 @@ class Edge:
|
|||
|
||||
def validate_handles(self, source, target) -> None:
|
||||
if self.target_handle.inputTypes is None:
|
||||
self.valid_handles = self.target_handle.type in self.source_handle.baseClasses
|
||||
self.valid_handles = self.target_handle.type in self.source_handle.output_types
|
||||
else:
|
||||
self.valid_handles = (
|
||||
any(baseClass in self.target_handle.inputTypes for baseClass in self.source_handle.baseClasses)
|
||||
or self.target_handle.type in self.source_handle.baseClasses
|
||||
any(output_type in self.target_handle.inputTypes for output_type in self.source_handle.output_types)
|
||||
or self.target_handle.type in self.source_handle.output_types
|
||||
)
|
||||
if not self.valid_handles:
|
||||
logger.debug(self.source_handle)
|
||||
|
|
@ -70,16 +71,29 @@ class Edge:
|
|||
def validate_edge(self, source, target) -> None:
|
||||
# Validate that the outputs of the source node are valid inputs
|
||||
# for the target node
|
||||
self.source_types = source.output
|
||||
# .outputs is a list of Output objects as dictionaries
|
||||
# meaning: check for "types" key in each dictionary
|
||||
self.source_types = [output for output in source.outputs if output["name"] == self.source_handle.name]
|
||||
self.target_reqs = target.required_inputs + target.optional_inputs
|
||||
# Both lists contain strings and sometimes a string contains the value we are
|
||||
# looking for e.g. comgin_out=["Chain"] and target_reqs=["LLMChain"]
|
||||
# so we need to check if any of the strings in source_types is in target_reqs
|
||||
self.valid = any(output in target_req for output in self.source_types for target_req in self.target_reqs)
|
||||
self.valid = any(
|
||||
any(output_type in target_req for output_type in output["types"])
|
||||
for output in self.source_types
|
||||
for target_req in self.target_reqs
|
||||
)
|
||||
# Get what type of input the target node is expecting
|
||||
|
||||
# Update the matched type to be the first found match
|
||||
self.matched_type = next(
|
||||
(output for output in self.source_types if output in self.target_reqs),
|
||||
(
|
||||
output_type
|
||||
for output in self.source_types
|
||||
for output_type in output["types"]
|
||||
for target_req in self.target_reqs
|
||||
if output_type in target_req
|
||||
),
|
||||
None,
|
||||
)
|
||||
no_matched_type = self.matched_type is None
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
File diff suppressed because one or more lines are too long
|
|
@ -0,0 +1,10 @@
|
|||
from langflow.template.field.base import Input, Output
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.template.base import Template
|
||||
|
||||
__all__ = [
|
||||
"Input",
|
||||
"Output",
|
||||
"FrontendNode",
|
||||
"Template",
|
||||
]
|
||||
|
|
@ -101,6 +101,9 @@ class FrontendNode(BaseModel):
|
|||
def add_extra_base_classes(self) -> None:
|
||||
pass
|
||||
|
||||
def get_base_classes_from_outputs(self) -> list[str]:
|
||||
self.base_classes = [output_type for output in self.outputs for output_type in output.types]
|
||||
|
||||
def add_base_class(self, base_class: Union[str, List[str]]) -> None:
|
||||
"""Adds a base class to the frontend node."""
|
||||
if isinstance(base_class, str):
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ export default function ParameterComponent({
|
|||
index,
|
||||
outputName,
|
||||
}: ParameterComponentType): JSX.Element {
|
||||
console.log("title", title);
|
||||
console.log("data", data);
|
||||
const infoHtml = useRef<HTMLDivElement & ReactNode>(null);
|
||||
const nodes = useFlowStore((state) => state.nodes);
|
||||
const edges = useFlowStore((state) => state.edges);
|
||||
|
|
@ -72,7 +74,6 @@ export default function ParameterComponent({
|
|||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const [errorDuplicateKey, setErrorDuplicateKey] = useState(false);
|
||||
const setFilterEdge = useFlowStore((state) => state.setFilterEdge);
|
||||
|
||||
const { handleOnNewValue: handleOnNewValueHook } = useHandleOnNewValue(
|
||||
data,
|
||||
name,
|
||||
|
|
|
|||
|
|
@ -55,14 +55,14 @@ export default function GenericNode({
|
|||
const [nodeName, setNodeName] = useState(data.node!.display_name);
|
||||
const [inputDescription, setInputDescription] = useState(false);
|
||||
const [nodeDescription, setNodeDescription] = useState(
|
||||
data.node?.description!,
|
||||
data.node?.description!
|
||||
);
|
||||
const [isOutdated, setIsOutdated] = useState(false);
|
||||
const buildStatus = useFlowStore(
|
||||
(state) => state.flowBuildStatus[data.id]?.status,
|
||||
(state) => state.flowBuildStatus[data.id]?.status
|
||||
);
|
||||
const lastRunTime = useFlowStore(
|
||||
(state) => state.flowBuildStatus[data.id]?.timestamp,
|
||||
(state) => state.flowBuildStatus[data.id]?.timestamp
|
||||
);
|
||||
const [validationStatus, setValidationStatus] =
|
||||
useState<validationStatusType | null>(null);
|
||||
|
|
@ -115,7 +115,7 @@ export default function GenericNode({
|
|||
|
||||
updateNodeInternals(data.id);
|
||||
},
|
||||
[data.id, data.node, setNode, setIsOutdated],
|
||||
[data.id, data.node, setNode, setIsOutdated]
|
||||
);
|
||||
|
||||
if (!data.node!.template) {
|
||||
|
|
@ -255,7 +255,7 @@ export default function GenericNode({
|
|||
const isDark = useDarkStore((state) => state.dark);
|
||||
const renderIconStatus = (
|
||||
buildStatus: BuildStatus | undefined,
|
||||
validationStatus: validationStatusType | null,
|
||||
validationStatus: validationStatusType | null
|
||||
) => {
|
||||
if (buildStatus === BuildStatus.BUILDING) {
|
||||
return <Loading className="text-medium-indigo" />;
|
||||
|
|
@ -296,7 +296,7 @@ export default function GenericNode({
|
|||
};
|
||||
const getSpecificClassFromBuildStatus = (
|
||||
buildStatus: BuildStatus | undefined,
|
||||
validationStatus: validationStatusType | null,
|
||||
validationStatus: validationStatusType | null
|
||||
) => {
|
||||
let isInvalid = validationStatus && !validationStatus.valid;
|
||||
|
||||
|
|
@ -320,11 +320,11 @@ export default function GenericNode({
|
|||
selected: boolean,
|
||||
showNode: boolean,
|
||||
buildStatus: BuildStatus | undefined,
|
||||
validationStatus: validationStatusType | null,
|
||||
validationStatus: validationStatusType | null
|
||||
) => {
|
||||
const specificClassFromBuildStatus = getSpecificClassFromBuildStatus(
|
||||
buildStatus,
|
||||
validationStatus,
|
||||
validationStatus
|
||||
);
|
||||
|
||||
const baseBorderClass = getBaseBorderClass(selected);
|
||||
|
|
@ -333,7 +333,7 @@ export default function GenericNode({
|
|||
baseBorderClass,
|
||||
nodeSizeClass,
|
||||
"generic-node-div",
|
||||
specificClassFromBuildStatus,
|
||||
specificClassFromBuildStatus
|
||||
);
|
||||
return names;
|
||||
};
|
||||
|
|
@ -393,7 +393,7 @@ export default function GenericNode({
|
|||
selected,
|
||||
showNode,
|
||||
buildStatus,
|
||||
validationStatus,
|
||||
validationStatus
|
||||
)}
|
||||
>
|
||||
{data.node?.beta && showNode && (
|
||||
|
|
@ -538,7 +538,7 @@ export default function GenericNode({
|
|||
}
|
||||
title={getFieldTitle(
|
||||
data.node?.template!,
|
||||
templateField,
|
||||
templateField
|
||||
)}
|
||||
info={data.node?.template[templateField].info}
|
||||
name={templateField}
|
||||
|
|
@ -566,7 +566,7 @@ export default function GenericNode({
|
|||
proxy={data.node?.template[templateField].proxy}
|
||||
showNode={showNode}
|
||||
/>
|
||||
),
|
||||
)
|
||||
)}
|
||||
{/* <ParameterComponent
|
||||
index={0}
|
||||
|
|
@ -725,7 +725,7 @@ export default function GenericNode({
|
|||
!data.node?.description) &&
|
||||
nameEditable
|
||||
? "font-light italic"
|
||||
: "",
|
||||
: ""
|
||||
)}
|
||||
onDoubleClick={(e) => {
|
||||
setInputDescription(true);
|
||||
|
|
@ -787,13 +787,13 @@ export default function GenericNode({
|
|||
}
|
||||
title={getFieldTitle(
|
||||
data.node?.template!,
|
||||
templateField,
|
||||
templateField
|
||||
)}
|
||||
info={data.node?.template[templateField].info}
|
||||
name={templateField}
|
||||
tooltipTitle={
|
||||
data.node?.template[templateField].input_types?.join(
|
||||
"\n",
|
||||
"\n"
|
||||
) ?? data.node?.template[templateField].type
|
||||
}
|
||||
required={data.node!.template[templateField].required}
|
||||
|
|
@ -820,7 +820,7 @@ export default function GenericNode({
|
|||
<div
|
||||
className={classNames(
|
||||
Object.keys(data.node!.template).length < 1 ? "hidden" : "",
|
||||
"flex-max-width justify-center",
|
||||
"flex-max-width justify-center"
|
||||
)}
|
||||
>
|
||||
{" "}
|
||||
|
|
@ -842,7 +842,7 @@ export default function GenericNode({
|
|||
nodeColors[types[data.type]] ??
|
||||
nodeColors.unknown
|
||||
}
|
||||
title={output.selected ?? output.types[0]}
|
||||
title={output.name}
|
||||
tooltipTitle={output.selected ?? output.types[0]}
|
||||
id={{
|
||||
output_types: [output.selected ?? output.types[0]],
|
||||
|
|
|
|||
|
|
@ -4,10 +4,9 @@ from uuid import uuid4
|
|||
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langflow.custom import CustomComponent
|
||||
from langflow.custom.code_parser.code_parser import CodeParser, CodeSyntaxError
|
||||
from langflow.custom.custom_component.component import Component, ComponentCodeNullError
|
||||
from langflow.custom.custom_component.base_component import BaseComponent, ComponentCodeNullError
|
||||
from langflow.custom.utils import build_custom_component_template
|
||||
from langflow.services.database.models.flow import Flow, FlowCreate
|
||||
|
||||
|
|
@ -77,7 +76,7 @@ def test_component_init():
|
|||
"""
|
||||
Test the initialization of the Component class.
|
||||
"""
|
||||
component = Component(code=code_default, function_entrypoint_name="build")
|
||||
component = BaseComponent(code=code_default, function_entrypoint_name="build")
|
||||
assert component.code == code_default
|
||||
assert component.function_entrypoint_name == "build"
|
||||
|
||||
|
|
@ -86,7 +85,7 @@ def test_component_get_code_tree():
|
|||
"""
|
||||
Test the get_code_tree method of the Component class.
|
||||
"""
|
||||
component = Component(code=code_default, function_entrypoint_name="build")
|
||||
component = BaseComponent(code=code_default, function_entrypoint_name="build")
|
||||
tree = component.get_code_tree(component.code)
|
||||
assert "imports" in tree
|
||||
|
||||
|
|
@ -96,7 +95,7 @@ def test_component_code_null_error():
|
|||
Test the get_function method raises the
|
||||
ComponentCodeNullError when the code is empty.
|
||||
"""
|
||||
component = Component(code="", function_entrypoint_name="")
|
||||
component = BaseComponent(code="", function_entrypoint_name="")
|
||||
with pytest.raises(ComponentCodeNullError):
|
||||
component.get_function()
|
||||
|
||||
|
|
@ -200,7 +199,7 @@ def test_component_get_function_valid():
|
|||
Test the get_function method of the Component
|
||||
class with valid code and function_entrypoint_name.
|
||||
"""
|
||||
component = Component(code="def build(): pass", function_entrypoint_name="build")
|
||||
component = BaseComponent(code="def build(): pass", function_entrypoint_name="build")
|
||||
my_function = component.get_function()
|
||||
assert callable(my_function)
|
||||
|
||||
|
|
@ -357,7 +356,7 @@ def test_component_get_code_tree_syntax_error():
|
|||
Test the get_code_tree method of the Component class
|
||||
raises the CodeSyntaxError when given incorrect syntax.
|
||||
"""
|
||||
component = Component(code="import os as", function_entrypoint_name="build")
|
||||
component = BaseComponent(code="import os as", function_entrypoint_name="build")
|
||||
with pytest.raises(CodeSyntaxError):
|
||||
component.get_code_tree(component.code)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue