179 lines
4.9 KiB
Python
179 lines
4.9 KiB
Python
from typing import Optional, Union
|
|
from enum import Enum
|
|
|
|
from pydantic import validator
|
|
|
|
from vocode.models.message import BaseMessage
|
|
from .model import TypedModel, BaseModel
|
|
|
|
FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS = 0.5
|
|
LLM_AGENT_DEFAULT_TEMPERATURE = 1.0
|
|
LLM_AGENT_DEFAULT_MAX_TOKENS = 256
|
|
LLM_AGENT_DEFAULT_MODEL_NAME = "text-curie-001"
|
|
|
|
|
|
class AgentType(str, Enum):
|
|
BASE = "agent_base"
|
|
LLM = "agent_llm"
|
|
CHAT_GPT_ALPHA = "agent_chat_gpt_alpha"
|
|
CHAT_GPT = "agent_chat_gpt"
|
|
ECHO = "agent_echo"
|
|
INFORMATION_RETRIEVAL = "agent_information_retrieval"
|
|
RESTFUL_USER_IMPLEMENTED = "agent_restful_user_implemented"
|
|
WEBSOCKET_USER_IMPLEMENTED = "agent_websocket_user_implemented"
|
|
|
|
|
|
class FillerAudioConfig(BaseModel):
|
|
silence_threshold_seconds: float = FILLER_AUDIO_DEFAULT_SILENCE_THRESHOLD_SECONDS
|
|
use_phrases: bool = True
|
|
use_typing_noise: bool = False
|
|
|
|
@validator("use_typing_noise")
|
|
def typing_noise_excludes_phrases(cls, v, values):
|
|
if v and values.get("use_phrases"):
|
|
values["use_phrases"] = False
|
|
if not v and not values.get("use_phrases"):
|
|
raise ValueError("must use either typing noise or phrases for filler audio")
|
|
return v
|
|
|
|
|
|
class AgentConfig(TypedModel, type=AgentType.BASE):
|
|
initial_message: Optional[BaseMessage] = None
|
|
generate_responses: bool = True
|
|
allowed_idle_time_seconds: Optional[float] = None
|
|
end_conversation_on_goodbye: bool = False
|
|
send_filler_audio: Union[bool, FillerAudioConfig] = False
|
|
|
|
|
|
class CutOffResponse(BaseModel):
|
|
messages: list[BaseMessage] = [BaseMessage(text="Sorry?")]
|
|
|
|
|
|
class LLMAgentConfig(AgentConfig, type=AgentType.LLM):
|
|
prompt_preamble: str
|
|
expected_first_prompt: Optional[str] = None
|
|
model_name: str = LLM_AGENT_DEFAULT_MODEL_NAME
|
|
temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
|
|
max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
|
|
cut_off_response: Optional[CutOffResponse] = None
|
|
|
|
|
|
class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT):
|
|
prompt_preamble: str
|
|
expected_first_prompt: Optional[str] = None
|
|
generate_responses: bool = False
|
|
temperature: float = LLM_AGENT_DEFAULT_TEMPERATURE
|
|
max_tokens: int = LLM_AGENT_DEFAULT_MAX_TOKENS
|
|
cut_off_response: Optional[CutOffResponse] = None
|
|
|
|
|
|
class InformationRetrievalAgentConfig(
|
|
AgentConfig, type=AgentType.INFORMATION_RETRIEVAL
|
|
):
|
|
recipient_descriptor: str
|
|
caller_descriptor: str
|
|
goal_description: str
|
|
fields: list[str]
|
|
# TODO: add fields for IVR, voicemail
|
|
|
|
|
|
class EchoAgentConfig(AgentConfig, type=AgentType.ECHO):
|
|
pass
|
|
|
|
|
|
class RESTfulUserImplementedAgentConfig(
|
|
AgentConfig, type=AgentType.RESTFUL_USER_IMPLEMENTED
|
|
):
|
|
class EndpointConfig(BaseModel):
|
|
url: str
|
|
method: str = "POST"
|
|
|
|
respond: EndpointConfig
|
|
generate_responses: bool = False
|
|
# generate_response: Optional[EndpointConfig]
|
|
# update_last_bot_message_on_cut_off: Optional[EndpointConfig]
|
|
|
|
|
|
class RESTfulAgentInput(BaseModel):
|
|
conversation_id: str
|
|
human_input: str
|
|
|
|
|
|
class RESTfulAgentOutputType(str, Enum):
|
|
BASE = "restful_agent_base"
|
|
TEXT = "restful_agent_text"
|
|
END = "restful_agent_end"
|
|
|
|
|
|
class RESTfulAgentOutput(TypedModel, type=RESTfulAgentOutputType.BASE):
|
|
pass
|
|
|
|
|
|
class RESTfulAgentText(RESTfulAgentOutput, type=RESTfulAgentOutputType.TEXT):
|
|
response: str
|
|
|
|
|
|
class RESTfulAgentEnd(RESTfulAgentOutput, type=RESTfulAgentOutputType.END):
|
|
pass
|
|
|
|
|
|
class WebSocketUserImplementedAgentConfig(
|
|
AgentConfig, type=AgentType.WEBSOCKET_USER_IMPLEMENTED
|
|
):
|
|
class RouteConfig(BaseModel):
|
|
url: str
|
|
|
|
respond: RouteConfig
|
|
generate_responses: bool = False
|
|
# generate_response: Optional[RouteConfig]
|
|
# send_message_on_cut_off: bool = False
|
|
|
|
|
|
class WebSocketAgentMessageType(str, Enum):
|
|
BASE = "websocket_agent_base"
|
|
START = "websocket_agent_start"
|
|
TEXT = "websocket_agent_text"
|
|
TEXT_END = "websocket_agent_text_end"
|
|
READY = "websocket_agent_ready"
|
|
STOP = "websocket_agent_stop"
|
|
|
|
|
|
class WebSocketAgentMessage(TypedModel, type=WebSocketAgentMessageType.BASE):
|
|
conversation_id: Optional[str] = None
|
|
|
|
|
|
class WebSocketAgentTextMessage(
|
|
WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT
|
|
):
|
|
class Payload(BaseModel):
|
|
text: str
|
|
|
|
data: Payload
|
|
|
|
@classmethod
|
|
def from_text(cls, text: str, conversation_id: Optional[str] = None):
|
|
return cls(data=cls.Payload(text=text), conversation_id=conversation_id)
|
|
|
|
|
|
class WebSocketAgentStartMessage(
|
|
WebSocketAgentMessage, type=WebSocketAgentMessageType.START
|
|
):
|
|
pass
|
|
|
|
|
|
class WebSocketAgentReadyMessage(
|
|
WebSocketAgentMessage, type=WebSocketAgentMessageType.READY
|
|
):
|
|
pass
|
|
|
|
|
|
class WebSocketAgentStopMessage(
|
|
WebSocketAgentMessage, type=WebSocketAgentMessageType.STOP
|
|
):
|
|
pass
|
|
|
|
|
|
class WebSocketAgentTextEndMessage(
|
|
WebSocketAgentMessage, type=WebSocketAgentMessageType.TEXT_END
|
|
):
|
|
pass
|