vocode-python/vocode/streaming/agent/chat_gpt_agent.py

165 lines
6.4 KiB
Python

import random
import time
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.chains import ConversationChain
from langchain.chat_models import ChatOpenAI
from langchain.llms import OpenAIChat
from langchain.memory import ConversationBufferMemory
from langchain.schema import ChatMessage, AIMessage
import openai
import json
from typing import Generator, Optional
from typing import Generator
import logging
from vocode import getenv
from vocode.streaming.agent.base_agent import BaseAgent
from vocode.streaming.models.agent import ChatGPTAgentConfig
from vocode.streaming.utils.sse_client import SSEClient
from vocode.streaming.agent.utils import stream_llm_response
class ChatGPTAgent(BaseAgent):
def __init__(
self,
agent_config: ChatGPTAgentConfig,
logger: logging.Logger = None,
openai_api_key: Optional[str] = None,
):
super().__init__(agent_config)
openai.api_key = openai_api_key or getenv("OPENAI_API_KEY")
if not openai.api_key:
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
self.agent_config = agent_config
self.logger = logger or logging.getLogger(__name__)
self.logger.setLevel(logging.DEBUG)
self.prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(agent_config.prompt_preamble),
MessagesPlaceholder(variable_name="history"),
HumanMessagePromptTemplate.from_template("{input}"),
]
)
self.memory = ConversationBufferMemory(return_messages=True)
if agent_config.initial_message:
if (
agent_config.generate_responses
): # we use ChatMessages for memory when we generate responses
self.memory.chat_memory.messages.append(
ChatMessage(
content=agent_config.initial_message.text, role="assistant"
)
)
else:
self.memory.chat_memory.add_ai_message(
agent_config.initial_message.text
)
self.llm = ChatOpenAI(
model_name=self.agent_config.model_name,
temperature=self.agent_config.temperature,
max_tokens=self.agent_config.max_tokens,
openai_api_key=openai.api_key,
)
self.conversation = ConversationChain(
memory=self.memory, prompt=self.prompt, llm=self.llm
)
self.first_response = (
self.create_first_response(agent_config.expected_first_prompt)
if agent_config.expected_first_prompt
else None
)
self.is_first_response = True
def create_first_response(self, first_prompt):
return self.conversation.predict(input=first_prompt)
def respond(self, human_input, is_interrupt: bool = False) -> tuple[str, bool]:
if is_interrupt and self.agent_config.cut_off_response:
cut_off_response = self.get_cut_off_response()
self.memory.chat_memory.add_user_message(human_input)
self.memory.chat_memory.add_ai_message(cut_off_response)
return cut_off_response, False
self.logger.debug("LLM responding to human input")
if self.is_first_response and self.first_response:
self.logger.debug("First response is cached")
self.is_first_response = False
text = self.first_response
else:
text = self.conversation.predict(input=human_input)
self.logger.debug(f"LLM response: {text}")
return text, False
def generate_response(
self, human_input, is_interrupt: bool = False
) -> Generator[str, None, None]:
self.memory.chat_memory.messages.append(
ChatMessage(role="user", content=human_input)
)
if is_interrupt and self.agent_config.cut_off_response:
cut_off_response = self.get_cut_off_response()
self.memory.chat_memory.messages.append(
ChatMessage(role="assistant", content=cut_off_response)
)
yield cut_off_response
return
prompt_messages = [
ChatMessage(role="system", content=self.agent_config.prompt_preamble)
] + self.memory.chat_memory.messages
messages = SSEClient(
"POST",
"https://api.openai.com/v1/chat/completions",
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {getenv('OPENAI_API_KEY')}",
},
json={
"model": self.agent_config.model_name,
"messages": [
prompt_message.dict(include={"content": True, "role": True})
for prompt_message in prompt_messages
],
"max_tokens": 256,
"temperature": 1.0,
"stream": True,
},
)
bot_memory_message = ChatMessage(role="assistant", content="")
self.memory.chat_memory.messages.append(bot_memory_message)
for message in stream_llm_response(
map(lambda event: json.loads(event.data), messages),
get_text=lambda choice: choice.get("delta", {}).get("content"),
):
bot_memory_message.content = f"{bot_memory_message.content} {message}"
yield message
def update_last_bot_message_on_cut_off(self, message: str):
for memory_message in self.memory.chat_memory.messages[::-1]:
if (
isinstance(memory_message, ChatMessage)
and memory_message.role == "assistant"
) or isinstance(memory_message, AIMessage):
memory_message.content = message
return
if __name__ == "__main__":
from dotenv import load_dotenv
load_dotenv()
agent = ChatGPTAgent(
ChatGPTAgentConfig(
prompt_preamble="The assistant is having a pleasant conversation about life. If the user hasn't completed their thought, the assistant responds with 'PASS'",
)
)
while True:
response = agent.respond(input("Human: "))[0]
print(f"AI: {response}")
# for response in agent.generate_response(input("Human: ")):
# print(f"AI: {response}")