From 8416fb25a7afb790c10dd9f7974ad2613adad498 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Mon, 4 Mar 2024 09:56:06 -0300 Subject: [PATCH] Refactor component and custom_component classes, and add to_record method to Flow model --- .../custom/custom_component/component.py | 7 +++++-- .../custom_component/custom_component.py | 12 ++++++++---- .../services/database/models/flow/model.py | 18 +++++++++++++++++- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/src/backend/langflow/interface/custom/custom_component/component.py b/src/backend/langflow/interface/custom/custom_component/component.py index 13f185ed9..a889fa7b9 100644 --- a/src/backend/langflow/interface/custom/custom_component/component.py +++ b/src/backend/langflow/interface/custom/custom_component/component.py @@ -21,7 +21,9 @@ class ComponentFunctionEntrypointNameNullError(HTTPException): class Component: ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided." - ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided." + ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = ( + "The name of the entrypoint function must be provided." + ) code: Optional[str] = None _function_entrypoint_name: str = "build" @@ -39,7 +41,8 @@ class Component: def __setattr__(self, key, value): if key == "_user_id" and hasattr(self, "_user_id"): warnings.warn("user_id is immutable and cannot be changed.") - super().__setattr__(key, value) + else: + super().__setattr__(key, value) @cachedmethod(cache=operator.attrgetter("cache")) def get_code_tree(self, code: str): diff --git a/src/backend/langflow/interface/custom/custom_component/custom_component.py b/src/backend/langflow/interface/custom/custom_component/custom_component.py index 9364265c2..3c483eb25 100644 --- a/src/backend/langflow/interface/custom/custom_component/custom_component.py +++ b/src/backend/langflow/interface/custom/custom_component/custom_component.py @@ -341,7 +341,7 @@ class CustomComponent(Component): input_value_dict = {"input_value": input_value} return await graph.run(input_value_dict, stream=False) - def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]: + def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]: if not self._user_id: raise ValueError("Session is invalid") try: @@ -349,11 +349,15 @@ class CustomComponent(Component): db_service = get_db_service() with get_session(db_service) as session: flows = session.exec( - select(Flow).where(Flow.user_id == self._user_id) + select(Flow) + .where(Flow.user_id == self._user_id) + .where(Flow.is_component == False) ).all() - return flows + + flows_records = [flow.to_record() for flow in flows] + return flows_records except Exception as e: - raise ValueError("Session is invalid") from e + raise ValueError(f"Error listing flows: {e}") def build(self, *args: Any, **kwargs: Any) -> Any: raise NotImplementedError diff --git a/src/backend/langflow/services/database/models/flow/model.py b/src/backend/langflow/services/database/models/flow/model.py index d942fa93c..1211c40ed 100644 --- a/src/backend/langflow/services/database/models/flow/model.py +++ b/src/backend/langflow/services/database/models/flow/model.py @@ -7,6 +7,8 @@ from uuid import UUID, uuid4 from pydantic import field_serializer, field_validator from sqlmodel import JSON, Column, Field, Relationship, SQLModel +from langflow.schema.schema import Record + if TYPE_CHECKING: from langflow.services.database.models.user import User @@ -16,7 +18,9 @@ class FlowBase(SQLModel): description: Optional[str] = Field(index=True, nullable=True, default=None) data: Optional[Dict] = Field(default=None, nullable=True) is_component: Optional[bool] = Field(default=False, nullable=True) - updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, nullable=True) + updated_at: Optional[datetime] = Field( + default_factory=datetime.utcnow, nullable=True + ) folder: Optional[str] = Field(default=None, nullable=True) @field_validator("data") @@ -57,6 +61,18 @@ class Flow(FlowBase, table=True): user_id: UUID = Field(index=True, foreign_key="user.id", nullable=True) user: "User" = Relationship(back_populates="flows") + def to_record(self): + serialized = self.model_dump() + data = { + "id": serialized.pop("id"), + "data": serialized.pop("data"), + "name": serialized.pop("name"), + "description": serialized.pop("description"), + "updated_at": serialized.pop("updated_at"), + } + record = Record(text=data.get("name"), data=data) + return record + class FlowCreate(FlowBase): user_id: Optional[UUID] = None