Adds OpenAIConversationalAgent (#746)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-08-10 17:41:43 -03:00 committed by GitHub
commit 337dba208e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 83 additions and 1 deletions

View file

@ -0,0 +1,76 @@
from langflow import CustomComponent
from typing import Optional
from langchain.prompts import SystemMessagePromptTemplate
from langchain.tools import Tool
from langchain.schema.memory import BaseMemory
from langchain.chat_models import ChatOpenAI
from langchain.agents.agent import AgentExecutor
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
from langchain.memory.token_buffer import ConversationTokenBufferMemory
from langchain.prompts.chat import MessagesPlaceholder
from langchain.agents.agent_toolkits.conversational_retrieval.openai_functions import (
_get_default_system_message,
)
class ConversationalAgent(CustomComponent):
display_name: str = "OpenaAI Conversational Agent"
description: str = "Conversational Agent that can use OpenAI's function calling API"
def build_config(self):
openai_function_models = [
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0613",
"gpt-4-32k-0613",
]
return {
"tools": {"is_list": True, "display_name": "Tools"},
"memory": {"display_name": "Memory"},
"system_message": {"display_name": "System Message"},
"max_token_limit": {"display_name": "Max Token Limit"},
"model_name": {
"display_name": "Model Name",
"options": openai_function_models,
"value": openai_function_models[0],
},
"code": {"show": False},
}
def build(
self,
model_name: str,
tools: Tool,
memory: Optional[BaseMemory] = None,
system_message: Optional[SystemMessagePromptTemplate] = None,
max_token_limit: int = 2000,
) -> AgentExecutor:
llm = ChatOpenAI(model=model_name)
if not memory:
memory_key = "chat_history"
memory = ConversationTokenBufferMemory(
memory_key=memory_key,
return_messages=True,
output_key="output",
llm=llm,
max_token_limit=max_token_limit,
)
else:
memory_key = memory.memory_key # type: ignore
_system_message = system_message or _get_default_system_message()
prompt = OpenAIFunctionsAgent.create_prompt(
system_message=_system_message, # type: ignore
extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
)
agent = OpenAIFunctionsAgent(
llm=llm, tools=tools, prompt=prompt # type: ignore
)
return AgentExecutor(
agent=agent,
tools=tools, # type: ignore
memory=memory,
verbose=True,
return_intermediate_steps=True,
)

View file

@ -8,6 +8,8 @@ from langchain.text_splitter import TextSplitter
from langchain.tools import Tool
from langchain.vectorstores.base import VectorStore
from langchain.schema import BaseOutputParser
from langchain.schema.memory import BaseMemory
from langchain.memory.chat_memory import BaseChatMemory
from langchain.agents.agent import AgentExecutor
LANGCHAIN_BASE_TYPES = {
@ -23,6 +25,8 @@ LANGCHAIN_BASE_TYPES = {
"Embeddings": Embeddings,
"BaseRetriever": BaseRetriever,
"BaseOutputParser": BaseOutputParser,
"BaseMemory": BaseMemory,
"BaseChatMemory": BaseChatMemory,
}
# Langchain base types plus Python base types

View file

@ -5,6 +5,7 @@ from langflow.api.v1.callback import (
)
from langflow.processing.process import fix_memory_inputs, format_actions
from langflow.utils.logger import logger
from langchain.agents.agent import AgentExecutor
async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwargs):
@ -20,7 +21,8 @@ async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwa
# to display intermediate steps
langchain_object.return_intermediate_steps = True
try:
fix_memory_inputs(langchain_object)
if not isinstance(langchain_object, AgentExecutor):
fix_memory_inputs(langchain_object)
except Exception as exc:
logger.error(f"Error fixing memory inputs: {exc}")