refactor(agent): standardize memory handling and update chat history logic (#8715)

* update chat history

* update to agents

* Update Simple Agent.json

* update to templates

* ruff errors

* Update agent.py

* Update test_agent_component.py

* [autofix.ci] apply automated fixes

* update templates

* test fix

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Mike Fortman <michael.fortman@datastax.com>
This commit is contained in:
Edwin Jose 2025-06-25 11:33:57 -05:00 committed by GitHub
commit 38d5885fa3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 8672 additions and 17583 deletions

View file

@ -139,7 +139,10 @@ class LCAgentComponent(Component):
if hasattr(self, "system_prompt"):
input_dict["system_prompt"] = self.system_prompt
if hasattr(self, "chat_history") and self.chat_history:
input_dict["chat_history"] = data_to_messages(self.chat_history)
if isinstance(self.chat_history, Data):
input_dict["chat_history"] = data_to_messages(self.chat_history)
if all(isinstance(m, Message) for m in self.chat_history):
input_dict["chat_history"] = data_to_messages([m.to_data() for m in self.chat_history])
if hasattr(self, "graph"):
session_id = self.graph.session_id

View file

@ -16,7 +16,7 @@ from langflow.components.langchain_utilities.tool_calling import ToolCallingAgen
from langflow.custom.custom_component.component import _get_component_toolkit
from langflow.custom.utils import update_component_build_config
from langflow.field_typing import Tool
from langflow.io import BoolInput, DropdownInput, MultilineInput, Output
from langflow.io import BoolInput, DropdownInput, IntInput, MultilineInput, Output
from langflow.logging import logger
from langflow.schema.dotdict import dotdict
from langflow.schema.message import Message
@ -27,6 +27,9 @@ def set_advanced_true(component_input):
return component_input
MODEL_PROVIDERS_LIST = ["Anthropic", "Google Generative AI", "Groq", "OpenAI"]
class AgentComponent(ToolCallingAgentComponent):
display_name: str = "Agent"
description: str = "Define the agent's instructions, then enter a task to complete using tools."
@ -41,11 +44,11 @@ class AgentComponent(ToolCallingAgentComponent):
name="agent_llm",
display_name="Model Provider",
info="The provider of the language model that the agent will use to generate responses.",
options=[*sorted(MODEL_PROVIDERS), "Custom"],
options=[*MODEL_PROVIDERS_LIST, "Custom"],
value="OpenAI",
real_time_refresh=True,
input_types=[],
options_metadata=[MODELS_METADATA[key] for key in sorted(MODEL_PROVIDERS)] + [{"icon": "brain"}],
options_metadata=[MODELS_METADATA[key] for key in MODEL_PROVIDERS_LIST] + [{"icon": "brain"}],
),
*MODEL_PROVIDERS_DICT["OpenAI"]["inputs"],
MultilineInput(
@ -55,8 +58,17 @@ class AgentComponent(ToolCallingAgentComponent):
value="You are a helpful assistant that can use tools to answer questions and perform tasks.",
advanced=False,
),
IntInput(
name="n_messages",
display_name="Number of Chat History Messages",
value=100,
info="Number of chat history messages to retrieve.",
advanced=True,
show=True,
),
*LCToolsAgentComponent._base_inputs,
*memory_inputs,
# removed memory inputs from agent component
# *memory_inputs,
BoolInput(
name="add_current_date_tool",
display_name="Current Date",
@ -78,6 +90,8 @@ class AgentComponent(ToolCallingAgentComponent):
# Get memory data
self.chat_history = await self.get_memory_data()
if isinstance(self.chat_history, Message):
self.chat_history = [self.chat_history]
# Add current date tool if enabled
if self.add_current_date_tool:
@ -112,13 +126,11 @@ class AgentComponent(ToolCallingAgentComponent):
raise
async def get_memory_data(self):
memory_kwargs = {
component_input.name: getattr(self, f"{component_input.name}") for component_input in self.memory_inputs
}
# filter out empty values
memory_kwargs = {k: v for k, v in memory_kwargs.items() if v is not None}
return await MemoryComponent(**self.get_base_args()).set(**memory_kwargs).retrieve_messages()
return (
await MemoryComponent(**self.get_base_args())
.set(session_id=self.graph.session_id, order="Ascending", n_messages=self.n_messages)
.retrieve_messages()
)
def get_llm(self):
if not isinstance(self.agent_llm, str):

View file

@ -76,7 +76,7 @@ class MemoryComponent(Component):
value=100,
info="Number of messages to retrieve.",
advanced=True,
show=False,
show=True,
),
MessageTextInput(
name="session_id",
@ -197,28 +197,33 @@ class MemoryComponent(Component):
stored = await self.memory.aget_messages()
# langchain memories are supposed to return messages in ascending order
if order == "DESC":
stored = stored[::-1]
if n_messages:
stored = stored[:n_messages]
stored = stored[-n_messages:] if order == "ASC" else stored[:n_messages]
stored = [Message.from_lc_message(m) for m in stored]
if sender_type:
expected_type = MESSAGE_SENDER_AI if sender_type == MESSAGE_SENDER_AI else MESSAGE_SENDER_USER
stored = [m for m in stored if m.type == expected_type]
else:
# For internal memory, we always fetch the last N messages by ordering by DESC
stored = await aget_messages(
sender=sender_type,
sender_name=sender_name,
session_id=session_id,
limit=n_messages,
limit=10000,
order=order,
)
self.status = stored
if n_messages:
stored = stored[-n_messages:] if order == "ASC" else stored[:n_messages]
# self.status = stored
return cast(Data, stored)
async def retrieve_messages_as_text(self) -> Message:
stored_text = data_to_text(self.template, await self.retrieve_messages())
self.status = stored_text
# self.status = stored_text
return Message(text=stored_text)
async def retrieve_messages_dataframe(self) -> DataFrame:

View file

@ -215,7 +215,7 @@ class Component(CustomComponent):
"""
return {
"_user_id": self.user_id,
"_session_id": self.session_id,
"_session_id": self.graph.session_id,
"_tracing_service": self._tracing_service,
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -14,7 +14,6 @@ from langflow.base.models.openai_constants import (
from langflow.components.agents.agent import AgentComponent
from langflow.components.tools.calculator import CalculatorToolComponent
from langflow.custom import Component
from langflow.utils.constants import MESSAGE_SENDER_AI, MESSAGE_SENDER_NAME_AI
from tests.base import ComponentTestBaseWithClient, ComponentTestBaseWithoutClient
from tests.unit.mock_language_model import MockLanguageModel
@ -50,9 +49,6 @@ class TestAgentComponent(ComponentTestBaseWithoutClient):
"system_prompt": "You are a helpful assistant.",
"tools": [],
"verbose": True,
"session_id": str(uuid4()),
"sender": MESSAGE_SENDER_AI,
"sender_name": MESSAGE_SENDER_NAME_AI,
}
async def test_build_config_update(self, component_class, default_kwargs):
@ -129,7 +125,7 @@ class TestAgentComponentWithClient(ComponentTestBaseWithClient):
input_value=input_value,
api_key=api_key,
model_name="gpt-4o",
llm_type="OpenAI",
agent_llm="OpenAI",
temperature=temperature,
_session_id=str(uuid4()),
)

View file

@ -1,78 +0,0 @@
import { expect, test } from "@playwright/test";
import * as dotenv from "dotenv";
import path from "path";
import { awaitBootstrapTest } from "../../utils/await-bootstrap-test";
import { withEventDeliveryModes } from "../../utils/withEventDeliveryModes";
withEventDeliveryModes(
"Financial Agent",
{ tag: ["@release", "@starter-projects"] },
async ({ page }) => {
test.skip(
!process?.env?.TAVILY_API_KEY,
"TAVILY_API_KEY required to run this test",
);
test.skip(
!process?.env?.SAMBANOVA_API_KEY,
"SAMBANOVA_API_KEY required to run this test",
);
await page.goto("/");
await awaitBootstrapTest(page);
await page.getByTestId("side_nav_options_all-templates").click();
await page.getByRole("heading", { name: "Financial Agent" }).click();
await page
.getByTestId("popover-anchor-input-api_key")
.nth(0)
.fill(process.env.TAVILY_API_KEY ?? "");
for (let i = 0; i < 2; i++) {
await page.getByTestId("dropdown_str_agent_llm").nth(i).click();
await page.waitForTimeout(500);
await page.getByRole("option", { name: "SambaNova" }).click();
}
for (let i = 0; i < 3; i++) {
await page
.getByTestId("value-dropdown-dropdown_str_model_name")
.nth(i)
.click();
await page.waitForTimeout(500);
await page.getByRole("option").first().click();
}
for (let i = 1; i <= 3; i++) {
await page
.getByTestId("popover-anchor-input-api_key")
.nth(i)
.fill(process.env.SAMBANOVA_API_KEY ?? "");
}
await page.getByTestId("playground-btn-flow-io").click();
await page
.getByTestId("input-chat-playground")
.last()
.fill("Why did Nvidia stock drop in January?");
await page.getByTestId("button-send").last().click();
const stopButton = page.getByRole("button", { name: "Stop" });
await stopButton.waitFor({ state: "visible", timeout: 30000 });
if (await stopButton.isVisible()) {
await expect(stopButton).toBeHidden({ timeout: 120000 });
}
const output = await page
.getByTestId("div-chat-message")
.last()
.innerText();
expect(output.toLowerCase()).toContain("nvidia");
expect(output.length).toBeGreaterThan(100);
},
);

View file

@ -41,17 +41,6 @@ withEventDeliveryModes(
.nth(0)
.fill(process.env.TAVILY_API_KEY ?? "");
//* TODO: Remove these 5 steps once the template is updated *//
await page.getByTestId("dropdown-output-openaimodel").click();
await page
.getByTestId("dropdown-item-output-openaimodel-language model")
.click();
await page
.getByTestId("handle-structuredoutput-shownode-structured output-right")
.click();
await page
.getByTestId("handle-parser-shownode-data or dataframe-left")
.click();