From 2e7b35ddd5f6652314eca8061b1b0741d8577860 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 31 Jul 2023 17:28:29 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20fix(custom=5Fcomponent.py):=20ch?= =?UTF-8?q?ange=20list=5Fflows=20and=20get=5Fflow=20methods=20to=20accept?= =?UTF-8?q?=20an=20optional=20get=5Fsession=20parameter=20for=20better=20f?= =?UTF-8?q?lexibility=20and=20testability?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/interface/custom/custom_component.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/backend/langflow/interface/custom/custom_component.py b/src/backend/langflow/interface/custom/custom_component.py index 9218a3e1b..0b6f0a732 100644 --- a/src/backend/langflow/interface/custom/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component.py @@ -153,8 +153,9 @@ class CustomComponent(Component, extra=Extra.allow): graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks) return build_sorted_vertices_with_caching(graph_data) - def list_flows(self) -> List[Flow]: - with session_getter() as session: + def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]: + get_session = get_session or session_getter + with get_session() as session: flows = session.query(Flow).all() return flows @@ -164,8 +165,11 @@ class CustomComponent(Component, extra=Extra.allow): flow_name: Optional[str] = None, flow_id: Optional[str] = None, tweaks: Optional[dict] = None, + get_session: Optional[Callable] = None, ) -> Flow: - with session_getter() as session: + get_session = get_session or session_getter + + with get_session() as session: if flow_id: flow = session.query(Flow).get(flow_id) elif flow_name: