From 4307c24c6df62e944326cf42fc6dd576b271c861 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 21 Jul 2023 09:25:55 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix(custom=5Fcomponent.py):=20fi?= =?UTF-8?q?x=20session=20handling=20in=20load=5Fflow=20method=20to=20ensur?= =?UTF-8?q?e=20proper=20context=20management=20and=20avoid=20potential=20r?= =?UTF-8?q?esource=20leaks=20=E2=9C=A8=20feat(custom=5Fcomponent.py):=20ad?= =?UTF-8?q?d=20custom=5Frepr=20method=20to=20CustomComponent=20class=20to?= =?UTF-8?q?=20provide=20a=20custom=20representation=20value=20for=20the=20?= =?UTF-8?q?component?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/interface/custom/custom_component.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index b03e8a560..f32c82ef9 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -6,21 +6,26 @@ from langflow.interface.custom.component import Component from langflow.utils import validate from uuid import UUID -from langflow.database.base import get_session +from langflow.database.base import session_getter from langflow.database.models.flow import Flow +from pydantic import Extra -class CustomComponent(Component): +class CustomComponent(Component, extra=Extra.allow): code: Optional[str] field_config: dict = {} code_class_base_inheritance = "CustomComponent" function_entrypoint_name = "build" function: Optional[Callable] = None return_type_valid_list = list(LANGCHAIN_BASE_TYPES.keys()) + repr_value: Optional[str] = "" def __init__(self, **data): super().__init__(**data) + def custom_repr(self): + return self.repr_value + def _class_template_validation(self, code: str) -> bool: if not code: raise HTTPException( @@ -133,8 +138,8 @@ class CustomComponent(Component): def load_flow(self, flow_id: UUID = None): from langflow.processing.process import build_sorted_vertices_with_caching - session = next(get_session()) - data_graph = flow.data if (flow := session.get(Flow, flow_id)) else None + with session_getter() as session: + data_graph = flow.data if (flow := session.get(Flow, flow_id)) else None if not data_graph: raise ValueError(f"Flow {flow_id} not found") return build_sorted_vertices_with_caching(data_graph)