diff --git a/src/backend/langflow/helpers/flow.py b/src/backend/langflow/helpers/flow.py new file mode 100644 index 000000000..f81e68915 --- /dev/null +++ b/src/backend/langflow/helpers/flow.py @@ -0,0 +1,75 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Union + +from sqlmodel import select + +from langflow.schema.schema import INPUT_FIELD_NAME, Record +from langflow.services.database.models.flow.model import Flow +from langflow.services.deps import session_scope + +if TYPE_CHECKING: + from langflow.graph.graph.base import Graph + + +def list_flows(*, user_id: Optional[str] = None) -> List[Record]: + if not user_id: + raise ValueError("Session is invalid") + try: + with session_scope() as session: + flows = session.exec( + select(Flow).where(Flow.user_id == user_id).where(Flow.is_component == False) # noqa + ).all() + + flows_records = [flow.to_record() for flow in flows] + return flows_records + except Exception as e: + raise ValueError(f"Error listing flows: {e}") + + +async def load_flow(flow_id: str, tweaks: Optional[dict] = None) -> "Graph": + from langflow.graph.graph.base import Graph + from langflow.processing.process import process_tweaks + + with session_scope() as session: + graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None + if not graph_data: + raise ValueError(f"Flow {flow_id} not found") + if tweaks: + graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks) + graph = Graph.from_payload(graph_data, flow_id=flow_id) + return graph + + +async def run_flow( + inputs: Union[dict, List[dict]] = None, + flow_id: Optional[str] = None, + flow_name: Optional[str] = None, + tweaks: Optional[dict] = None, + flows_records: Optional[List[Record]] = None, +) -> Any: + if not flow_id and not flow_name: + raise ValueError("Flow ID or Flow Name is required") + if not flows_records: + flows_records = list_flows() + if not flow_id and flows_records: + flow_ids = [flow.data["id"] for flow in flows_records if flow.data["name"] == flow_name] + if not flow_ids: + raise ValueError(f"Flow {flow_name} not found") + elif len(flow_ids) > 1: + raise ValueError(f"Multiple flows found with the name {flow_name}") + flow_id = flow_ids[0] + + if not flow_id: + raise ValueError(f"Flow {flow_name} not found") + graph = await load_flow(flow_id, tweaks) + + if inputs is None: + inputs = [] + inputs_list = [] + inputs_components = [] + types = [] + for input_dict in inputs: + inputs_list.append({INPUT_FIELD_NAME: input_dict.get("input_value")}) + inputs_components.append(input_dict.get("components", [])) + types.append(input_dict.get("type", [])) + + return await graph.arun(inputs_list, inputs_components=inputs_components, types=types) diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index 7d90263ef..42d837157 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -7,24 +7,22 @@ import yaml from cachetools import TTLCache, cachedmethod from langchain_core.documents import Document from pydantic import BaseModel -from sqlmodel import select +from langflow.helpers.flow import list_flows, load_flow, run_flow from langflow.interface.custom.code_parser.utils import ( extract_inner_type_from_generic_alias, extract_union_types_from_generic_alias, ) from langflow.interface.custom.custom_component.component import Component -from langflow.schema import Record -from langflow.schema.dotdict import dotdict -from langflow.services.database.models.flow import Flow -from langflow.services.database.utils import session_getter -from langflow.services.deps import get_credential_service, get_db_service, get_storage_service -from langflow.services.storage.service import StorageService +from langflow.schema import dotdict +from langflow.schema.schema import Record +from langflow.services.deps import get_credential_service, get_storage_service, session_scope from langflow.utils import validate if TYPE_CHECKING: from langflow.graph.graph.base import Graph from langflow.graph.vertex.base import Vertex + from langflow.services.storage.service import StorageService class CustomComponent(Component): @@ -293,8 +291,8 @@ class CustomComponent(Component): raise ValueError(f"User id is not set for {self.__class__.__name__}") credential_service = get_credential_service() # Get service instance # Retrieve and decrypt the credential by name for the current user - db_service = get_db_service() - with session_getter(db_service) as session: + + with session_scope() as session: return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session) return get_credential @@ -303,8 +301,8 @@ class CustomComponent(Component): if hasattr(self, "_user_id") and not self._user_id: raise ValueError(f"User id is not set for {self.__class__.__name__}") credential_service = get_credential_service() - db_service = get_db_service() - with session_getter(db_service) as session: + + with session_scope() as session: return credential_service.list_credentials(user_id=self._user_id, session=session) def index(self, value: int = 0): @@ -319,60 +317,22 @@ class CustomComponent(Component): return validate.create_function(self.code, self.function_entrypoint_name) async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph": - from langflow.graph.graph.base import Graph - from langflow.processing.process import process_tweaks - - db_service = get_db_service() - with session_getter(db_service) as session: - graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None - if not graph_data: - raise ValueError(f"Flow {flow_id} not found") - if tweaks: - graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks) - graph = Graph.from_payload(graph_data, flow_id=flow_id) - return graph + return await load_flow(flow_id, tweaks) async def run_flow( self, - input_value: Union[str, list[str]], + inputs: Union[dict, List[dict]] = None, flow_id: Optional[str] = None, flow_name: Optional[str] = None, tweaks: Optional[dict] = None, ) -> Any: - if not flow_id and not flow_name: - raise ValueError("Flow ID or Flow Name is required") - if not self._flows_records: - self.list_flows() - if not flow_id and self._flows_records: - flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name] - if not flow_ids: - raise ValueError(f"Flow {flow_name} not found") - elif len(flow_ids) > 1: - raise ValueError(f"Multiple flows found with the name {flow_name}") - flow_id = flow_ids[0] + return await run_flow(inputs=inputs, flow_id=flow_id, flow_name=flow_name, tweaks=tweaks) - if not flow_id: - raise ValueError(f"Flow {flow_name} not found") - if isinstance(input_value, str): - input_value = [input_value] - graph = await self.load_flow(flow_id, tweaks) - input_value_dict = [{"input_value": input_val} for input_val in input_value] - return await graph.arun(input_value_dict, stream=False) - - def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]: + def list_flows(self) -> List[Record]: if not self._user_id: raise ValueError("Session is invalid") try: - get_session = get_session or session_getter - db_service = get_db_service() - with get_session(db_service) as session: - flows = session.exec( - select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa - ).all() - - flows_records = [flow.to_record() for flow in flows] - self._flows_records = flows_records - return flows_records + return list_flows(user_id=self._user_id) except Exception as e: raise ValueError(f"Error listing flows: {e}")