feat: Needle Search Tool With Template (#6648)
* feat: Needle Search Tool With Templte * lint * lint * lint * lint * refactor: Use Langflow Agent instead of CrewAI Agent * techdebt: adjust Needle component to use tool mode and remove tool component * lint * lint * Update Invoice Summarizer.json * Update Invoice Summarizer.json * update to the component * refactor: Use Needle icon svg * make format * component updates * update with latest agent component * updated a missing connection when updating the agent component * update template --------- Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
parent
1249ea138c
commit
94bc8dbc7d
17 changed files with 2122 additions and 99 deletions
|
|
@ -1,16 +1,14 @@
|
|||
from langchain.chains import ConversationalRetrievalChain
|
||||
from langchain_community.retrievers.needle import NeedleRetriever
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.io import DropdownInput, Output, SecretStrInput, StrInput
|
||||
from langflow.io import IntInput, MessageTextInput, Output, SecretStrInput
|
||||
from langflow.schema.message import Message
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI
|
||||
|
||||
|
||||
class NeedleComponent(Component):
|
||||
display_name = "Needle Retriever"
|
||||
description = "A retriever that uses the Needle API to search collections and generates responses using OpenAI."
|
||||
description = "A retriever that uses the Needle API to search collections."
|
||||
documentation = "https://docs.needle-ai.com"
|
||||
icon = "Needle"
|
||||
name = "needle"
|
||||
|
|
@ -22,30 +20,24 @@ class NeedleComponent(Component):
|
|||
info="Your Needle API key.",
|
||||
required=True,
|
||||
),
|
||||
SecretStrInput(
|
||||
name="openai_api_key",
|
||||
display_name="OpenAI API Key",
|
||||
info="Your OpenAI API key.",
|
||||
required=True,
|
||||
),
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="collection_id",
|
||||
display_name="Collection ID",
|
||||
info="The ID of the Needle collection.",
|
||||
required=True,
|
||||
),
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="query",
|
||||
display_name="User Query",
|
||||
info="Enter your question here.",
|
||||
info="Enter your question here. In tool mode, you can also specify top_k parameter (min: 20).",
|
||||
required=True,
|
||||
tool_mode=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="output_type",
|
||||
display_name="Output Type",
|
||||
info="Return either the answer or the chunks.",
|
||||
options=["answer", "chunks"],
|
||||
value="answer",
|
||||
IntInput(
|
||||
name="top_k",
|
||||
display_name="Top K Results",
|
||||
info="Number of search results to return (min: 20).",
|
||||
value=20,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
|
@ -53,74 +45,57 @@ class NeedleComponent(Component):
|
|||
outputs = [Output(display_name="Result", name="result", type_="Message", method="run")]
|
||||
|
||||
def run(self) -> Message:
|
||||
needle_api_key = self.needle_api_key or ""
|
||||
openai_api_key = self.openai_api_key or ""
|
||||
collection_id = self.collection_id
|
||||
query = self.query
|
||||
output_type = self.output_type
|
||||
# Extract query and top_k
|
||||
query_input = self.query
|
||||
actual_query = query_input.get("query", "") if isinstance(query_input, dict) else query_input
|
||||
|
||||
# Define error messages
|
||||
needle_api_key = "The Needle API key cannot be empty."
|
||||
openai_api_key = "The OpenAI API key cannot be empty."
|
||||
collection_id_error = "The Collection ID cannot be empty."
|
||||
query_error = "The query cannot be empty."
|
||||
# Parse top_k from tool input or use default, always enforcing minimum of 20
|
||||
try:
|
||||
if isinstance(query_input, dict) and "top_k" in query_input:
|
||||
agent_top_k = query_input.get("top_k")
|
||||
# Check if agent_top_k is not None before converting to int
|
||||
top_k = max(20, int(agent_top_k)) if agent_top_k is not None else max(20, self.top_k)
|
||||
else:
|
||||
top_k = max(20, self.top_k)
|
||||
except (ValueError, TypeError):
|
||||
top_k = max(20, self.top_k)
|
||||
|
||||
# Validate inputs
|
||||
if not needle_api_key.strip():
|
||||
raise ValueError(needle_api_key)
|
||||
if not openai_api_key.strip():
|
||||
raise ValueError(openai_api_key)
|
||||
if not collection_id.strip():
|
||||
raise ValueError(collection_id_error)
|
||||
if not query.strip():
|
||||
raise ValueError(query_error)
|
||||
|
||||
# Handle output_type if it's somehow a list
|
||||
if isinstance(output_type, list):
|
||||
output_type = output_type[0]
|
||||
# Validate required inputs
|
||||
if not self.needle_api_key or not self.needle_api_key.strip():
|
||||
error_msg = "The Needle API key cannot be empty."
|
||||
raise ValueError(error_msg)
|
||||
if not self.collection_id or not self.collection_id.strip():
|
||||
error_msg = "The Collection ID cannot be empty."
|
||||
raise ValueError(error_msg)
|
||||
if not actual_query or not actual_query.strip():
|
||||
error_msg = "The query cannot be empty."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
# Initialize the retriever
|
||||
# Initialize the retriever and get documents
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key=needle_api_key,
|
||||
collection_id=collection_id,
|
||||
needle_api_key=self.needle_api_key,
|
||||
collection_id=self.collection_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Create the chain
|
||||
llm = ChatOpenAI(
|
||||
temperature=0.7,
|
||||
api_key=openai_api_key,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(actual_query)
|
||||
|
||||
qa_chain = ConversationalRetrievalChain.from_llm(
|
||||
llm=llm,
|
||||
retriever=retriever,
|
||||
return_source_documents=True,
|
||||
)
|
||||
|
||||
# Process the query
|
||||
result = qa_chain({"question": query, "chat_history": []})
|
||||
|
||||
# Format content based on output type
|
||||
if str(output_type).lower().strip() == "chunks":
|
||||
# If chunks selected, include full context and answer
|
||||
docs = result["source_documents"]
|
||||
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
|
||||
text_content = f"Question: {query}\n\nContext:\n{context}\n\nAnswer: {result['answer']}"
|
||||
# Format the response
|
||||
if not docs:
|
||||
text_content = "No relevant documents found for the query."
|
||||
else:
|
||||
# If answer selected, only include the answer
|
||||
text_content = result["answer"]
|
||||
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
|
||||
text_content = f"Question: {actual_query}\n\nContext:\n{context}"
|
||||
|
||||
# Create a Message object following chat.py pattern
|
||||
# Return formatted message
|
||||
return Message(
|
||||
text=text_content,
|
||||
type="assistant",
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
additional_kwargs={
|
||||
"source_documents": [
|
||||
{"page_content": doc.page_content, "metadata": doc.metadata}
|
||||
for doc in result["source_documents"]
|
||||
]
|
||||
"source_documents": [{"page_content": doc.page_content, "metadata": doc.metadata} for doc in docs],
|
||||
"top_k_used": top_k,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -6,6 +6,7 @@ from langflow.api.v1.schemas import ResultDataResponse, VertexBuildResponse
|
|||
from langflow.schema.schema import OutputValue
|
||||
from langflow.serialization import serialize
|
||||
from langflow.services.tracing.schema import Log
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Use a smaller test size for hypothesis
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from langflow.components.outputs.chat import ChatOutput
|
|||
from langflow.components.tools.calculator import CalculatorToolComponent
|
||||
from langflow.graph import Graph
|
||||
from langflow.schema.data import Data
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from langflow.io import (
|
|||
TableInput,
|
||||
)
|
||||
from langflow.schema import Data
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import pytest
|
|||
from langflow.components.helpers.structured_output import StructuredOutputComponent
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from langflow.inputs.inputs import TableInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
from tests.unit.mock_language_model import MockLanguageModel
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from langflow.graph import Graph
|
|||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from typing import Any
|
|||
|
||||
import pytest
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TestBuildModelFromSchema:
|
||||
# Successfully creates a Pydantic model from a valid schema
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from langflow.inputs.inputs import (
|
|||
)
|
||||
from langflow.inputs.utils import instantiate_input
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MockLanguageModel(BaseLanguageModel, BaseModel):
|
||||
"""A mock language model for testing purposes."""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from hypothesis import strategies as st
|
|||
from langchain_core.documents import Document
|
||||
from langflow.serialization.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
|
||||
from langflow.serialization.serialization import serialize, serialize_or_str
|
||||
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from langflow.schema.data import Data
|
|||
from langflow.template import Input, Output
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import importlib
|
|||
|
||||
import pytest
|
||||
from langflow.utils.util import build_template_from_function, get_base_classes, get_default_factory
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +0,0 @@
|
|||
const SvgNeedleIcon = (props) => (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="1em"
|
||||
height="1em"
|
||||
viewBox="0 0 32 32"
|
||||
{...props}
|
||||
>
|
||||
<circle cx="16" cy="16" r="15" stroke="currentColor" strokeWidth="2" />
|
||||
<path
|
||||
fill="currentColor"
|
||||
d="M6.06661 23.4341L19.2183 4.1263C19.5277 3.67209 20.3458 3.99976 20.3225 4.54884C20.1123 9.51475 23.0448 12.1637 28.0237 12.3643C28.5874 12.387 28.8315 13.2879 28.3252 13.5368L6.70039 24.1643C6.23948 24.3909 5.7775 23.8586 6.06661 23.4341Z"
|
||||
/>
|
||||
<circle cx="24.5" cy="8.5" r="1.5" fill="currentColor" />
|
||||
</svg>
|
||||
);
|
||||
|
||||
export default SvgNeedleIcon;
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
import React, { forwardRef } from "react";
|
||||
import SvgNeedleIcon from "./NeedleIcon";
|
||||
import NeedleSvg from "./needle-icon.svg?react";
|
||||
|
||||
export const NeedleIcon = forwardRef<
|
||||
SVGSVGElement,
|
||||
React.PropsWithChildren<{}>
|
||||
>((props, ref) => {
|
||||
return <SvgNeedleIcon ref={ref} {...props} />;
|
||||
return <NeedleSvg ref={ref} {...props} />;
|
||||
});
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 503 B After Width: | Height: | Size: 45 KiB |
102
src/frontend/tests/core/integrations/Invoice Summarizer.spec.ts
Normal file
102
src/frontend/tests/core/integrations/Invoice Summarizer.spec.ts
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import { expect, test } from "@playwright/test";
|
||||
import * as dotenv from "dotenv";
|
||||
import path from "path";
|
||||
import { awaitBootstrapTest } from "../../utils/await-bootstrap-test";
|
||||
import { initialGPTsetup } from "../../utils/initialGPTsetup";
|
||||
|
||||
test(
|
||||
"Invoice Summarizer",
|
||||
{ tag: ["@release", "@starter-projects"] },
|
||||
async ({ page }) => {
|
||||
test.skip(
|
||||
!process?.env?.OPENAI_API_KEY ||
|
||||
!process?.env?.NEEDLE_API_KEY ||
|
||||
!process?.env?.NEEDLE_COLLECTION_ID,
|
||||
"OPENAI_API_KEY, NEEDLE_API_KEY, and NEEDLE_COLLECTION_ID required to run this test",
|
||||
);
|
||||
|
||||
if (!process.env.CI) {
|
||||
dotenv.config({ path: path.resolve(__dirname, "../../.env") });
|
||||
}
|
||||
|
||||
await awaitBootstrapTest(page);
|
||||
|
||||
await page.getByTestId("side_nav_options_all-templates").click();
|
||||
await page.getByRole("heading", { name: "Invoice Summarizer" }).click();
|
||||
|
||||
await initialGPTsetup(page);
|
||||
|
||||
// Configure Needle Search Knowledge Base
|
||||
await page
|
||||
.getByTestId("input_str_needle_api_key")
|
||||
.fill(process.env.NEEDLE_API_KEY || "");
|
||||
await page
|
||||
.getByTestId("input_str_collection_id")
|
||||
.fill(process.env.NEEDLE_COLLECTION_ID || "");
|
||||
|
||||
// Run the flow
|
||||
await page.getByTestId("button_run_chat output").click();
|
||||
|
||||
// Wait for the flow to build successfully
|
||||
await page.waitForSelector("text=built successfully", { timeout: 30000 });
|
||||
await page.getByText("built successfully").last().click({
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
// Switch to Playground
|
||||
await page.getByText("Playground", { exact: true }).last().click();
|
||||
|
||||
// Wait for the playground to be ready
|
||||
const inputPlaceholder = page
|
||||
.getByPlaceholder(
|
||||
"No chat input variables found. Click to run your flow.",
|
||||
{ exact: true },
|
||||
)
|
||||
.last();
|
||||
|
||||
await expect(inputPlaceholder).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Verify initial response is displayed
|
||||
await expect(page.getByText("Search Results Summary")).toBeVisible({
|
||||
timeout: 15000,
|
||||
});
|
||||
|
||||
// Verify that specific invoice-related data appears in the results
|
||||
const keyTerms = ["expenses", "invoice", "vendor"];
|
||||
for (const term of keyTerms) {
|
||||
await expect(page.getByText(term, { exact: false })).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
}
|
||||
|
||||
// Test interaction with the flow by adding a specific query
|
||||
// Click the input field and type a query
|
||||
await inputPlaceholder.click();
|
||||
await page.keyboard.type("Summarize the total expenses from last month");
|
||||
await page.keyboard.press("Enter");
|
||||
|
||||
// Wait for response to the specific query
|
||||
await expect(page.getByText("Search Results Summary")).toBeVisible({
|
||||
timeout: 20000,
|
||||
});
|
||||
|
||||
// Verify that expense summary information appears in the response
|
||||
await expect(page.getByText("expenses", { exact: false })).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
// Test error handling - invalid query
|
||||
await page.keyboard.type("xyz123$%^NonSensicalQuery");
|
||||
await page.keyboard.press("Enter");
|
||||
|
||||
// Wait for the response, which should still show search results or appropriate message
|
||||
await expect(
|
||||
page
|
||||
.getByText("Search Results", { exact: false })
|
||||
.or(page.getByText("no relevant", { exact: false })),
|
||||
).toBeVisible({ timeout: 20000 });
|
||||
|
||||
// Cleanup - Reset the session
|
||||
await page.getByTestId("side_nav_options_all-templates").click();
|
||||
},
|
||||
);
|
||||
Loading…
Add table
Add a link
Reference in a new issue