feat: assistants agent improvements (#5581)

* assistants agent improvements

* remove alembic init file

* vector store / file upload support

* use sync file object (required by sdk)

* steps

* self.tools initialization

* improvements for edwin

* add name and switch to MultilineInput

* ci fixes
This commit is contained in:
Sebastián Estévez 2025-01-16 15:54:34 -05:00 committed by GitHub
commit 2acd434e09
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 327 additions and 57 deletions

View file

@ -1,3 +1,4 @@
# noqa: INP001
from logging.config import fileConfig
from alembic import context

View file

@ -4,14 +4,20 @@ import json
import os
import pkgutil
import threading
import uuid
from json.decoder import JSONDecodeError
from pathlib import Path
from typing import Any
import astra_assistants.tools as astra_assistants_tools
import requests
from astra_assistants import OpenAIWithDefaultKey, patch
from astra_assistants.tools.tool_interface import ToolInterface
from langchain_core.tools import BaseTool
from pydantic import BaseModel
from requests.exceptions import RequestException
from langflow.components.tools.mcp_stdio import create_input_schema_from_json_schema
from langflow.services.cache.utils import CacheMiss
client_lock = threading.Lock()
@ -64,3 +70,95 @@ def tools_from_package(your_package) -> None:
tools_from_package(astra_assistants_tools)
def wrap_base_tool_as_tool_interface(base_tool: BaseTool) -> ToolInterface:
"""wrap_Base_tool_ass_tool_interface.
Wrap a BaseTool instance in a new class implementing ToolInterface,
building a dynamic Pydantic model from its args_schema (if any).
We only call `args_schema()` if it's truly a function/method,
avoiding accidental calls on a Pydantic model class (which is also callable).
"""
raw_args_schema = getattr(base_tool, "args_schema", None)
# --- 1) Distinguish between a function/method vs. class/dict/None ---
if inspect.isfunction(raw_args_schema) or inspect.ismethod(raw_args_schema):
# It's actually a function -> call it once to get a class or dict
raw_args_schema = raw_args_schema()
# Otherwise, if it's a class or dict, do nothing here
# Now `raw_args_schema` might be:
# - A Pydantic model class (subclass of BaseModel)
# - A dict (JSON schema)
# - None
# - Something unexpected => raise error
# --- 2) Convert the schema or model class to a JSON schema dict ---
if raw_args_schema is None:
# No schema => minimal
schema_dict = {"type": "object", "properties": {}}
elif isinstance(raw_args_schema, dict):
# Already a JSON schema
schema_dict = raw_args_schema
elif inspect.isclass(raw_args_schema) and issubclass(raw_args_schema, BaseModel):
# It's a Pydantic model class -> convert to JSON schema
schema_dict = raw_args_schema.schema()
else:
msg = f"args_schema must be a Pydantic model class, a JSON schema dict, or None. Got: {raw_args_schema!r}"
raise TypeError(msg)
# --- 3) Build our dynamic Pydantic model from the JSON schema ---
InputSchema: type[BaseModel] = create_input_schema_from_json_schema(schema_dict) # noqa: N806
# --- 4) Define a wrapper class that uses composition ---
class WrappedDynamicTool(ToolInterface):
"""WrappedDynamicTool.
Uses composition to delegate logic to the original base_tool,
but sets `call(..., arguments: InputSchema)` so we have a real model.
"""
def __init__(self, tool: BaseTool):
self._tool = tool
def call(self, arguments: InputSchema) -> dict: # type: ignore # noqa: PGH003
output = self._tool.invoke(arguments.dict()) # type: ignore # noqa: PGH003
result = ""
if "error" in output[0].data:
result = output[0].data["error"]
elif "result" in output[0].data:
result = output[0].data["result"]
return {"cache_id": str(uuid.uuid4()), "output": result}
def run(self, tool_input: Any) -> str:
return self._tool.run(tool_input)
def name(self) -> str:
"""Return the base tool's name if it exists."""
if hasattr(self._tool, "name"):
return str(self._tool.name)
return super().name()
def to_function(self):
"""Incorporate the base tool's description if present."""
params = InputSchema.schema()
description = getattr(self._tool, "description", "A dynamically wrapped tool")
return {
"type": "function",
"function": {"name": self.name(), "description": description, "parameters": params},
}
# Return an instance of our newly minted class
return WrappedDynamicTool(base_tool)
def sync_upload(file_path, client):
with Path(file_path).open("rb") as sync_file_handle:
return client.files.create(
file=sync_file_handle, # Pass the sync file handle
purpose="assistants",
)

View file

@ -1,70 +1,145 @@
import asyncio
from asyncio import to_thread
from typing import TYPE_CHECKING, Any, cast
from astra_assistants.astra_assistants_manager import AssistantManager
from langchain_core.agents import AgentFinish
from loguru import logger
from langflow.base.agents.events import ExceptionWithMessageError, process_agent_events
from langflow.base.astra_assistants.util import (
get_patched_openai_client,
litellm_model_names,
tool_names,
tools_and_names,
sync_upload,
wrap_base_tool_as_tool_interface,
)
from langflow.custom.custom_component.component_with_cache import ComponentWithCache
from langflow.inputs import DropdownInput, MultilineInput, StrInput
from langflow.inputs import DropdownInput, FileInput, HandleInput, MultilineInput
from langflow.memory import delete_message
from langflow.schema.content_block import ContentBlock
from langflow.schema.message import Message
from langflow.template import Output
from langflow.utils.constants import MESSAGE_SENDER_AI
if TYPE_CHECKING:
from langflow.schema.log import SendMessageFunctionType
class AstraAssistantManager(ComponentWithCache):
display_name = "Astra Assistant Manager"
display_name = "Astra Assistant Agent"
name = "Astra Assistant Agent"
description = "Manages Assistant Interactions"
icon = "AstraDB"
inputs = [
StrInput(
name="instructions",
display_name="Instructions",
info="Instructions for the assistant, think of these as the system prompt.",
),
DropdownInput(
name="model_name",
display_name="Model Name",
display_name="Model",
advanced=False,
options=litellm_model_names,
value="gpt-4o-mini",
),
DropdownInput(
display_name="Tool",
name="tool",
options=tool_names,
),
MultilineInput(
name="user_message",
display_name="User Message",
info="User message to pass to the run.",
name="instructions",
display_name="Agent Instructions",
info="Instructions for the assistant, think of these as the system prompt.",
),
HandleInput(
name="input_tools",
display_name="Tools",
input_types=["Tool"],
is_list=True,
required=False,
info="These are the tools that the agent can use to help with tasks.",
),
# DropdownInput(
# display_name="Tools",
# name="tool",
# options=tool_names,
# ),
MultilineInput(
name="user_message", display_name="User Message", info="User message to pass to the run.", tool_mode=True
),
FileInput(
name="file",
display_name="File(s) for retrieval",
list=True,
info="Files to be sent with the message.",
required=False,
show=True,
file_types=[
"txt",
"md",
"mdx",
"csv",
"json",
"yaml",
"yml",
"xml",
"html",
"htm",
"pdf",
"docx",
"py",
"sh",
"sql",
"js",
"ts",
"tsx",
"jpg",
"jpeg",
"png",
"bmp",
"image",
"zip",
"tar",
"tgz",
"bz2",
"gz",
"c",
"cpp",
"cs",
"css",
"go",
"java",
"php",
"rb",
"tex",
"doc",
"docx",
"ppt",
"pptx",
"xls",
"xlsx",
"jsonl",
],
),
MultilineInput(
name="input_thread_id",
display_name="Thread ID (optional)",
info="ID of the thread",
advanced=True,
),
MultilineInput(
name="input_assistant_id",
display_name="Assistant ID (optional)",
info="ID of the assistant",
advanced=True,
),
MultilineInput(
name="env_set",
display_name="Environment Set",
info="Dummy input to allow chaining with Dotenv Component.",
advanced=True,
),
]
outputs = [
Output(display_name="Assistant Response", name="assistant_response", method="get_assistant_response"),
Output(display_name="Tool output", name="tool_output", method="get_tool_output"),
Output(display_name="Thread Id", name="output_thread_id", method="get_thread_id"),
Output(display_name="Assistant Id", name="output_assistant_id", method="get_assistant_id"),
Output(display_name="Tool output", name="tool_output", method="get_tool_output", hidden=True),
Output(display_name="Thread Id", name="output_thread_id", method="get_thread_id", hidden=True),
Output(display_name="Assistant Id", name="output_assistant_id", method="get_assistant_id", hidden=True),
Output(display_name="Vector Store Id", name="output_vs_id", method="get_vs_id", hidden=True),
]
def __init__(self, **kwargs) -> None:
@ -75,22 +150,33 @@ class AstraAssistantManager(ComponentWithCache):
self._tool_output: Message = None # type: ignore[assignment]
self._thread_id: Message = None # type: ignore[assignment]
self._assistant_id: Message = None # type: ignore[assignment]
self._vs_id: Message = None # type: ignore[assignment]
self.client = get_patched_openai_client(self._shared_component_cache)
self.input_tools: list[Any]
async def get_assistant_response(self) -> Message:
await self.initialize()
self.status = self._assistant_response
return self._assistant_response
async def get_vs_id(self) -> Message:
await self.initialize()
self.status = self._vs_id
return self._vs_id
async def get_tool_output(self) -> Message:
await self.initialize()
self.status = self._tool_output
return self._tool_output
async def get_thread_id(self) -> Message:
await self.initialize()
self.status = self._thread_id
return self._thread_id
async def get_assistant_id(self) -> Message:
await self.initialize()
self.status = self._assistant_id
return self._assistant_id
async def initialize(self) -> None:
@ -101,19 +187,37 @@ class AstraAssistantManager(ComponentWithCache):
async def process_inputs(self) -> None:
logger.info(f"env_set is {self.env_set}")
logger.info(self.tool)
logger.info(self.input_tools)
tools = []
tool_obj = None
if self.tool:
tool_cls = tools_and_names[self.tool]
tool_obj = tool_cls()
if self.input_tools is None:
self.input_tools = []
for tool in self.input_tools:
tool_obj = wrap_base_tool_as_tool_interface(tool)
tools.append(tool_obj)
assistant_id = None
thread_id = None
if self.input_assistant_id:
assistant_id = self.input_assistant_id
if self.input_thread_id:
thread_id = self.input_thread_id
if hasattr(self, "graph"):
session_id = self.graph.session_id
elif hasattr(self, "_session_id"):
session_id = self._session_id
else:
session_id = None
agent_message = Message(
sender=MESSAGE_SENDER_AI,
sender_name=self.display_name or "Astra Assistant",
properties={"icon": "Bot", "state": "partial"},
content_blocks=[ContentBlock(title="Assistant Steps", contents=[])],
session_id=session_id,
)
assistant_manager = AssistantManager(
instructions=self.instructions,
model=self.model_name,
@ -124,12 +228,79 @@ class AstraAssistantManager(ComponentWithCache):
assistant_id=assistant_id,
)
content = self.user_message
result = await assistant_manager.run_thread(content=content, tool=tool_obj)
self._assistant_response = Message(text=result["text"])
if "decision" in result:
self._tool_output = Message(text=str(result["decision"].is_complete))
else:
self._tool_output = Message(text=result["text"])
self._thread_id = Message(text=assistant_manager.thread.id)
self._assistant_id = Message(text=assistant_manager.assistant.id)
if self.file:
file = await to_thread(sync_upload, self.file, assistant_manager.client)
vector_store = assistant_manager.client.beta.vector_stores.create(name="my_vs", file_ids=[file.id])
assistant_tools = assistant_manager.assistant.tools
assistant_tools += [{"type": "file_search"}]
assistant = assistant_manager.client.beta.assistants.update(
assistant_manager.assistant.id,
tools=assistant_tools,
tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}},
)
assistant_manager.assistant = assistant
async def step_iterator():
# Initial event
yield {"event": "on_chain_start", "name": "AstraAssistant", "data": {"input": {"text": self.user_message}}}
content = self.user_message
result = await assistant_manager.run_thread(content=content, tool=tool_obj)
# Tool usage if present
if "output" in result and "arguments" in result:
yield {"event": "on_tool_start", "name": "tool", "data": {"input": {"text": str(result["arguments"])}}}
yield {"event": "on_tool_end", "name": "tool", "data": {"output": result["output"]}}
if "file_search" in result and result["file_search"] is not None:
yield {"event": "on_tool_start", "name": "tool", "data": {"input": {"text": self.user_message}}}
file_search_str = ""
for chunk in result["file_search"].to_dict().get("chunks", []):
file_search_str += f"## Chunk ID: `{chunk['chunk_id']}`\n"
file_search_str += f"**Content:**\n\n```\n{chunk['content']}\n```\n\n"
if "score" in chunk:
file_search_str += f"**Score:** {chunk['score']}\n\n"
if "file_id" in chunk:
file_search_str += f"**File ID:** `{chunk['file_id']}`\n\n"
if "file_name" in chunk:
file_search_str += f"**File Name:** `{chunk['file_name']}`\n\n"
if "bytes" in chunk:
file_search_str += f"**Bytes:** {chunk['bytes']}\n\n"
if "search_string" in chunk:
file_search_str += f"**Search String:** {chunk['search_string']}\n\n"
yield {"event": "on_tool_end", "name": "tool", "data": {"output": file_search_str}}
if "text" not in result:
msg = f"No text in result, {result}"
raise ValueError(msg)
self._assistant_response = Message(text=result["text"])
if "decision" in result:
self._tool_output = Message(text=str(result["decision"].is_complete))
else:
self._tool_output = Message(text=result["text"])
self._thread_id = Message(text=assistant_manager.thread.id)
self._assistant_id = Message(text=assistant_manager.assistant.id)
# Final event - format it like AgentFinish to match the expected format
yield {
"event": "on_chain_end",
"name": "AstraAssistant",
"data": {"output": AgentFinish(return_values={"output": result["text"]}, log="")},
}
try:
if hasattr(self, "send_message"):
processed_result = await process_agent_events(
step_iterator(),
agent_message,
cast("SendMessageFunctionType", self.send_message),
)
self.status = processed_result
except ExceptionWithMessageError as e:
msg_id = e.agent_message.id
await delete_message(id_=msg_id)
await self._send_message_event(e.agent_message, category="remove_message")
raise
except Exception:
raise