Add model validation for message in ChatOutputResponse to fix markdown

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-02-21 18:57:39 -03:00
commit ff70ec0ff2

View file

@ -1,7 +1,7 @@
from typing import Dict, List, Optional, Union
from langchain_core.messages import BaseMessage
from pydantic import BaseModel
from pydantic import BaseModel, model_validator
class ChatOutputResponse(BaseModel):
@ -12,7 +12,30 @@ class ChatOutputResponse(BaseModel):
sender_name: Optional[str] = "AI"
@classmethod
def from_message(cls, message: BaseMessage, sender: Optional[str] = "Machine", sender_name: Optional[str] = "AI"):
def from_message(
cls,
message: BaseMessage,
sender: Optional[str] = "Machine",
sender_name: Optional[str] = "AI",
):
"""Build chat output response from message."""
content = message.content
return cls(message=content, sender=sender, sender_name=sender_name)
@model_validator(mode="after")
def validate_message(self):
"""Validate message."""
# The idea here is ensure the \n in message
# is compliant with markdown if sender is machine
# so, for example:
# \n\n -> \n\n
# \n -> \n\n
if self.sender != "Machine":
return self
# We need to make sure we don't duplicate \n
# in the message
message = self.message.replace("\n\n", "\n")
self.message = message.replace("\n", "\n\n")
return self