refactor(agent): standardize memory handling and update chat history logic (#8715)
* update chat history * update to agents * Update Simple Agent.json * update to templates * ruff errors * Update agent.py * Update test_agent_component.py * [autofix.ci] apply automated fixes * update templates * test fix --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Mike Fortman <michael.fortman@datastax.com>
This commit is contained in:
parent
0d74474e12
commit
38d5885fa3
25 changed files with 8672 additions and 17583 deletions
|
|
@ -139,7 +139,10 @@ class LCAgentComponent(Component):
|
|||
if hasattr(self, "system_prompt"):
|
||||
input_dict["system_prompt"] = self.system_prompt
|
||||
if hasattr(self, "chat_history") and self.chat_history:
|
||||
input_dict["chat_history"] = data_to_messages(self.chat_history)
|
||||
if isinstance(self.chat_history, Data):
|
||||
input_dict["chat_history"] = data_to_messages(self.chat_history)
|
||||
if all(isinstance(m, Message) for m in self.chat_history):
|
||||
input_dict["chat_history"] = data_to_messages([m.to_data() for m in self.chat_history])
|
||||
|
||||
if hasattr(self, "graph"):
|
||||
session_id = self.graph.session_id
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from langflow.components.langchain_utilities.tool_calling import ToolCallingAgen
|
|||
from langflow.custom.custom_component.component import _get_component_toolkit
|
||||
from langflow.custom.utils import update_component_build_config
|
||||
from langflow.field_typing import Tool
|
||||
from langflow.io import BoolInput, DropdownInput, MultilineInput, Output
|
||||
from langflow.io import BoolInput, DropdownInput, IntInput, MultilineInput, Output
|
||||
from langflow.logging import logger
|
||||
from langflow.schema.dotdict import dotdict
|
||||
from langflow.schema.message import Message
|
||||
|
|
@ -27,6 +27,9 @@ def set_advanced_true(component_input):
|
|||
return component_input
|
||||
|
||||
|
||||
MODEL_PROVIDERS_LIST = ["Anthropic", "Google Generative AI", "Groq", "OpenAI"]
|
||||
|
||||
|
||||
class AgentComponent(ToolCallingAgentComponent):
|
||||
display_name: str = "Agent"
|
||||
description: str = "Define the agent's instructions, then enter a task to complete using tools."
|
||||
|
|
@ -41,11 +44,11 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
name="agent_llm",
|
||||
display_name="Model Provider",
|
||||
info="The provider of the language model that the agent will use to generate responses.",
|
||||
options=[*sorted(MODEL_PROVIDERS), "Custom"],
|
||||
options=[*MODEL_PROVIDERS_LIST, "Custom"],
|
||||
value="OpenAI",
|
||||
real_time_refresh=True,
|
||||
input_types=[],
|
||||
options_metadata=[MODELS_METADATA[key] for key in sorted(MODEL_PROVIDERS)] + [{"icon": "brain"}],
|
||||
options_metadata=[MODELS_METADATA[key] for key in MODEL_PROVIDERS_LIST] + [{"icon": "brain"}],
|
||||
),
|
||||
*MODEL_PROVIDERS_DICT["OpenAI"]["inputs"],
|
||||
MultilineInput(
|
||||
|
|
@ -55,8 +58,17 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
value="You are a helpful assistant that can use tools to answer questions and perform tasks.",
|
||||
advanced=False,
|
||||
),
|
||||
IntInput(
|
||||
name="n_messages",
|
||||
display_name="Number of Chat History Messages",
|
||||
value=100,
|
||||
info="Number of chat history messages to retrieve.",
|
||||
advanced=True,
|
||||
show=True,
|
||||
),
|
||||
*LCToolsAgentComponent._base_inputs,
|
||||
*memory_inputs,
|
||||
# removed memory inputs from agent component
|
||||
# *memory_inputs,
|
||||
BoolInput(
|
||||
name="add_current_date_tool",
|
||||
display_name="Current Date",
|
||||
|
|
@ -78,6 +90,8 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
|
||||
# Get memory data
|
||||
self.chat_history = await self.get_memory_data()
|
||||
if isinstance(self.chat_history, Message):
|
||||
self.chat_history = [self.chat_history]
|
||||
|
||||
# Add current date tool if enabled
|
||||
if self.add_current_date_tool:
|
||||
|
|
@ -112,13 +126,11 @@ class AgentComponent(ToolCallingAgentComponent):
|
|||
raise
|
||||
|
||||
async def get_memory_data(self):
|
||||
memory_kwargs = {
|
||||
component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs
|
||||
}
|
||||
# filter out empty values
|
||||
memory_kwargs = {k: v for k, v in memory_kwargs.items() if v is not None}
|
||||
|
||||
return await MemoryComponent(**self.get_base_args()).set(**memory_kwargs).retrieve_messages()
|
||||
return (
|
||||
await MemoryComponent(**self.get_base_args())
|
||||
.set(session_id=self.graph.session_id, order="Ascending", n_messages=self.n_messages)
|
||||
.retrieve_messages()
|
||||
)
|
||||
|
||||
def get_llm(self):
|
||||
if not isinstance(self.agent_llm, str):
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ class MemoryComponent(Component):
|
|||
value=100,
|
||||
info="Number of messages to retrieve.",
|
||||
advanced=True,
|
||||
show=False,
|
||||
show=True,
|
||||
),
|
||||
MessageTextInput(
|
||||
name="session_id",
|
||||
|
|
@ -197,28 +197,33 @@ class MemoryComponent(Component):
|
|||
|
||||
stored = await self.memory.aget_messages()
|
||||
# langchain memories are supposed to return messages in ascending order
|
||||
|
||||
if order == "DESC":
|
||||
stored = stored[::-1]
|
||||
if n_messages:
|
||||
stored = stored[:n_messages]
|
||||
stored = stored[-n_messages:] if order == "ASC" else stored[:n_messages]
|
||||
stored = [Message.from_lc_message(m) for m in stored]
|
||||
if sender_type:
|
||||
expected_type = MESSAGE_SENDER_AI if sender_type == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER
|
||||
stored = [m for m in stored if m.type == expected_type]
|
||||
else:
|
||||
# For internal memory, we always fetch the last N messages by ordering by DESC
|
||||
stored = await aget_messages(
|
||||
sender=sender_type,
|
||||
sender_name=sender_name,
|
||||
session_id=session_id,
|
||||
limit=n_messages,
|
||||
limit=10000,
|
||||
order=order,
|
||||
)
|
||||
self.status = stored
|
||||
if n_messages:
|
||||
stored = stored[-n_messages:] if order == "ASC" else stored[:n_messages]
|
||||
|
||||
# self.status = stored
|
||||
return cast(Data, stored)
|
||||
|
||||
async def retrieve_messages_as_text(self) -> Message:
|
||||
stored_text = data_to_text(self.template, await self.retrieve_messages())
|
||||
self.status = stored_text
|
||||
# self.status = stored_text
|
||||
return Message(text=stored_text)
|
||||
|
||||
async def retrieve_messages_dataframe(self) -> DataFrame:
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ class Component(CustomComponent):
|
|||
"""
|
||||
return {
|
||||
"_user_id": self.user_id,
|
||||
"_session_id": self.session_id,
|
||||
"_session_id": self.graph.session_id,
|
||||
"_tracing_service": self._tracing_service,
|
||||
}
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -14,7 +14,6 @@ from langflow.base.models.openai_constants import (
|
|||
from langflow.components.agents.agent import AgentComponent
|
||||
from langflow.components.tools.calculator import CalculatorToolComponent
|
||||
from langflow.custom import Component
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI
|
||||
|
||||
from tests.base import ComponentTestBaseWithClient, ComponentTestBaseWithoutClient
|
||||
from tests.unit.mock_language_model import MockLanguageModel
|
||||
|
|
@ -50,9 +49,6 @@ class TestAgentComponent(ComponentTestBaseWithoutClient):
|
|||
"system_prompt": "You are a helpful assistant.",
|
||||
"tools": [],
|
||||
"verbose": True,
|
||||
"session_id": str(uuid4()),
|
||||
"sender": MESSAGE_SENDER_AI,
|
||||
"sender_name": MESSAGE_SENDER_NAME_AI,
|
||||
}
|
||||
|
||||
async def test_build_config_update(self, component_class, default_kwargs):
|
||||
|
|
@ -129,7 +125,7 @@ class TestAgentComponentWithClient(ComponentTestBaseWithClient):
|
|||
input_value=input_value,
|
||||
api_key=api_key,
|
||||
model_name="gpt-4o",
|
||||
llm_type="OpenAI",
|
||||
agent_llm="OpenAI",
|
||||
temperature=temperature,
|
||||
_session_id=str(uuid4()),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,78 +0,0 @@
|
|||
import { expect, test } from "@playwright/test";
|
||||
import * as dotenv from "dotenv";
|
||||
import path from "path";
|
||||
import { awaitBootstrapTest } from "../../utils/await-bootstrap-test";
|
||||
import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes";
|
||||
|
||||
withEventDeliveryModes(
|
||||
"Financial Agent",
|
||||
{ tag: ["@release", "@starter-projects"] },
|
||||
async ({ page }) => {
|
||||
test.skip(
|
||||
!process?.env?.TAVILY_API_KEY,
|
||||
"TAVILY_API_KEY required to run this test",
|
||||
);
|
||||
|
||||
test.skip(
|
||||
!process?.env?.SAMBANOVA_API_KEY,
|
||||
"SAMBANOVA_API_KEY required to run this test",
|
||||
);
|
||||
|
||||
await page.goto("/");
|
||||
await awaitBootstrapTest(page);
|
||||
|
||||
await page.getByTestId("side_nav_options_all-templates").click();
|
||||
await page.getByRole("heading", { name: "Financial Agent" }).click();
|
||||
|
||||
await page
|
||||
.getByTestId("popover-anchor-input-api_key")
|
||||
.nth(0)
|
||||
.fill(process.env.TAVILY_API_KEY ?? "");
|
||||
|
||||
for (let i = 0; i < 2; i++) {
|
||||
await page.getByTestId("dropdown_str_agent_llm").nth(i).click();
|
||||
await page.waitForTimeout(500);
|
||||
await page.getByRole("option", { name: "SambaNova" }).click();
|
||||
}
|
||||
|
||||
for (let i = 0; i < 3; i++) {
|
||||
await page
|
||||
.getByTestId("value-dropdown-dropdown_str_model_name")
|
||||
.nth(i)
|
||||
.click();
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.getByRole("option").first().click();
|
||||
}
|
||||
|
||||
for (let i = 1; i <= 3; i++) {
|
||||
await page
|
||||
.getByTestId("popover-anchor-input-api_key")
|
||||
.nth(i)
|
||||
.fill(process.env.SAMBANOVA_API_KEY ?? "");
|
||||
}
|
||||
|
||||
await page.getByTestId("playground-btn-flow-io").click();
|
||||
|
||||
await page
|
||||
.getByTestId("input-chat-playground")
|
||||
.last()
|
||||
.fill("Why did Nvidia stock drop in January?");
|
||||
|
||||
await page.getByTestId("button-send").last().click();
|
||||
|
||||
const stopButton = page.getByRole("button", { name: "Stop" });
|
||||
await stopButton.waitFor({ state: "visible", timeout: 30000 });
|
||||
|
||||
if (await stopButton.isVisible()) {
|
||||
await expect(stopButton).toBeHidden({ timeout: 120000 });
|
||||
}
|
||||
|
||||
const output = await page
|
||||
.getByTestId("div-chat-message")
|
||||
.last()
|
||||
.innerText();
|
||||
expect(output.toLowerCase()).toContain("nvidia");
|
||||
expect(output.length).toBeGreaterThan(100);
|
||||
},
|
||||
);
|
||||
|
|
@ -41,17 +41,6 @@ withEventDeliveryModes(
|
|||
.nth(0)
|
||||
.fill(process.env.TAVILY_API_KEY ?? "");
|
||||
|
||||
//* TODO: Remove these 5 steps once the template is updated *//
|
||||
await page.getByTestId("dropdown-output-openaimodel").click();
|
||||
|
||||
await page
|
||||
.getByTestId("dropdown-item-output-openaimodel-language model")
|
||||
.click();
|
||||
|
||||
await page
|
||||
.getByTestId("handle-structuredoutput-shownode-structured output-right")
|
||||
.click();
|
||||
|
||||
await page
|
||||
.getByTestId("handle-parser-shownode-data or dataframe-left")
|
||||
.click();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue