langflow/src/backend/langflow/interface/custom/custom_component.py
Gabriel Luiz Freitas Almeida 3d8d44538f Merge branch 'release' into dev
2023-08-10 10:10:28 -03:00

214 lines
7.3 KiB
Python

from typing import Any, Callable, List, Optional
from fastapi import HTTPException
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
from langflow.interface.custom.component import Component
from langflow.interface.custom.directory_reader import DirectoryReader
from langflow.services.utils import get_db_manager
from langflow.utils import validate
from langflow.services.database.utils import session_getter
from langflow.services.database.models.flow import Flow
from pydantic import Extra
import yaml
class CustomComponent(Component, extra=Extra.allow):
code: Optional[str]
field_config: dict = {}
code_class_base_inheritance = "CustomComponent"
function_entrypoint_name = "build"
function: Optional[Callable] = None
return_type_valid_list = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
repr_value: Optional[str] = ""
def __init__(self, **data):
super().__init__(**data)
def custom_repr(self):
if isinstance(self.repr_value, dict):
return yaml.dump(self.repr_value)
if isinstance(self.repr_value, str):
return self.repr_value
return str(self.repr_value)
def build_config(self):
return self.field_config
def _class_template_validation(self, code: str):
TYPE_HINT_LIST = ["Optional", "Prompt", "PromptTemplate", "LLMChain"]
if not code:
raise HTTPException(
status_code=400,
detail={
"error": self.ERROR_CODE_NULL,
"traceback": "",
},
)
reader = DirectoryReader("", False)
for type_hint in TYPE_HINT_LIST:
if reader._is_type_hint_used_in_args(
"Optional", code
) and not reader._is_type_hint_imported("Optional", code):
error_detail = {
"error": "Type hint Error",
"traceback": f"Type hint '{type_hint}' is used but not imported in the code.",
}
raise HTTPException(status_code=400, detail=error_detail)
def is_check_valid(self) -> bool:
return self._class_template_validation(self.code) if self.code else False
def get_code_tree(self, code: str):
return super().get_code_tree(code)
@property
def get_function_entrypoint_args(self) -> str:
if not self.code:
return ""
tree = self.get_code_tree(self.code)
component_classes = [
cls
for cls in tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return ""
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
if not build_methods:
return ""
build_method = build_methods[0]
return build_method["args"]
@property
def get_function_entrypoint_return_type(self) -> List[str]:
if not self.code:
return []
tree = self.get_code_tree(self.code)
component_classes = [
cls
for cls in tree["classes"]
if self.code_class_base_inheritance in cls["bases"]
]
if not component_classes:
return []
# Assume the first Component class is the one we're interested in
component_class = component_classes[0]
build_methods = [
method
for method in component_class["methods"]
if method["name"] == self.function_entrypoint_name
]
if not build_methods:
return []
build_method = build_methods[0]
return_type = build_method["return_type"]
if not return_type:
return []
# If the return type is not a Union, then we just return it as a list
if "Union" not in return_type:
return [return_type] if return_type in self.return_type_valid_list else []
# If the return type is a Union, then we need to parse it
return_type = return_type.replace("Union", "").replace("[", "").replace("]", "")
return_type = return_type.split(",")
return_type = [item.strip() for item in return_type]
return [item for item in return_type if item in self.return_type_valid_list]
@property
def get_main_class_name(self):
tree = self.get_code_tree(self.code)
base_name = self.code_class_base_inheritance
method_name = self.function_entrypoint_name
classes = []
for item in tree.get("classes"):
if base_name in item["bases"]:
method_names = [method["name"] for method in item["methods"]]
if method_name in method_names:
classes.append(item["name"])
# Get just the first item
return next(iter(classes), "")
@property
def build_template_config(self):
tree = self.get_code_tree(self.code)
attributes = [
main_class["attributes"]
for main_class in tree.get("classes")
if main_class["name"] == self.get_main_class_name
]
# Get just the first item
attributes = next(iter(attributes), [])
return super().build_template_config(attributes)
@property
def get_function(self):
return validate.create_function(self.code, self.function_entrypoint_name)
def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
from langflow.processing.process import build_sorted_vertices_with_caching
from langflow.processing.process import process_tweaks
db_manager = get_db_manager()
with session_getter(db_manager) as session:
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
if not graph_data:
raise ValueError(f"Flow {flow_id} not found")
if tweaks:
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
return build_sorted_vertices_with_caching(graph_data)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
get_session = get_session or session_getter
db_manager = get_db_manager()
with get_session(db_manager) as session:
flows = session.query(Flow).all()
return flows
def get_flow(
self,
*,
flow_name: Optional[str] = None,
flow_id: Optional[str] = None,
tweaks: Optional[dict] = None,
get_session: Optional[Callable] = None,
) -> Flow:
get_session = get_session or session_getter
db_manager = get_db_manager()
with get_session(db_manager) as session:
if flow_id:
flow = session.query(Flow).get(flow_id)
elif flow_name:
flow = session.query(Flow).filter(Flow.name == flow_name).first()
else:
raise ValueError("Either flow_name or flow_id must be provided")
if not flow:
raise ValueError(f"Flow {flow_name or flow_id} not found")
return self.load_flow(flow.id, tweaks)
def build(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError