Add flow.py module with helper functions

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-23 00:22:11 -03:00
commit 87fd095233
2 changed files with 89 additions and 54 deletions

View file

@ -0,0 +1,75 @@
from typing import TYPE_CHECKING, Any, List, Optional, Union
from sqlmodel import select
from langflow.schema.schema import INPUT_FIELD_NAME, Record
from langflow.services.database.models.flow.model import Flow
from langflow.services.deps import session_scope
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
def list_flows(*, user_id: Optional[str] = None) -> List[Record]:
if not user_id:
raise ValueError("Session is invalid")
try:
with session_scope() as session:
flows = session.exec(
select(Flow).where(Flow.user_id == user_id).where(Flow.is_component == False) # noqa
).all()
flows_records = [flow.to_record() for flow in flows]
return flows_records
except Exception as e:
raise ValueError(f"Error listing flows: {e}")
async def load_flow(flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
from langflow.graph.graph.base import Graph
from langflow.processing.process import process_tweaks
with session_scope() as session:
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
if not graph_data:
raise ValueError(f"Flow {flow_id} not found")
if tweaks:
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
graph = Graph.from_payload(graph_data, flow_id=flow_id)
return graph
async def run_flow(
inputs: Union[dict, List[dict]] = None,
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
tweaks: Optional[dict] = None,
flows_records: Optional[List[Record]] = None,
) -> Any:
if not flow_id and not flow_name:
raise ValueError("Flow ID or Flow Name is required")
if not flows_records:
flows_records = list_flows()
if not flow_id and flows_records:
flow_ids = [flow.data["id"] for flow in flows_records if flow.data["name"] == flow_name]
if not flow_ids:
raise ValueError(f"Flow {flow_name} not found")
elif len(flow_ids) > 1:
raise ValueError(f"Multiple flows found with the name {flow_name}")
flow_id = flow_ids[0]
if not flow_id:
raise ValueError(f"Flow {flow_name} not found")
graph = await load_flow(flow_id, tweaks)
if inputs is None:
inputs = []
inputs_list = []
inputs_components = []
types = []
for input_dict in inputs:
inputs_list.append({INPUT_FIELD_NAME: input_dict.get("input_value")})
inputs_components.append(input_dict.get("components", []))
types.append(input_dict.get("type", []))
return await graph.arun(inputs_list, inputs_components=inputs_components, types=types)

View file

@ -7,24 +7,22 @@ import yaml
from cachetools import TTLCache, cachedmethod
from langchain_core.documents import Document
from pydantic import BaseModel
from sqlmodel import select
from langflow.helpers.flow import list_flows, load_flow, run_flow
from langflow.interface.custom.code_parser.utils import (
extract_inner_type_from_generic_alias,
extract_union_types_from_generic_alias,
)
from langflow.interface.custom.custom_component.component import Component
from langflow.schema import Record
from langflow.schema.dotdict import dotdict
from langflow.services.database.models.flow import Flow
from langflow.services.database.utils import session_getter
from langflow.services.deps import get_credential_service, get_db_service, get_storage_service
from langflow.services.storage.service import StorageService
from langflow.schema import dotdict
from langflow.schema.schema import Record
from langflow.services.deps import get_credential_service, get_storage_service, session_scope
from langflow.utils import validate
if TYPE_CHECKING:
from langflow.graph.graph.base import Graph
from langflow.graph.vertex.base import Vertex
from langflow.services.storage.service import StorageService
class CustomComponent(Component):
@ -293,8 +291,8 @@ class CustomComponent(Component):
raise ValueError(f"User id is not set for {self.__class__.__name__}")
credential_service = get_credential_service() # Get service instance
# Retrieve and decrypt the credential by name for the current user
db_service = get_db_service()
with session_getter(db_service) as session:
with session_scope() as session:
return credential_service.get_credential(user_id=self._user_id or "", name=name, session=session)
return get_credential
@ -303,8 +301,8 @@ class CustomComponent(Component):
if hasattr(self, "_user_id") and not self._user_id:
raise ValueError(f"User id is not set for {self.__class__.__name__}")
credential_service = get_credential_service()
db_service = get_db_service()
with session_getter(db_service) as session:
with session_scope() as session:
return credential_service.list_credentials(user_id=self._user_id, session=session)
def index(self, value: int = 0):
@ -319,60 +317,22 @@ class CustomComponent(Component):
return validate.create_function(self.code, self.function_entrypoint_name)
async def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> "Graph":
from langflow.graph.graph.base import Graph
from langflow.processing.process import process_tweaks
db_service = get_db_service()
with session_getter(db_service) as session:
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
if not graph_data:
raise ValueError(f"Flow {flow_id} not found")
if tweaks:
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
graph = Graph.from_payload(graph_data, flow_id=flow_id)
return graph
return await load_flow(flow_id, tweaks)
async def run_flow(
self,
input_value: Union[str, list[str]],
inputs: Union[dict, List[dict]] = None,
flow_id: Optional[str] = None,
flow_name: Optional[str] = None,
tweaks: Optional[dict] = None,
) -> Any:
if not flow_id and not flow_name:
raise ValueError("Flow ID or Flow Name is required")
if not self._flows_records:
self.list_flows()
if not flow_id and self._flows_records:
flow_ids = [flow.data["id"] for flow in self._flows_records if flow.data["name"] == flow_name]
if not flow_ids:
raise ValueError(f"Flow {flow_name} not found")
elif len(flow_ids) > 1:
raise ValueError(f"Multiple flows found with the name {flow_name}")
flow_id = flow_ids[0]
return await run_flow(inputs=inputs, flow_id=flow_id, flow_name=flow_name, tweaks=tweaks)
if not flow_id:
raise ValueError(f"Flow {flow_name} not found")
if isinstance(input_value, str):
input_value = [input_value]
graph = await self.load_flow(flow_id, tweaks)
input_value_dict = [{"input_value": input_val} for input_val in input_value]
return await graph.arun(input_value_dict, stream=False)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]:
def list_flows(self) -> List[Record]:
if not self._user_id:
raise ValueError("Session is invalid")
try:
get_session = get_session or session_getter
db_service = get_db_service()
with get_session(db_service) as session:
flows = session.exec(
select(Flow).where(Flow.user_id == self._user_id).where(Flow.is_component == False) # noqa
).all()
flows_records = [flow.to_record() for flow in flows]
self._flows_records = flows_records
return flows_records
return list_flows(user_id=self._user_id)
except Exception as e:
raise ValueError(f"Error listing flows: {e}")