feat: assistants agent improvements (#5581)

* assistants agent improvements

* remove alembic init file

* vector store / file upload support

* use sync file object (required by sdk)

* steps

* self.tools initialization

* improvements for edwin

* add name and switch to MultilineInput

* ci fixes
This commit is contained in:
Sebastián Estévez 2025-01-16 15:54:34 -05:00 committed by GitHub
commit 2acd434e09
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 327 additions and 57 deletions

View file

@ -75,7 +75,7 @@ dependencies = [
"langsmith==0.1.147",
"yfinance==0.2.50",
"wolframalpha==5.1.3",
"astra-assistants[tools]~=2.2.6",
"astra-assistants[tools]~=2.2.9",
"composio-langchain==0.6.13",
"composio-core==0.6.13",
"spider-client==0.1.24",

View file

@ -1,3 +1,4 @@
# noqa: INP001
from logging.config import fileConfig
from alembic import context

View file

@ -4,14 +4,20 @@ import json
import os
import pkgutil
import threading
import uuid
from json.decoder import JSONDecodeError
from pathlib import Path
from typing import Any
import astra_assistants.tools as astra_assistants_tools
import requests
from astra_assistants import OpenAIWithDefaultKey, patch
from astra_assistants.tools.tool_interface import ToolInterface
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from requests.exceptions import RequestException
from langflow.components.tools.mcp_stdio import create_input_schema_from_json_schema
from langflow.services.cache.utils import CacheMiss
client_lock = threading.Lock()
@ -64,3 +70,95 @@ def tools_from_package(your_package) -> None:
tools_from_package(astra_assistants_tools)
def wrap_base_tool_as_tool_interface(base_tool: BaseTool) -> ToolInterface:
"""wrap_Base_tool_ass_tool_interface.
Wrap a BaseTool instance in a new class implementing ToolInterface,
building a dynamic Pydantic model from its args_schema (if any).
We only call `args_schema()` if it's truly a function/method,
avoiding accidental calls on a Pydantic model class (which is also callable).
"""
raw_args_schema = getattr(base_tool, "args_schema", None)
# --- 1) Distinguish between a function/method vs. class/dict/None ---
if inspect.isfunction(raw_args_schema) or inspect.ismethod(raw_args_schema):
# It's actually a function -> call it once to get a class or dict
raw_args_schema = raw_args_schema()
# Otherwise, if it's a class or dict, do nothing here
# Now `raw_args_schema` might be:
# - A Pydantic model class (subclass of BaseModel)
# - A dict (JSON schema)
# - None
# - Something unexpected => raise error
# --- 2) Convert the schema or model class to a JSON schema dict ---
if raw_args_schema is None:
# No schema => minimal
schema_dict = {"type": "object", "properties": {}}
elif isinstance(raw_args_schema, dict):
# Already a JSON schema
schema_dict = raw_args_schema
elif inspect.isclass(raw_args_schema) and issubclass(raw_args_schema, BaseModel):
# It's a Pydantic model class -> convert to JSON schema
schema_dict = raw_args_schema.schema()
else:
msg = f"args_schema must be a Pydantic model class, a JSON schema dict, or None. Got: {raw_args_schema!r}"
raise TypeError(msg)
# --- 3) Build our dynamic Pydantic model from the JSON schema ---
InputSchema: type[BaseModel] = create_input_schema_from_json_schema(schema_dict) # noqa: N806
# --- 4) Define a wrapper class that uses composition ---
class WrappedDynamicTool(ToolInterface):
"""WrappedDynamicTool.
Uses composition to delegate logic to the original base_tool,
but sets `call(..., arguments: InputSchema)` so we have a real model.
"""
def __init__(self, tool: BaseTool):
self._tool = tool
def call(self, arguments: InputSchema) -> dict: # type: ignore # noqa: PGH003
output = self._tool.invoke(arguments.dict()) # type: ignore # noqa: PGH003
result = ""
if "error" in output[0].data:
result = output[0].data["error"]
elif "result" in output[0].data:
result = output[0].data["result"]
return {"cache_id": str(uuid.uuid4()), "output": result}
def run(self, tool_input: Any) -> str:
return self._tool.run(tool_input)
def name(self) -> str:
"""Return the base tool's name if it exists."""
if hasattr(self._tool, "name"):
return str(self._tool.name)
return super().name()
def to_function(self):
"""Incorporate the base tool's description if present."""
params = InputSchema.schema()
description = getattr(self._tool, "description", "A dynamically wrapped tool")
return {
"type": "function",
"function": {"name": self.name(), "description": description, "parameters": params},
}
# Return an instance of our newly minted class
return WrappedDynamicTool(base_tool)
def sync_upload(file_path, client):
with Path(file_path).open("rb") as sync_file_handle:
return client.files.create(
file=sync_file_handle, # Pass the sync file handle
purpose="assistants",
)

View file

@ -1,70 +1,145 @@
import asyncio
from asyncio import to_thread
from typing import TYPE_CHECKING, Any, cast
from astra_assistants.astra_assistants_manager import AssistantManager
from langchain_core.agents import AgentFinish
from loguru import logger
from langflow.base.agents.events import ExceptionWithMessageError, process_agent_events
from langflow.base.astra_assistants.util import (
get_patched_openai_client,
litellm_model_names,
tool_names,
tools_and_names,
sync_upload,
wrap_base_tool_as_tool_interface,
)
from langflow.custom.custom_component.component_with_cache import ComponentWithCache
from langflow.inputs import DropdownInput, MultilineInput, StrInput
from langflow.inputs import DropdownInput, FileInput, HandleInput, MultilineInput
from langflow.memory import delete_message
from langflow.schema.content_block import ContentBlock
from langflow.schema.message import Message
from langflow.template import Output
from langflow.utils.constants import MESSAGE_SENDER_AI
if TYPE_CHECKING:
from langflow.schema.log import SendMessageFunctionType
class AstraAssistantManager(ComponentWithCache):
display_name = "Astra Assistant Manager"
display_name = "Astra Assistant Agent"
name = "Astra Assistant Agent"
description = "Manages Assistant Interactions"
icon = "AstraDB"
inputs = [
StrInput(
name="instructions",
display_name="Instructions",
info="Instructions for the assistant, think of these as the system prompt.",
),
DropdownInput(
name="model_name",
display_name="Model Name",
display_name="Model",
advanced=False,
options=litellm_model_names,
value="gpt-4o-mini",
),
DropdownInput(
display_name="Tool",
name="tool",
options=tool_names,
),
MultilineInput(
name="user_message",
display_name="User Message",
info="User message to pass to the run.",
name="instructions",
display_name="Agent Instructions",
info="Instructions for the assistant, think of these as the system prompt.",
),
HandleInput(
name="input_tools",
display_name="Tools",
input_types=["Tool"],
is_list=True,
required=False,
info="These are the tools that the agent can use to help with tasks.",
),
# DropdownInput(
# display_name="Tools",
# name="tool",
# options=tool_names,
# ),
MultilineInput(
name="user_message", display_name="User Message", info="User message to pass to the run.", tool_mode=True
),
FileInput(
name="file",
display_name="File(s) for retrieval",
list=True,
info="Files to be sent with the message.",
required=False,
show=True,
file_types=[
"txt",
"md",
"mdx",
"csv",
"json",
"yaml",
"yml",
"xml",
"html",
"htm",
"pdf",
"docx",
"py",
"sh",
"sql",
"js",
"ts",
"tsx",
"jpg",
"jpeg",
"png",
"bmp",
"image",
"zip",
"tar",
"tgz",
"bz2",
"gz",
"c",
"cpp",
"cs",
"css",
"go",
"java",
"php",
"rb",
"tex",
"doc",
"docx",
"ppt",
"pptx",
"xls",
"xlsx",
"jsonl",
],
),
MultilineInput(
name="input_thread_id",
display_name="Thread ID (optional)",
info="ID of the thread",
advanced=True,
),
MultilineInput(
name="input_assistant_id",
display_name="Assistant ID (optional)",
info="ID of the assistant",
advanced=True,
),
MultilineInput(
name="env_set",
display_name="Environment Set",
info="Dummy input to allow chaining with Dotenv Component.",
advanced=True,
),
]
outputs = [
Output(display_name="Assistant Response", name="assistant_response", method="get_assistant_response"),
Output(display_name="Tool output", name="tool_output", method="get_tool_output"),
Output(display_name="Thread Id", name="output_thread_id", method="get_thread_id"),
Output(display_name="Assistant Id", name="output_assistant_id", method="get_assistant_id"),
Output(display_name="Tool output", name="tool_output", method="get_tool_output", hidden=True),
Output(display_name="Thread Id", name="output_thread_id", method="get_thread_id", hidden=True),
Output(display_name="Assistant Id", name="output_assistant_id", method="get_assistant_id", hidden=True),
Output(display_name="Vector Store Id", name="output_vs_id", method="get_vs_id", hidden=True),
]
def __init__(self, **kwargs) -> None:
@ -75,22 +150,33 @@ class AstraAssistantManager(ComponentWithCache):
self._tool_output: Message = None # type: ignore[assignment]
self._thread_id: Message = None # type: ignore[assignment]
self._assistant_id: Message = None # type: ignore[assignment]
self._vs_id: Message = None # type: ignore[assignment]
self.client = get_patched_openai_client(self._shared_component_cache)
self.input_tools: list[Any]
async def get_assistant_response(self) -> Message:
await self.initialize()
self.status = self._assistant_response
return self._assistant_response
async def get_vs_id(self) -> Message:
await self.initialize()
self.status = self._vs_id
return self._vs_id
async def get_tool_output(self) -> Message:
await self.initialize()
self.status = self._tool_output
return self._tool_output
async def get_thread_id(self) -> Message:
await self.initialize()
self.status = self._thread_id
return self._thread_id
async def get_assistant_id(self) -> Message:
await self.initialize()
self.status = self._assistant_id
return self._assistant_id
async def initialize(self) -> None:
@ -101,19 +187,37 @@ class AstraAssistantManager(ComponentWithCache):
async def process_inputs(self) -> None:
logger.info(f"env_set is {self.env_set}")
logger.info(self.tool)
logger.info(self.input_tools)
tools = []
tool_obj = None
if self.tool:
tool_cls = tools_and_names[self.tool]
tool_obj = tool_cls()
if self.input_tools is None:
self.input_tools = []
for tool in self.input_tools:
tool_obj = wrap_base_tool_as_tool_interface(tool)
tools.append(tool_obj)
assistant_id = None
thread_id = None
if self.input_assistant_id:
assistant_id = self.input_assistant_id
if self.input_thread_id:
thread_id = self.input_thread_id
if hasattr(self, "graph"):
session_id = self.graph.session_id
elif hasattr(self, "_session_id"):
session_id = self._session_id
else:
session_id = None
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name=self.display_name or "Astra Assistant",
properties={"icon": "Bot", "state": "partial"},
content_blocks=[ContentBlock(title="Assistant Steps", contents=[])],
session_id=session_id,
)
assistant_manager = AssistantManager(
instructions=self.instructions,
model=self.model_name,
@ -124,12 +228,79 @@ class AstraAssistantManager(ComponentWithCache):
assistant_id=assistant_id,
)
content = self.user_message
result = await assistant_manager.run_thread(content=content, tool=tool_obj)
self._assistant_response = Message(text=result["text"])
if "decision" in result:
self._tool_output = Message(text=str(result["decision"].is_complete))
else:
self._tool_output = Message(text=result["text"])
self._thread_id = Message(text=assistant_manager.thread.id)
self._assistant_id = Message(text=assistant_manager.assistant.id)
if self.file:
file = await to_thread(sync_upload, self.file, assistant_manager.client)
vector_store = assistant_manager.client.beta.vector_stores.create(name="my_vs", file_ids=[file.id])
assistant_tools = assistant_manager.assistant.tools
assistant_tools += [{"type": "file_search"}]
assistant = assistant_manager.client.beta.assistants.update(
assistant_manager.assistant.id,
tools=assistant_tools,
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
assistant_manager.assistant = assistant
async def step_iterator():
# Initial event
yield {"event": "on_chain_start", "name": "AstraAssistant", "data": {"input": {"text": self.user_message}}}
content = self.user_message
result = await assistant_manager.run_thread(content=content, tool=tool_obj)
# Tool usage if present
if "output" in result and "arguments" in result:
yield {"event": "on_tool_start", "name": "tool", "data": {"input": {"text": str(result["arguments"])}}}
yield {"event": "on_tool_end", "name": "tool", "data": {"output": result["output"]}}
if "file_search" in result and result["file_search"] is not None:
yield {"event": "on_tool_start", "name": "tool", "data": {"input": {"text": self.user_message}}}
file_search_str = ""
for chunk in result["file_search"].to_dict().get("chunks", []):
file_search_str += f"## Chunk ID: `{chunk['chunk_id']}`\n"
file_search_str += f"**Content:**\n\n```\n{chunk['content']}\n```\n\n"
if "score" in chunk:
file_search_str += f"**Score:** {chunk['score']}\n\n"
if "file_id" in chunk:
file_search_str += f"**File ID:** `{chunk['file_id']}`\n\n"
if "file_name" in chunk:
file_search_str += f"**File Name:** `{chunk['file_name']}`\n\n"
if "bytes" in chunk:
file_search_str += f"**Bytes:** {chunk['bytes']}\n\n"
if "search_string" in chunk:
file_search_str += f"**Search String:** {chunk['search_string']}\n\n"
yield {"event": "on_tool_end", "name": "tool", "data": {"output": file_search_str}}
if "text" not in result:
msg = f"No text in result, {result}"
raise ValueError(msg)
self._assistant_response = Message(text=result["text"])
if "decision" in result:
self._tool_output = Message(text=str(result["decision"].is_complete))
else:
self._tool_output = Message(text=result["text"])
self._thread_id = Message(text=assistant_manager.thread.id)
self._assistant_id = Message(text=assistant_manager.assistant.id)
# Final event - format it like AgentFinish to match the expected format
yield {
"event": "on_chain_end",
"name": "AstraAssistant",
"data": {"output": AgentFinish(return_values={"output": result["text"]}, log="")},
}
try:
if hasattr(self, "send_message"):
processed_result = await process_agent_events(
step_iterator(),
agent_message,
cast("SendMessageFunctionType", self.send_message),
)
self.status = processed_result
except ExceptionWithMessageError as e:
msg_id = e.agent_message.id
await delete_message(id_=msg_id)
await self._send_message_event(e.agent_message, category="remove_message")
raise
except Exception:
raise

44
uv.lock generated
View file

@ -324,7 +324,7 @@ wheels = [
[[package]]
name = "astra-assistants"
version = "2.2.7"
version = "2.2.9"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
@ -339,9 +339,9 @@ dependencies = [
{ name = "tree-sitter" },
{ name = "tree-sitter-python" },
]
sdist = { url = "https://files.pythonhosted.org/packages/81/e2/c440ba3fe475088537c7c258b2f7689b1c9724bf46cd62d7de77e3ddc79f/astra_assistants-2.2.7.tar.gz", hash = "sha256:dd88adad9a74c9839c6faade1ccdfb47b827c82f2ed2a6da92731b4190506774", size = 67687 }
sdist = { url = "https://files.pythonhosted.org/packages/05/88/37b7ba47e7e639588a9068bfc90b4f3cbd964a8a2f4153e69b40e2165648/astra_assistants-2.2.9.tar.gz", hash = "sha256:b33e6a31d08155917e6b5413f986c278efcaa8e1c5a03ca1563e92ca0130a807", size = 67838 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/25/61/8a165e4ed492dae66d278d485eb505689273f2a266e34e2a21bac0bec4a6/astra_assistants-2.2.7-py3-none-any.whl", hash = "sha256:2d12999f97f57a45f24c3236af7b8792de317584e58cc7eaf771e5a440a26f2d", size = 78374 },
{ url = "https://files.pythonhosted.org/packages/f3/32/30a69010077a71ef5fd80c71296b2976946b140e8274ff37b44552c54ef4/astra_assistants-2.2.9-py3-none-any.whl", hash = "sha256:b5b2713cd32ac2050e4f28b1d748cb4701c021d802912837b943c65861699a4e", size = 78527 },
]
[package.optional-dependencies]
@ -532,7 +532,7 @@ name = "blessed"
version = "1.20.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinxed", marker = "platform_system == 'Windows'" },
{ name = "jinxed", marker = "sys_platform == 'win32'" },
{ name = "six" },
{ name = "wcwidth" },
]
@ -954,7 +954,7 @@ name = "click"
version = "8.1.8"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 }
wheels = [
@ -3156,7 +3156,7 @@ name = "ipykernel"
version = "6.29.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "appnope", marker = "platform_system == 'Darwin'" },
{ name = "appnope", marker = "sys_platform == 'darwin'" },
{ name = "comm" },
{ name = "debugpy" },
{ name = "ipython" },
@ -3247,7 +3247,7 @@ name = "jinxed"
version = "1.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "ansicon", marker = "platform_system == 'Windows'" },
{ name = "ansicon", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/20/d0/59b2b80e7a52d255f9e0ad040d2e826342d05580c4b1d7d7747cfb8db731/jinxed-1.3.0.tar.gz", hash = "sha256:1593124b18a41b7a3da3b078471442e51dbad3d77b4d4f2b0c26ab6f7d660dbf", size = 80981 }
wheels = [
@ -4094,7 +4094,7 @@ requires-dist = [
{ name = "aiofile", specifier = ">=3.9.0,<4.0.0" },
{ name = "arize-phoenix-otel", specifier = ">=0.6.1" },
{ name = "assemblyai", specifier = "==0.35.1" },
{ name = "astra-assistants", extras = ["tools"], specifier = "~=2.2.6" },
{ name = "astra-assistants", extras = ["tools"], specifier = "~=2.2.9" },
{ name = "atlassian-python-api", specifier = "==3.41.16" },
{ name = "beautifulsoup4", specifier = "==4.12.3" },
{ name = "boto3", specifier = "==1.34.162" },
@ -6164,7 +6164,7 @@ name = "portalocker"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
wheels = [
@ -8661,19 +8661,19 @@ dependencies = [
{ name = "fsspec" },
{ name = "jinja2" },
{ name = "networkx" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "sympy" },
{ name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "typing-extensions" },
]
wheels = [
@ -8714,7 +8714,7 @@ name = "tqdm"
version = "4.67.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 }
wheels = [