feat: Needle Search Tool With Template (#6648)

* feat: Needle Search Tool With Templte

* lint

* lint

* lint

* lint

* refactor: Use Langflow Agent instead of CrewAI Agent

* techdebt: adjust Needle component to use tool mode and remove tool component

* lint

* lint

* Update Invoice Summarizer.json

* Update Invoice Summarizer.json

* update to the component

* refactor: Use Needle icon svg

* make format

* component updates

* update with latest agent component

* updated a missing connection when updating the agent component

* update template

---------

Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
Jan Heimes 2025-03-14 10:37:54 +01:00 committed by GitHub
commit 94bc8dbc7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2122 additions and 99 deletions

View file

@ -1,16 +1,14 @@
from langchain.chains import ConversationalRetrievalChain
from langchain_community.retrievers.needle import NeedleRetriever
from langchain_openai import ChatOpenAI
from langflow.custom.custom_component.component import Component
from langflow.io import DropdownInput, Output, SecretStrInput, StrInput
from langflow.io import IntInput, MessageTextInput, Output, SecretStrInput
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_AI
class NeedleComponent(Component):
display_name = "Needle Retriever"
description = "A retriever that uses the Needle API to search collections and generates responses using OpenAI."
description = "A retriever that uses the Needle API to search collections."
documentation = "https://docs.needle-ai.com"
icon = "Needle"
name = "needle"
@ -22,30 +20,24 @@ class NeedleComponent(Component):
info="Your Needle API key.",
required=True,
),
SecretStrInput(
name="openai_api_key",
display_name="OpenAI API Key",
info="Your OpenAI API key.",
required=True,
),
StrInput(
MessageTextInput(
name="collection_id",
display_name="Collection ID",
info="The ID of the Needle collection.",
required=True,
),
StrInput(
MessageTextInput(
name="query",
display_name="User Query",
info="Enter your question here.",
info="Enter your question here. In tool mode, you can also specify top_k parameter (min: 20).",
required=True,
tool_mode=True,
),
DropdownInput(
name="output_type",
display_name="Output Type",
info="Return either the answer or the chunks.",
options=["answer", "chunks"],
value="answer",
IntInput(
name="top_k",
display_name="Top K Results",
info="Number of search results to return (min: 20).",
value=20,
required=True,
),
]
@ -53,74 +45,57 @@ class NeedleComponent(Component):
outputs = [Output(display_name="Result", name="result", type_="Message", method="run")]
def run(self) -> Message:
needle_api_key = self.needle_api_key or ""
openai_api_key = self.openai_api_key or ""
collection_id = self.collection_id
query = self.query
output_type = self.output_type
# Extract query and top_k
query_input = self.query
actual_query = query_input.get("query", "") if isinstance(query_input, dict) else query_input
# Define error messages
needle_api_key = "The Needle API key cannot be empty."
openai_api_key = "The OpenAI API key cannot be empty."
collection_id_error = "The Collection ID cannot be empty."
query_error = "The query cannot be empty."
# Parse top_k from tool input or use default, always enforcing minimum of 20
try:
if isinstance(query_input, dict) and "top_k" in query_input:
agent_top_k = query_input.get("top_k")
# Check if agent_top_k is not None before converting to int
top_k = max(20, int(agent_top_k)) if agent_top_k is not None else max(20, self.top_k)
else:
top_k = max(20, self.top_k)
except (ValueError, TypeError):
top_k = max(20, self.top_k)
# Validate inputs
if not needle_api_key.strip():
raise ValueError(needle_api_key)
if not openai_api_key.strip():
raise ValueError(openai_api_key)
if not collection_id.strip():
raise ValueError(collection_id_error)
if not query.strip():
raise ValueError(query_error)
# Handle output_type if it's somehow a list
if isinstance(output_type, list):
output_type = output_type[0]
# Validate required inputs
if not self.needle_api_key or not self.needle_api_key.strip():
error_msg = "The Needle API key cannot be empty."
raise ValueError(error_msg)
if not self.collection_id or not self.collection_id.strip():
error_msg = "The Collection ID cannot be empty."
raise ValueError(error_msg)
if not actual_query or not actual_query.strip():
error_msg = "The query cannot be empty."
raise ValueError(error_msg)
try:
# Initialize the retriever
# Initialize the retriever and get documents
retriever = NeedleRetriever(
needle_api_key=needle_api_key,
collection_id=collection_id,
needle_api_key=self.needle_api_key,
collection_id=self.collection_id,
top_k=top_k,
)
# Create the chain
llm = ChatOpenAI(
temperature=0.7,
api_key=openai_api_key,
)
docs = retriever.get_relevant_documents(actual_query)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
return_source_documents=True,
)
# Process the query
result = qa_chain({"question": query, "chat_history": []})
# Format content based on output type
if str(output_type).lower().strip() == "chunks":
# If chunks selected, include full context and answer
docs = result["source_documents"]
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
text_content = f"Question: {query}\n\nContext:\n{context}\n\nAnswer: {result['answer']}"
# Format the response
if not docs:
text_content = "No relevant documents found for the query."
else:
# If answer selected, only include the answer
text_content = result["answer"]
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
text_content = f"Question: {actual_query}\n\nContext:\n{context}"
# Create a Message object following chat.py pattern
# Return formatted message
return Message(
text=text_content,
type="assistant",
sender=MESSAGE_SENDER_AI,
additional_kwargs={
"source_documents": [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in result["source_documents"]
]
"source_documents": [{"page_content": doc.page_content, "metadata": doc.metadata} for doc in docs],
"top_k_used": top_k,
},
)

File diff suppressed because one or more lines are too long

View file

@ -6,6 +6,7 @@ from langflow.api.v1.schemas import ResultDataResponse, VertexBuildResponse
from langflow.schema.schema import OutputValue
from langflow.serialization import serialize
from langflow.services.tracing.schema import Log
from pydantic import BaseModel
# Use a smaller test size for hypothesis

View file

@ -8,6 +8,7 @@ from langflow.components.outputs.chat import ChatOutput
from langflow.components.tools.calculator import CalculatorToolComponent
from langflow.graph import Graph
from langflow.schema.data import Data
from pydantic import BaseModel

View file

@ -21,6 +21,7 @@ from langflow.io import (
TableInput,
)
from langflow.schema import Data
from pydantic import BaseModel

View file

@ -5,8 +5,8 @@ import pytest
from langflow.components.helpers.structured_output import StructuredOutputComponent
from langflow.helpers.base_model import build_model_from_schema
from langflow.inputs.inputs import TableInput
from pydantic import BaseModel
from pydantic import BaseModel
from tests.base import ComponentTestBaseWithoutClient
from tests.unit.mock_language_model import MockLanguageModel

View file

@ -5,6 +5,7 @@ from langflow.graph import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.state.model import create_state_model
from langflow.template.field.base import UNDEFINED
from pydantic import Field

View file

@ -4,9 +4,10 @@ from typing import Any
import pytest
from langflow.helpers.base_model import build_model_from_schema
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from pydantic import BaseModel
class TestBuildModelFromSchema:
# Successfully creates a Pydantic model from a valid schema

View file

@ -23,6 +23,7 @@ from langflow.inputs.inputs import (
)
from langflow.inputs.utils import instantiate_input
from langflow.schema.message import Message
from pydantic import ValidationError

View file

@ -1,9 +1,10 @@
from unittest.mock import MagicMock
from langchain_core.language_models import BaseLanguageModel
from pydantic import BaseModel, Field
from typing_extensions import override
from pydantic import BaseModel, Field
class MockLanguageModel(BaseLanguageModel, BaseModel):
"""A mock language model for testing purposes."""

View file

@ -9,6 +9,7 @@ from hypothesis import strategies as st
from langchain_core.documents import Document
from langflow.serialization.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
from langflow.serialization.serialization import serialize, serialize_or_str
from pydantic import BaseModel as PydanticBaseModel
from pydantic.v1 import BaseModel as PydanticV1BaseModel

View file

@ -7,6 +7,7 @@ from langflow.schema.data import Data
from langflow.template import Input, Output
from langflow.template.field.base import UNDEFINED
from langflow.type_extraction.type_extraction import post_process_type
from pydantic import ValidationError

View file

@ -2,6 +2,7 @@ import importlib
import pytest
from langflow.utils.util import build_template_from_function, get_base_classes, get_default_factory
from pydantic import BaseModel

View file

@ -1,18 +0,0 @@
const SvgNeedleIcon = (props) => (
<svg
xmlns="http://www.w3.org/2000/svg"
width="1em"
height="1em"
viewBox="0 0 32 32"
{...props}
>
<circle cx="16" cy="16" r="15" stroke="currentColor" strokeWidth="2" />
<path
fill="currentColor"
d="M6.06661 23.4341L19.2183 4.1263C19.5277 3.67209 20.3458 3.99976 20.3225 4.54884C20.1123 9.51475 23.0448 12.1637 28.0237 12.3643C28.5874 12.387 28.8315 13.2879 28.3252 13.5368L6.70039 24.1643C6.23948 24.3909 5.7775 23.8586 6.06661 23.4341Z"
/>
<circle cx="24.5" cy="8.5" r="1.5" fill="currentColor" />
</svg>
);
export default SvgNeedleIcon;

View file

@ -1,9 +1,9 @@
import React, { forwardRef } from "react";
import SvgNeedleIcon from "./NeedleIcon";
import NeedleSvg from "./needle-icon.svg?react";
export const NeedleIcon = forwardRef<
SVGSVGElement,
React.PropsWithChildren<{}>
>((props, ref) => {
return <SvgNeedleIcon ref={ref} {...props} />;
return <NeedleSvg ref={ref} {...props} />;
});

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 503 B

After

Width:  |  Height:  |  Size: 45 KiB

Before After
Before After

View file

@ -0,0 +1,102 @@
import { expect, test } from "@playwright/test";
import * as dotenv from "dotenv";
import path from "path";
import { awaitBootstrapTest } from "../../utils/await-bootstrap-test";
import { initialGPTsetup } from "../../utils/initialGPTsetup";
test(
"Invoice Summarizer",
{ tag: ["@release", "@starter-projects"] },
async ({ page }) => {
test.skip(
!process?.env?.OPENAI_API_KEY ||
!process?.env?.NEEDLE_API_KEY ||
!process?.env?.NEEDLE_COLLECTION_ID,
"OPENAI_API_KEY, NEEDLE_API_KEY, and NEEDLE_COLLECTION_ID required to run this test",
);
if (!process.env.CI) {
dotenv.config({ path: path.resolve(__dirname, "../../.env") });
}
await awaitBootstrapTest(page);
await page.getByTestId("side_nav_options_all-templates").click();
await page.getByRole("heading", { name: "Invoice Summarizer" }).click();
await initialGPTsetup(page);
// Configure Needle Search Knowledge Base
await page
.getByTestId("input_str_needle_api_key")
.fill(process.env.NEEDLE_API_KEY || "");
await page
.getByTestId("input_str_collection_id")
.fill(process.env.NEEDLE_COLLECTION_ID || "");
// Run the flow
await page.getByTestId("button_run_chat output").click();
// Wait for the flow to build successfully
await page.waitForSelector("text=built successfully", { timeout: 30000 });
await page.getByText("built successfully").last().click({
timeout: 30000,
});
// Switch to Playground
await page.getByText("Playground", { exact: true }).last().click();
// Wait for the playground to be ready
const inputPlaceholder = page
.getByPlaceholder(
"No chat input variables found. Click to run your flow.",
{ exact: true },
)
.last();
await expect(inputPlaceholder).toBeVisible({ timeout: 10000 });
// Verify initial response is displayed
await expect(page.getByText("Search Results Summary")).toBeVisible({
timeout: 15000,
});
// Verify that specific invoice-related data appears in the results
const keyTerms = ["expenses", "invoice", "vendor"];
for (const term of keyTerms) {
await expect(page.getByText(term, { exact: false })).toBeVisible({
timeout: 5000,
});
}
// Test interaction with the flow by adding a specific query
// Click the input field and type a query
await inputPlaceholder.click();
await page.keyboard.type("Summarize the total expenses from last month");
await page.keyboard.press("Enter");
// Wait for response to the specific query
await expect(page.getByText("Search Results Summary")).toBeVisible({
timeout: 20000,
});
// Verify that expense summary information appears in the response
await expect(page.getByText("expenses", { exact: false })).toBeVisible({
timeout: 10000,
});
// Test error handling - invalid query
await page.keyboard.type("xyz123$%^NonSensicalQuery");
await page.keyboard.press("Enter");
// Wait for the response, which should still show search results or appropriate message
await expect(
page
.getByText("Search Results", { exact: false })
.or(page.getByText("no relevant", { exact: false })),
).toBeVisible({ timeout: 20000 });
// Cleanup - Reset the session
await page.getByTestId("side_nav_options_all-templates").click();
},
);