Refactor component and custom_component classes, and add to_record method to Flow model

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-04 09:56:06 -03:00
commit 8416fb25a7
3 changed files with 30 additions and 7 deletions

View file

@ -21,7 +21,9 @@ class ComponentFunctionEntrypointNameNullError(HTTPException):
class Component:
ERROR_CODE_NULL: ClassVar[str] = "Python code must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = "The name of the entrypoint function must be provided."
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL: ClassVar[str] = (
"The name of the entrypoint function must be provided."
)
code: Optional[str] = None
_function_entrypoint_name: str = "build"
@ -39,7 +41,8 @@ class Component:
def __setattr__(self, key, value):
if key == "_user_id" and hasattr(self, "_user_id"):
warnings.warn("user_id is immutable and cannot be changed.")
super().__setattr__(key, value)
else:
super().__setattr__(key, value)
@cachedmethod(cache=operator.attrgetter("cache"))
def get_code_tree(self, code: str):

View file

@ -341,7 +341,7 @@ class CustomComponent(Component):
input_value_dict = {"input_value": input_value}
return await graph.run(input_value_dict, stream=False)
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Record]:
if not self._user_id:
raise ValueError("Session is invalid")
try:
@ -349,11 +349,15 @@ class CustomComponent(Component):
db_service = get_db_service()
with get_session(db_service) as session:
flows = session.exec(
select(Flow).where(Flow.user_id == self._user_id)
select(Flow)
.where(Flow.user_id == self._user_id)
.where(Flow.is_component == False)
).all()
return flows
flows_records = [flow.to_record() for flow in flows]
return flows_records
except Exception as e:
raise ValueError("Session is invalid") from e
raise ValueError(f"Error listing flows: {e}")
def build(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError

View file

@ -7,6 +7,8 @@ from uuid import UUID, uuid4
from pydantic import field_serializer, field_validator
from sqlmodel import JSON, Column, Field, Relationship, SQLModel
from langflow.schema.schema import Record
if TYPE_CHECKING:
from langflow.services.database.models.user import User
@ -16,7 +18,9 @@ class FlowBase(SQLModel):
description: Optional[str] = Field(index=True, nullable=True, default=None)
data: Optional[Dict] = Field(default=None, nullable=True)
is_component: Optional[bool] = Field(default=False, nullable=True)
updated_at: Optional[datetime] = Field(default_factory=datetime.utcnow, nullable=True)
updated_at: Optional[datetime] = Field(
default_factory=datetime.utcnow, nullable=True
)
folder: Optional[str] = Field(default=None, nullable=True)
@field_validator("data")
@ -57,6 +61,18 @@ class Flow(FlowBase, table=True):
user_id: UUID = Field(index=True, foreign_key="user.id", nullable=True)
user: "User" = Relationship(back_populates="flows")
def to_record(self):
serialized = self.model_dump()
data = {
"id": serialized.pop("id"),
"data": serialized.pop("data"),
"name": serialized.pop("name"),
"description": serialized.pop("description"),
"updated_at": serialized.pop("updated_at"),
}
record = Record(text=data.get("name"), data=data)
return record
class FlowCreate(FlowBase):
user_id: Optional[UUID] = None