Add model validation for message in ChatOutputResponse to fix markdown
This commit is contained in:
parent
0ae6596dbd
commit
ff70ec0ff2
1 changed files with 25 additions and 2 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class ChatOutputResponse(BaseModel):
|
||||
|
|
@ -12,7 +12,30 @@ class ChatOutputResponse(BaseModel):
|
|||
sender_name: Optional[str] = "AI"
|
||||
|
||||
@classmethod
|
||||
def from_message(cls, message: BaseMessage, sender: Optional[str] = "Machine", sender_name: Optional[str] = "AI"):
|
||||
def from_message(
|
||||
cls,
|
||||
message: BaseMessage,
|
||||
sender: Optional[str] = "Machine",
|
||||
sender_name: Optional[str] = "AI",
|
||||
):
|
||||
"""Build chat output response from message."""
|
||||
content = message.content
|
||||
return cls(message=content, sender=sender, sender_name=sender_name)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_message(self):
|
||||
"""Validate message."""
|
||||
# The idea here is ensure the \n in message
|
||||
# is compliant with markdown if sender is machine
|
||||
# so, for example:
|
||||
# \n\n -> \n\n
|
||||
# \n -> \n\n
|
||||
|
||||
if self.sender != "Machine":
|
||||
return self
|
||||
|
||||
# We need to make sure we don't duplicate \n
|
||||
# in the message
|
||||
message = self.message.replace("\n\n", "\n")
|
||||
self.message = message.replace("\n", "\n\n")
|
||||
return self
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue