🔧 fix(types.py): import correct module for custom component creator
🔧 fix(types.py): fix import for CustomComponent class 🔧 fix(types.py): remove unused imports 🔧 fix(types.py): fix function signature for add_new_custom_field 🔧 fix(types.py): fix function signature for build_langchain_template_custom_component 🔧 fix(types.py): fix return type validation and error handling in build_langchain_template_custom_component 🔧 fix(types.py): fix appending base classes to frontend_node in build_langchain_template_custom_component 🔧 fix(types.py): fix return statement in build_langchain_template_custom_component The changes in this commit fix import statements, function signatures, and error handling in the types.py file. The correct module is now imported for the custom component creator. The import for the CustomComponent class is fixed. Unused imports are removed. The function signature for add_new_custom_field is fixed to include the field_config parameter. The function signature for build_langchain_template_custom_component is fixed to include the field_config parameter. The return type validation and error handling in build_langchain_template_custom_component are fixed to handle invalid return types. The base classes are correctly appended to the frontend_node in build_langchain_template_custom_component. The return statement in build_langchain_template_custom_component is fixed to return the frontend_node.
This commit is contained in:
parent
f41cd1905f
commit
024cd3398a
1 changed files with 47 additions and 23 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.custom.constants import LANGCHAIN_BASE_TYPES
|
||||
from langflow.interface.document_loaders.base import documentloader_creator
|
||||
from langflow.interface.embeddings.base import embedding_creator
|
||||
from langflow.interface.llms.base import llm_creator
|
||||
|
|
@ -12,7 +13,8 @@ from langflow.interface.utilities.base import utility_creator
|
|||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.tools.custom import CustomComponent
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from langflow.interface.custom.custom import CustomComponent
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.tools import CustomComponentNode
|
||||
|
|
@ -24,9 +26,6 @@ from fastapi import HTTPException
|
|||
import traceback
|
||||
|
||||
# Used to get the base_classes list
|
||||
from langchain.chains import ConversationChain # noqa: F401
|
||||
from langchain.llms.base import BaseLLM # noqa: F401
|
||||
from langchain.tools import Tool # noqa: F401
|
||||
|
||||
|
||||
def get_type_list():
|
||||
|
|
@ -62,6 +61,7 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
|||
utility_creator,
|
||||
output_parser_creator,
|
||||
retriever_creator,
|
||||
custom_component_creator,
|
||||
]
|
||||
|
||||
all_types = {}
|
||||
|
|
@ -73,9 +73,16 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
|||
|
||||
|
||||
# TODO: Move to correct place
|
||||
def add_new_custom_field(template, field_name: str, field_type: str):
|
||||
def add_new_custom_field(
|
||||
template, field_name: str, field_type: str, field_config: dict
|
||||
):
|
||||
new_field = TemplateField(
|
||||
name=field_name, field_type=field_type, show=True, required=True, advanced=False
|
||||
name=field_name,
|
||||
field_type=field_type,
|
||||
show=True,
|
||||
required=True,
|
||||
advanced=False,
|
||||
**field_config,
|
||||
)
|
||||
template.get("template")[field_name] = new_field.to_dict()
|
||||
template.get("custom_fields")[field_name] = None
|
||||
|
|
@ -108,28 +115,45 @@ def add_code_field(template, raw_code):
|
|||
|
||||
def build_langchain_template_custom_component(extractor: CustomComponent):
|
||||
# Build base "CustomComponent" template
|
||||
template = CustomComponentNode().to_dict().get(type(extractor).__name__)
|
||||
frontend_node = CustomComponentNode().to_dict().get(type(extractor).__name__)
|
||||
|
||||
function_args, return_type = extractor.args_and_return_type
|
||||
function_args, return_type, template_config = extractor.args_and_return_type
|
||||
|
||||
if "display_name" in template_config and frontend_node is not None:
|
||||
frontend_node["display_name"] = template_config["display_name"]
|
||||
raw_code = extractor.code
|
||||
field_config = template_config.get("field_config", {})
|
||||
if function_args is not None:
|
||||
# Add extra fields
|
||||
for extra_field in function_args:
|
||||
def_field = extra_field[0]
|
||||
def_type = extra_field[1]
|
||||
|
||||
# Add extra fields
|
||||
for extra_field in function_args:
|
||||
def_field = extra_field[0]
|
||||
def_type = extra_field[1]
|
||||
if def_field != "self":
|
||||
# TODO: Validate type - if is possible to render into frontend
|
||||
if not def_type:
|
||||
def_type = "str"
|
||||
config = field_config.get(def_field, {})
|
||||
frontend_node = add_new_custom_field(
|
||||
frontend_node, def_field, def_type, config
|
||||
)
|
||||
|
||||
if def_field != "self":
|
||||
# TODO: Validate type - if is possible to render into frontend
|
||||
if not def_type:
|
||||
def_type = "str"
|
||||
|
||||
template = add_new_custom_field(template, def_field, def_type)
|
||||
|
||||
template = add_code_field(template, raw_code)
|
||||
frontend_node = add_code_field(frontend_node, raw_code)
|
||||
|
||||
# Get base classes from "return_type" and add to template.base_classes
|
||||
try:
|
||||
return_type_instance = globals()[return_type]
|
||||
if return_type not in LANGCHAIN_BASE_TYPES or return_type is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": (
|
||||
"Invalid return type should be one of: "
|
||||
f"{list(LANGCHAIN_BASE_TYPES.keys())}"
|
||||
),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
return_type_instance = LANGCHAIN_BASE_TYPES.get(return_type)
|
||||
base_classes = get_base_classes(return_type_instance)
|
||||
except (KeyError, AttributeError) as err:
|
||||
raise HTTPException(
|
||||
|
|
@ -138,9 +162,9 @@ def build_langchain_template_custom_component(extractor: CustomComponent):
|
|||
) from err
|
||||
|
||||
for base_class in base_classes:
|
||||
template.get("base_classes").append(base_class)
|
||||
frontend_node.get("base_classes").append(base_class)
|
||||
|
||||
return template
|
||||
return frontend_node
|
||||
|
||||
|
||||
langchain_types_dict = build_langchain_types_dict()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue