214 lines
7.3 KiB
Python
214 lines
7.3 KiB
Python
from typing import Any, Callable, List, Optional
|
|
from fastapi import HTTPException
|
|
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
|
from langflow.interface.custom.component import Component
|
|
from langflow.interface.custom.directory_reader import DirectoryReader
|
|
from langflow.services.utils import get_db_manager
|
|
|
|
from langflow.utils import validate
|
|
|
|
from langflow.services.database.utils import session_getter
|
|
from langflow.services.database.models.flow import Flow
|
|
from pydantic import Extra
|
|
import yaml
|
|
|
|
|
|
class CustomComponent(Component, extra=Extra.allow):
|
|
code: Optional[str]
|
|
field_config: dict = {}
|
|
code_class_base_inheritance = "CustomComponent"
|
|
function_entrypoint_name = "build"
|
|
function: Optional[Callable] = None
|
|
return_type_valid_list = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
|
|
repr_value: Optional[str] = ""
|
|
|
|
def __init__(self, **data):
|
|
super().__init__(**data)
|
|
|
|
def custom_repr(self):
|
|
if isinstance(self.repr_value, dict):
|
|
return yaml.dump(self.repr_value)
|
|
if isinstance(self.repr_value, str):
|
|
return self.repr_value
|
|
return str(self.repr_value)
|
|
|
|
def build_config(self):
|
|
return self.field_config
|
|
|
|
def _class_template_validation(self, code: str):
|
|
TYPE_HINT_LIST = ["Optional", "Prompt", "PromptTemplate", "LLMChain"]
|
|
|
|
if not code:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={
|
|
"error": self.ERROR_CODE_NULL,
|
|
"traceback": "",
|
|
},
|
|
)
|
|
|
|
reader = DirectoryReader("", False)
|
|
|
|
for type_hint in TYPE_HINT_LIST:
|
|
if reader._is_type_hint_used_in_args(
|
|
"Optional", code
|
|
) and not reader._is_type_hint_imported("Optional", code):
|
|
error_detail = {
|
|
"error": "Type hint Error",
|
|
"traceback": f"Type hint '{type_hint}' is used but not imported in the code.",
|
|
}
|
|
raise HTTPException(status_code=400, detail=error_detail)
|
|
|
|
def is_check_valid(self) -> bool:
|
|
return self._class_template_validation(self.code) if self.code else False
|
|
|
|
def get_code_tree(self, code: str):
|
|
return super().get_code_tree(code)
|
|
|
|
@property
|
|
def get_function_entrypoint_args(self) -> str:
|
|
if not self.code:
|
|
return ""
|
|
tree = self.get_code_tree(self.code)
|
|
|
|
component_classes = [
|
|
cls
|
|
for cls in tree["classes"]
|
|
if self.code_class_base_inheritance in cls["bases"]
|
|
]
|
|
if not component_classes:
|
|
return ""
|
|
|
|
# Assume the first Component class is the one we're interested in
|
|
component_class = component_classes[0]
|
|
build_methods = [
|
|
method
|
|
for method in component_class["methods"]
|
|
if method["name"] == self.function_entrypoint_name
|
|
]
|
|
|
|
if not build_methods:
|
|
return ""
|
|
|
|
build_method = build_methods[0]
|
|
|
|
return build_method["args"]
|
|
|
|
@property
|
|
def get_function_entrypoint_return_type(self) -> List[str]:
|
|
if not self.code:
|
|
return []
|
|
tree = self.get_code_tree(self.code)
|
|
|
|
component_classes = [
|
|
cls
|
|
for cls in tree["classes"]
|
|
if self.code_class_base_inheritance in cls["bases"]
|
|
]
|
|
if not component_classes:
|
|
return []
|
|
|
|
# Assume the first Component class is the one we're interested in
|
|
component_class = component_classes[0]
|
|
build_methods = [
|
|
method
|
|
for method in component_class["methods"]
|
|
if method["name"] == self.function_entrypoint_name
|
|
]
|
|
|
|
if not build_methods:
|
|
return []
|
|
|
|
build_method = build_methods[0]
|
|
return_type = build_method["return_type"]
|
|
if not return_type:
|
|
return []
|
|
# If the return type is not a Union, then we just return it as a list
|
|
if "Union" not in return_type:
|
|
return [return_type] if return_type in self.return_type_valid_list else []
|
|
|
|
# If the return type is a Union, then we need to parse it
|
|
return_type = return_type.replace("Union", "").replace("[", "").replace("]", "")
|
|
return_type = return_type.split(",")
|
|
return_type = [item.strip() for item in return_type]
|
|
return [item for item in return_type if item in self.return_type_valid_list]
|
|
|
|
@property
|
|
def get_main_class_name(self):
|
|
tree = self.get_code_tree(self.code)
|
|
|
|
base_name = self.code_class_base_inheritance
|
|
method_name = self.function_entrypoint_name
|
|
|
|
classes = []
|
|
for item in tree.get("classes"):
|
|
if base_name in item["bases"]:
|
|
method_names = [method["name"] for method in item["methods"]]
|
|
if method_name in method_names:
|
|
classes.append(item["name"])
|
|
|
|
# Get just the first item
|
|
return next(iter(classes), "")
|
|
|
|
@property
|
|
def build_template_config(self):
|
|
tree = self.get_code_tree(self.code)
|
|
|
|
attributes = [
|
|
main_class["attributes"]
|
|
for main_class in tree.get("classes")
|
|
if main_class["name"] == self.get_main_class_name
|
|
]
|
|
# Get just the first item
|
|
attributes = next(iter(attributes), [])
|
|
|
|
return super().build_template_config(attributes)
|
|
|
|
@property
|
|
def get_function(self):
|
|
return validate.create_function(self.code, self.function_entrypoint_name)
|
|
|
|
def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
|
|
from langflow.processing.process import build_sorted_vertices_with_caching
|
|
from langflow.processing.process import process_tweaks
|
|
|
|
db_manager = get_db_manager()
|
|
with session_getter(db_manager) as session:
|
|
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
|
|
if not graph_data:
|
|
raise ValueError(f"Flow {flow_id} not found")
|
|
if tweaks:
|
|
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
|
|
return build_sorted_vertices_with_caching(graph_data)
|
|
|
|
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
|
|
get_session = get_session or session_getter
|
|
db_manager = get_db_manager()
|
|
with get_session(db_manager) as session:
|
|
flows = session.query(Flow).all()
|
|
return flows
|
|
|
|
def get_flow(
|
|
self,
|
|
*,
|
|
flow_name: Optional[str] = None,
|
|
flow_id: Optional[str] = None,
|
|
tweaks: Optional[dict] = None,
|
|
get_session: Optional[Callable] = None,
|
|
) -> Flow:
|
|
get_session = get_session or session_getter
|
|
db_manager = get_db_manager()
|
|
with get_session(db_manager) as session:
|
|
if flow_id:
|
|
flow = session.query(Flow).get(flow_id)
|
|
elif flow_name:
|
|
flow = session.query(Flow).filter(Flow.name == flow_name).first()
|
|
else:
|
|
raise ValueError("Either flow_name or flow_id must be provided")
|
|
|
|
if not flow:
|
|
raise ValueError(f"Flow {flow_name or flow_id} not found")
|
|
return self.load_flow(flow.id, tweaks)
|
|
|
|
def build(self, *args: Any, **kwargs: Any) -> Any:
|
|
raise NotImplementedError
|