Refactor chat component and add ClearMessageHistory component

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-06 18:45:39 -03:00
commit 0c1703de9a
7 changed files with 67 additions and 12 deletions

View file

@ -45,7 +45,9 @@ class ChatComponent(CustomComponent):
return []
if not session_id or not sender or not sender_name:
raise ValueError("All of session_id, sender, and sender_name must be provided.")
raise ValueError(
"All of session_id, sender, and sender_name must be provided."
)
if isinstance(message, Record):
record = message
record.data.update(
@ -57,8 +59,8 @@ class ChatComponent(CustomComponent):
)
else:
record = Record(
text=message,
data={
"text": message,
"session_id": session_id,
"sender": sender,
"sender_name": sender_name,

View file

@ -0,0 +1,24 @@
from langflow import CustomComponent
from langflow.memory import delete_messages, get_messages
class ClearMessageHistoryComponent(CustomComponent):
display_name = "Clear Message History"
description = "A component to clear the message history."
def build_config(self):
return {
"session_id": {
"display_name": "Session ID",
"info": "The session ID to clear the message history.",
}
}
def build(
self,
session_id: str,
) -> None:
delete_messages(session_id=session_id)
records = get_messages(session_id=session_id)
self.records = records
return records

View file

@ -30,5 +30,5 @@ def records_to_text(template: str, records: list[Record]) -> list[str]:
records = [records]
# Check if there are any format strings in the template
formated_records = [template.format(text=record.text, data=record.data, **record.data) for record in records]
formated_records = [template.format(**record.data) for record in records]
return "\n".join(formated_records)

View file

@ -40,8 +40,8 @@ def get_messages(
for row in messages_df.itertuples():
record = Record(
text=row.message,
data={
"text": row.message,
"sender": row.sender,
"sender_name": row.sender_name,
"session_id": row.session_id,
@ -81,3 +81,14 @@ def add_messages(records: Union[list[Record], Record]):
except Exception as e:
logger.exception(e)
raise e
def delete_messages(session_id: str):
"""
Delete messages from the monitor service based on the provided session ID.
Args:
session_id (str): The session ID associated with the messages to delete.
"""
monitor_service = get_monitor_service()
monitor_service.delete_messages(session_id)

View file

@ -108,3 +108,7 @@ class Record(BaseModel):
suffix = ")"
text = ", ".join([f"{k}={v}" for k, v in self.data.items()])
return prefix + text + suffix
# check which attributes the Record has by checking the keys in the data dictionary
def __dir__(self):
return super().__dir__() + list(self.data.keys())

View file

@ -10,7 +10,9 @@ if TYPE_CHECKING:
class TransactionModel(BaseModel):
id: Optional[int] = Field(default=None, alias="id")
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
timestamp: Optional[datetime] = Field(
default_factory=datetime.now, alias="timestamp"
)
source: str
target: str
target_args: dict
@ -51,14 +53,18 @@ class MessageModel(BaseModel):
@classmethod
def from_record(cls, record: "Record"):
# first check if the record has all the required fields
if not record.data or ("sender" not in record.data and "sender_name" not in record.data):
raise ValueError("The record does not have the required fields 'sender' and 'sender_name' in the data.")
if not record.data or (
"sender" not in record.data and "sender_name" not in record.data
):
raise ValueError(
"The record does not have the required fields 'sender' and 'sender_name' in the data."
)
return cls(
sender=record.data["sender"],
sender_name=record.data["sender_name"],
sender=record.sender,
sender_name=record.sender_name,
message=record.text,
session_id=record.data.get("session_id", ""),
artifacts=record.data.get("artifacts", {}),
session_id=record.session_id,
artifacts=record.artifacts or {},
)

View file

@ -44,7 +44,9 @@ class MonitorService(Service):
def ensure_tables_exist(self):
for table_name, model in self.table_map.items():
drop_and_create_table_if_schema_mismatch(str(self.db_path), table_name, model)
drop_and_create_table_if_schema_mismatch(
str(self.db_path), table_name, model
)
def add_row(
self,
@ -105,6 +107,12 @@ class MonitorService(Service):
with duckdb.connect(str(self.db_path)) as conn:
conn.execute(query)
def delete_messages(self, session_id: str):
query = f"DELETE FROM messages WHERE session_id = '{session_id}'"
with duckdb.connect(str(self.db_path)) as conn:
conn.execute(query)
def add_message(self, message: MessageModel):
self.add_row("messages", message)