fix: tool calling and openai tools do not include chat history (#3030)
* fix: tool calling and openai tools do not include chat history * fix import * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
32824276c8
commit
53063957e4
3 changed files with 26 additions and 5 deletions
|
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
from typing import List, Union, cast
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
from langchain.agents import AgentExecutor, BaseMultiActionAgent, BaseSingleActionAgent
|
||||
from langchain.agents.agent import RunnableAgent
|
||||
|
|
@ -10,8 +10,9 @@ from langflow.base.agents.callback import AgentAsyncHandler
|
|||
from langflow.base.agents.utils import data_to_messages
|
||||
from langflow.custom import Component
|
||||
from langflow.field_typing import Text
|
||||
from langflow.inputs.inputs import DataInput, InputTypes
|
||||
from langflow.inputs.inputs import InputTypes
|
||||
from langflow.io import BoolInput, HandleInput, IntInput, MessageTextInput
|
||||
from langflow.schema import Data
|
||||
from langflow.schema.message import Message
|
||||
from langflow.template import Output
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI
|
||||
|
|
@ -39,7 +40,6 @@ class LCAgentComponent(Component):
|
|||
value=15,
|
||||
advanced=True,
|
||||
),
|
||||
DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True),
|
||||
]
|
||||
|
||||
outputs = [
|
||||
|
|
@ -89,8 +89,13 @@ class LCAgentComponent(Component):
|
|||
}
|
||||
return {**base, "agent_executor_kwargs": agent_kwargs}
|
||||
|
||||
def get_chat_history_data(self) -> Optional[List[Data]]:
|
||||
# might be overridden in subclasses
|
||||
return None
|
||||
|
||||
async def run_agent(self, agent: AgentExecutor) -> Text:
|
||||
input_dict: dict[str, str | list[BaseMessage]] = {"input": self.input_value}
|
||||
self.chat_history = self.get_chat_history_data()
|
||||
if self.chat_history:
|
||||
input_dict["chat_history"] = data_to_messages(self.chat_history)
|
||||
result = await agent.ainvoke(input_dict, config={"callbacks": [AgentAsyncHandler(self.log)]})
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
from typing import Optional, List
|
||||
|
||||
from langchain.agents import create_openai_tools_agent
|
||||
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate
|
||||
|
||||
from langflow.base.agents.agent import LCToolsAgentComponent
|
||||
from langflow.inputs import MultilineInput
|
||||
from langflow.inputs.inputs import HandleInput
|
||||
from langflow.inputs.inputs import HandleInput, DataInput
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class OpenAIToolsAgentComponent(LCToolsAgentComponent):
|
||||
|
|
@ -29,13 +32,18 @@ class OpenAIToolsAgentComponent(LCToolsAgentComponent):
|
|||
MultilineInput(
|
||||
name="user_prompt", display_name="Prompt", info="This prompt must contain 'input' key.", value="{input}"
|
||||
),
|
||||
DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True),
|
||||
]
|
||||
|
||||
def get_chat_history_data(self) -> Optional[List[Data]]:
|
||||
return self.chat_history
|
||||
|
||||
def create_agent_runnable(self):
|
||||
if "input" not in self.user_prompt:
|
||||
raise ValueError("Prompt must contain 'input' key.")
|
||||
messages = [
|
||||
("system", self.system_prompt),
|
||||
("placeholder", "{chat_history}"),
|
||||
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt)),
|
||||
("placeholder", "{agent_scratchpad}"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
from typing import Optional, List
|
||||
|
||||
from langchain.agents import create_tool_calling_agent
|
||||
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, HumanMessagePromptTemplate
|
||||
from langflow.base.agents.agent import LCToolsAgentComponent
|
||||
from langflow.inputs import MultilineInput
|
||||
from langflow.inputs.inputs import HandleInput
|
||||
from langflow.inputs.inputs import HandleInput, DataInput
|
||||
from langflow.schema import Data
|
||||
|
||||
|
||||
class ToolCallingAgentComponent(LCToolsAgentComponent):
|
||||
|
|
@ -23,13 +26,18 @@ class ToolCallingAgentComponent(LCToolsAgentComponent):
|
|||
MultilineInput(
|
||||
name="user_prompt", display_name="Prompt", info="This prompt must contain 'input' key.", value="{input}"
|
||||
),
|
||||
DataInput(name="chat_history", display_name="Chat History", is_list=True, advanced=True),
|
||||
]
|
||||
|
||||
def get_chat_history_data(self) -> Optional[List[Data]]:
|
||||
return self.chat_history
|
||||
|
||||
def create_agent_runnable(self):
|
||||
if "input" not in self.user_prompt:
|
||||
raise ValueError("Prompt must contain 'input' key.")
|
||||
messages = [
|
||||
("system", self.system_prompt),
|
||||
("placeholder", "{chat_history}"),
|
||||
HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=["input"], template=self.user_prompt)),
|
||||
("placeholder", "{agent_scratchpad}"),
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue