45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
from typing import Optional
|
|
from langchain.prompts import (
|
|
ChatPromptTemplate,
|
|
MessagesPlaceholder,
|
|
SystemMessagePromptTemplate,
|
|
HumanMessagePromptTemplate,
|
|
)
|
|
from langchain.chains import ConversationChain
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.memory import ConversationBufferMemory
|
|
|
|
from vocode.turn_based.agent.base_agent import BaseAgent
|
|
|
|
|
|
class ChatGPTAgent(BaseAgent):
|
|
def __init__(
|
|
self,
|
|
system_prompt: str,
|
|
initial_message: Optional[str] = None,
|
|
model_name: str = "gpt-3.5-turbo",
|
|
temperature: float = 0.7,
|
|
max_tokens: int = 100,
|
|
):
|
|
super().__init__(initial_message=initial_message)
|
|
self.prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
SystemMessagePromptTemplate.from_template(system_prompt),
|
|
MessagesPlaceholder(variable_name="history"),
|
|
HumanMessagePromptTemplate.from_template("{input}"),
|
|
]
|
|
)
|
|
self.memory = ConversationBufferMemory(return_messages=True)
|
|
if initial_message:
|
|
self.memory.chat_memory.add_ai_message(initial_message)
|
|
self.llm = ChatOpenAI(
|
|
model_name=model_name,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
self.conversation = ConversationChain(
|
|
memory=self.memory, prompt=self.prompt, llm=self.llm
|
|
)
|
|
|
|
def respond(self, human_input: str):
|
|
return self.conversation.predict(input=human_input)
|