vocode-python/vocode/streaming/agent/llm_agent.py
2023-03-28 10:29:00 -07:00

139 lines
5.2 KiB
Python

import re
from typing import Optional
from dotenv import load_dotenv
from langchain import OpenAI
from langchain.llms import OpenAIChat
from typing import Generator
import logging
from vocode.streaming.agent.base_agent import BaseAgent
from vocode.streaming.agent.utils import stream_llm_response
from vocode.streaming.models.agent import LLMAgentConfig
load_dotenv()
class LLMAgent(BaseAgent):
SENTENCE_ENDINGS = [".", "!", "?"]
DEFAULT_PROMPT_TEMPLATE = "{history}\nHuman: {human_input}\nAI:"
def __init__(
self,
agent_config: LLMAgentConfig,
logger: logging.Logger = None,
sender="AI",
recipient="Human",
):
super().__init__(agent_config)
self.agent_config = agent_config
self.prompt_template = (
f"{agent_config.prompt_preamble}\n\n{self.DEFAULT_PROMPT_TEMPLATE}"
)
self.initial_bot_message = (
agent_config.initial_message.text if agent_config.initial_message else None
)
self.logger = logger or logging.getLogger(__name__)
self.sender = sender
self.recipient = recipient
self.memory = (
[f"AI: {agent_config.initial_message.text}"]
if agent_config.initial_message
else []
)
self.llm = OpenAI(
model_name=self.agent_config.model_name,
temperature=self.agent_config.temperature,
max_tokens=self.agent_config.max_tokens,
)
self.stop_tokens = [f"{recipient}:"]
self.first_response = (
self.llm(
self.prompt_template.format(
history="", human_input=agent_config.expected_first_prompt
),
stop=self.stop_tokens,
).strip()
if agent_config.expected_first_prompt
else None
)
self.is_first_response = True
def create_prompt(self, human_input):
history = "\n".join(self.memory[-5:])
return self.prompt_template.format(history=history, human_input=human_input)
def get_memory_entry(self, human_input, response):
return f"{self.recipient}: {human_input}\n{self.sender}: {response}"
def respond(self, human_input, is_interrupt: bool = False) -> tuple[str, bool]:
if is_interrupt and self.agent_config.cut_off_response:
cut_off_response = self.get_cut_off_response()
self.memory.append(self.get_memory_entry(human_input, cut_off_response))
return cut_off_response, False
self.logger.debug("LLM responding to human input")
if self.is_first_response and self.first_response:
self.logger.debug("First response is cached")
self.is_first_response = False
response = self.first_response
else:
response = self.llm(self.create_prompt(human_input), stop=self.stop_tokens)
response = response.replace(f"{self.sender}:", "")
self.memory.append(self.get_memory_entry(human_input, response))
self.logger.debug(f"LLM response: {response}")
return response, False
def generate_response(self, human_input, is_interrupt: bool = False) -> Generator:
self.logger.debug("LLM generating response to human input")
if is_interrupt and self.agent_config.cut_off_response:
cut_off_response = self.get_cut_off_response()
self.memory.append(self.get_memory_entry(human_input, cut_off_response))
yield cut_off_response
return
self.memory.append(self.get_memory_entry(human_input, ""))
if self.is_first_response and self.first_response:
self.logger.debug("First response is cached")
self.is_first_response = False
sentences = [self.first_response]
else:
self.logger.debug("Creating LLM prompt")
prompt = self.create_prompt(human_input)
self.logger.debug("Streaming LLM response")
sentences = stream_llm_response(
map(
lambda resp: resp.to_dict(),
self.llm.stream(prompt, stop=self.stop_tokens),
)
)
response_buffer = ""
for sentence in sentences:
sentence = sentence.replace(f"{self.sender}:", "")
sentence = re.sub(r"^\s+(.*)", r" \1", sentence)
response_buffer += sentence
self.memory[-1] = self.get_memory_entry(human_input, response_buffer)
yield sentence
def update_last_bot_message_on_cut_off(self, message: str):
last_message = self.memory[-1]
new_last_message = (
last_message.split("\n", 1)[0] + f"\n{self.sender}: {message}"
)
self.memory[-1] = new_last_message
if __name__ == "__main__":
chat_responder = LLMAgent(
LLMAgentConfig(
prompt_preamble="""
The AI is having a pleasant conversation about life. If the human hasn't completed their thought, the AI responds with 'PASS'
{history}
Human: {human_input}
AI:""",
)
)
while True:
# response = chat_responder.respond(input("Human: "))[0]
for response in chat_responder.generate_response(input("Human: ")):
print(f"AI: {response}")