Refactor chat component and add ClearMessageHistory component
This commit is contained in:
parent
bfe35b4e46
commit
0c1703de9a
7 changed files with 67 additions and 12 deletions
|
|
@ -45,7 +45,9 @@ class ChatComponent(CustomComponent):
|
|||
return []
|
||||
|
||||
if not session_id or not sender or not sender_name:
|
||||
raise ValueError("All of session_id, sender, and sender_name must be provided.")
|
||||
raise ValueError(
|
||||
"All of session_id, sender, and sender_name must be provided."
|
||||
)
|
||||
if isinstance(message, Record):
|
||||
record = message
|
||||
record.data.update(
|
||||
|
|
@ -57,8 +59,8 @@ class ChatComponent(CustomComponent):
|
|||
)
|
||||
else:
|
||||
record = Record(
|
||||
text=message,
|
||||
data={
|
||||
"text": message,
|
||||
"session_id": session_id,
|
||||
"sender": sender,
|
||||
"sender_name": sender_name,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
from langflow import CustomComponent
|
||||
from langflow.memory import delete_messages, get_messages
|
||||
|
||||
|
||||
class ClearMessageHistoryComponent(CustomComponent):
|
||||
display_name = "Clear Message History"
|
||||
description = "A component to clear the message history."
|
||||
|
||||
def build_config(self):
|
||||
return {
|
||||
"session_id": {
|
||||
"display_name": "Session ID",
|
||||
"info": "The session ID to clear the message history.",
|
||||
}
|
||||
}
|
||||
|
||||
def build(
|
||||
self,
|
||||
session_id: str,
|
||||
) -> None:
|
||||
delete_messages(session_id=session_id)
|
||||
records = get_messages(session_id=session_id)
|
||||
self.records = records
|
||||
return records
|
||||
|
|
@ -30,5 +30,5 @@ def records_to_text(template: str, records: list[Record]) -> list[str]:
|
|||
records = [records]
|
||||
# Check if there are any format strings in the template
|
||||
|
||||
formated_records = [template.format(text=record.text, data=record.data, **record.data) for record in records]
|
||||
formated_records = [template.format(**record.data) for record in records]
|
||||
return "\n".join(formated_records)
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ def get_messages(
|
|||
|
||||
for row in messages_df.itertuples():
|
||||
record = Record(
|
||||
text=row.message,
|
||||
data={
|
||||
"text": row.message,
|
||||
"sender": row.sender,
|
||||
"sender_name": row.sender_name,
|
||||
"session_id": row.session_id,
|
||||
|
|
@ -81,3 +81,14 @@ def add_messages(records: Union[list[Record], Record]):
|
|||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
|
||||
def delete_messages(session_id: str):
|
||||
"""
|
||||
Delete messages from the monitor service based on the provided session ID.
|
||||
|
||||
Args:
|
||||
session_id (str): The session ID associated with the messages to delete.
|
||||
"""
|
||||
monitor_service = get_monitor_service()
|
||||
monitor_service.delete_messages(session_id)
|
||||
|
|
|
|||
|
|
@ -108,3 +108,7 @@ class Record(BaseModel):
|
|||
suffix = ")"
|
||||
text = ", ".join([f"{k}={v}" for k, v in self.data.items()])
|
||||
return prefix + text + suffix
|
||||
|
||||
# check which attributes the Record has by checking the keys in the data dictionary
|
||||
def __dir__(self):
|
||||
return super().__dir__() + list(self.data.keys())
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ if TYPE_CHECKING:
|
|||
|
||||
class TransactionModel(BaseModel):
|
||||
id: Optional[int] = Field(default=None, alias="id")
|
||||
timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp")
|
||||
timestamp: Optional[datetime] = Field(
|
||||
default_factory=datetime.now, alias="timestamp"
|
||||
)
|
||||
source: str
|
||||
target: str
|
||||
target_args: dict
|
||||
|
|
@ -51,14 +53,18 @@ class MessageModel(BaseModel):
|
|||
@classmethod
|
||||
def from_record(cls, record: "Record"):
|
||||
# first check if the record has all the required fields
|
||||
if not record.data or ("sender" not in record.data and "sender_name" not in record.data):
|
||||
raise ValueError("The record does not have the required fields 'sender' and 'sender_name' in the data.")
|
||||
if not record.data or (
|
||||
"sender" not in record.data and "sender_name" not in record.data
|
||||
):
|
||||
raise ValueError(
|
||||
"The record does not have the required fields 'sender' and 'sender_name' in the data."
|
||||
)
|
||||
return cls(
|
||||
sender=record.data["sender"],
|
||||
sender_name=record.data["sender_name"],
|
||||
sender=record.sender,
|
||||
sender_name=record.sender_name,
|
||||
message=record.text,
|
||||
session_id=record.data.get("session_id", ""),
|
||||
artifacts=record.data.get("artifacts", {}),
|
||||
session_id=record.session_id,
|
||||
artifacts=record.artifacts or {},
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -44,7 +44,9 @@ class MonitorService(Service):
|
|||
|
||||
def ensure_tables_exist(self):
|
||||
for table_name, model in self.table_map.items():
|
||||
drop_and_create_table_if_schema_mismatch(str(self.db_path), table_name, model)
|
||||
drop_and_create_table_if_schema_mismatch(
|
||||
str(self.db_path), table_name, model
|
||||
)
|
||||
|
||||
def add_row(
|
||||
self,
|
||||
|
|
@ -105,6 +107,12 @@ class MonitorService(Service):
|
|||
with duckdb.connect(str(self.db_path)) as conn:
|
||||
conn.execute(query)
|
||||
|
||||
def delete_messages(self, session_id: str):
|
||||
query = f"DELETE FROM messages WHERE session_id = '{session_id}'"
|
||||
|
||||
with duckdb.connect(str(self.db_path)) as conn:
|
||||
conn.execute(query)
|
||||
|
||||
def add_message(self, message: MessageModel):
|
||||
self.add_row("messages", message)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue