feat: Needle Search Tool With Template (#6648)

* feat: Needle Search Tool With Templte

* lint

* lint

* lint

* lint

* refactor: Use Langflow Agent instead of CrewAI Agent

* techdebt: adjust Needle component to use tool mode and remove tool component

* lint

* lint

* Update Invoice Summarizer.json

* Update Invoice Summarizer.json

* update to the component

* refactor: Use Needle icon svg

* make format

* component updates

* update with latest agent component

* updated a missing connection when updating the agent component

* update template

---------

Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
Jan Heimes 2025-03-14 10:37:54 +01:00 committed by GitHub
commit 94bc8dbc7d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 2122 additions and 99 deletions

View file

@ -1,16 +1,14 @@
from langchain.chains import ConversationalRetrievalChain
from langchain_community.retrievers.needle import NeedleRetriever
from langchain_openai import ChatOpenAI
from langflow.custom.custom_component.component import Component
from langflow.io import DropdownInput, Output, SecretStrInput, StrInput
from langflow.io import IntInput, MessageTextInput, Output, SecretStrInput
from langflow.schema.message import Message
from langflow.utils.constants import MESSAGE_SENDER_AI
class NeedleComponent(Component):
display_name = "Needle Retriever"
description = "A retriever that uses the Needle API to search collections and generates responses using OpenAI."
description = "A retriever that uses the Needle API to search collections."
documentation = "https://docs.needle-ai.com"
icon = "Needle"
name = "needle"
@ -22,30 +20,24 @@ class NeedleComponent(Component):
info="Your Needle API key.",
required=True,
),
SecretStrInput(
name="openai_api_key",
display_name="OpenAI API Key",
info="Your OpenAI API key.",
required=True,
),
StrInput(
MessageTextInput(
name="collection_id",
display_name="Collection ID",
info="The ID of the Needle collection.",
required=True,
),
StrInput(
MessageTextInput(
name="query",
display_name="User Query",
info="Enter your question here.",
info="Enter your question here. In tool mode, you can also specify top_k parameter (min: 20).",
required=True,
tool_mode=True,
),
DropdownInput(
name="output_type",
display_name="Output Type",
info="Return either the answer or the chunks.",
options=["answer", "chunks"],
value="answer",
IntInput(
name="top_k",
display_name="Top K Results",
info="Number of search results to return (min: 20).",
value=20,
required=True,
),
]
@ -53,74 +45,57 @@ class NeedleComponent(Component):
outputs = [Output(display_name="Result", name="result", type_="Message", method="run")]
def run(self) -> Message:
needle_api_key = self.needle_api_key or ""
openai_api_key = self.openai_api_key or ""
collection_id = self.collection_id
query = self.query
output_type = self.output_type
# Extract query and top_k
query_input = self.query
actual_query = query_input.get("query", "") if isinstance(query_input, dict) else query_input
# Define error messages
needle_api_key = "The Needle API key cannot be empty."
openai_api_key = "The OpenAI API key cannot be empty."
collection_id_error = "The Collection ID cannot be empty."
query_error = "The query cannot be empty."
# Parse top_k from tool input or use default, always enforcing minimum of 20
try:
if isinstance(query_input, dict) and "top_k" in query_input:
agent_top_k = query_input.get("top_k")
# Check if agent_top_k is not None before converting to int
top_k = max(20, int(agent_top_k)) if agent_top_k is not None else max(20, self.top_k)
else:
top_k = max(20, self.top_k)
except (ValueError, TypeError):
top_k = max(20, self.top_k)
# Validate inputs
if not needle_api_key.strip():
raise ValueError(needle_api_key)
if not openai_api_key.strip():
raise ValueError(openai_api_key)
if not collection_id.strip():
raise ValueError(collection_id_error)
if not query.strip():
raise ValueError(query_error)
# Handle output_type if it's somehow a list
if isinstance(output_type, list):
output_type = output_type[0]
# Validate required inputs
if not self.needle_api_key or not self.needle_api_key.strip():
error_msg = "The Needle API key cannot be empty."
raise ValueError(error_msg)
if not self.collection_id or not self.collection_id.strip():
error_msg = "The Collection ID cannot be empty."
raise ValueError(error_msg)
if not actual_query or not actual_query.strip():
error_msg = "The query cannot be empty."
raise ValueError(error_msg)
try:
# Initialize the retriever
# Initialize the retriever and get documents
retriever = NeedleRetriever(
needle_api_key=needle_api_key,
collection_id=collection_id,
needle_api_key=self.needle_api_key,
collection_id=self.collection_id,
top_k=top_k,
)
# Create the chain
llm = ChatOpenAI(
temperature=0.7,
api_key=openai_api_key,
)
docs = retriever.get_relevant_documents(actual_query)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
return_source_documents=True,
)
# Process the query
result = qa_chain({"question": query, "chat_history": []})
# Format content based on output type
if str(output_type).lower().strip() == "chunks":
# If chunks selected, include full context and answer
docs = result["source_documents"]
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
text_content = f"Question: {query}\n\nContext:\n{context}\n\nAnswer: {result['answer']}"
# Format the response
if not docs:
text_content = "No relevant documents found for the query."
else:
# If answer selected, only include the answer
text_content = result["answer"]
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
text_content = f"Question: {actual_query}\n\nContext:\n{context}"
# Create a Message object following chat.py pattern
# Return formatted message
return Message(
text=text_content,
type="assistant",
sender=MESSAGE_SENDER_AI,
additional_kwargs={
"source_documents": [
{"page_content": doc.page_content, "metadata": doc.metadata}
for doc in result["source_documents"]
]
"source_documents": [{"page_content": doc.page_content, "metadata": doc.metadata} for doc in docs],
"top_k_used": top_k,
},
)

File diff suppressed because one or more lines are too long

View file

@ -6,6 +6,7 @@ from langflow.api.v1.schemas import ResultDataResponse, VertexBuildResponse
from langflow.schema.schema import OutputValue
from langflow.serialization import serialize
from langflow.services.tracing.schema import Log
from pydantic import BaseModel
# Use a smaller test size for hypothesis

View file

@ -8,6 +8,7 @@ from langflow.components.outputs.chat import ChatOutput
from langflow.components.tools.calculator import CalculatorToolComponent
from langflow.graph import Graph
from langflow.schema.data import Data
from pydantic import BaseModel

View file

@ -21,6 +21,7 @@ from langflow.io import (
TableInput,
)
from langflow.schema import Data
from pydantic import BaseModel

View file

@ -5,8 +5,8 @@ import pytest
from langflow.components.helpers.structured_output import StructuredOutputComponent
from langflow.helpers.base_model import build_model_from_schema
from langflow.inputs.inputs import TableInput
from pydantic import BaseModel
from pydantic import BaseModel
from tests.base import ComponentTestBaseWithoutClient
from tests.unit.mock_language_model import MockLanguageModel

View file

@ -5,6 +5,7 @@ from langflow.graph import Graph
from langflow.graph.graph.constants import Finish
from langflow.graph.state.model import create_state_model
from langflow.template.field.base import UNDEFINED
from pydantic import Field

View file

@ -4,9 +4,10 @@ from typing import Any
import pytest
from langflow.helpers.base_model import build_model_from_schema
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from pydantic import BaseModel
class TestBuildModelFromSchema:
# Successfully creates a Pydantic model from a valid schema

View file

@ -23,6 +23,7 @@ from langflow.inputs.inputs import (
)
from langflow.inputs.utils import instantiate_input
from langflow.schema.message import Message
from pydantic import ValidationError

View file

@ -1,9 +1,10 @@
from unittest.mock import MagicMock
from langchain_core.language_models import BaseLanguageModel
from pydantic import BaseModel, Field
from typing_extensions import override
from pydantic import BaseModel, Field
class MockLanguageModel(BaseLanguageModel, BaseModel):
"""A mock language model for testing purposes."""

View file

@ -9,6 +9,7 @@ from hypothesis import strategies as st
from langchain_core.documents import Document
from langflow.serialization.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
from langflow.serialization.serialization import serialize, serialize_or_str
from pydantic import BaseModel as PydanticBaseModel
from pydantic.v1 import BaseModel as PydanticV1BaseModel

View file

@ -7,6 +7,7 @@ from langflow.schema.data import Data
from langflow.template import Input, Output
from langflow.template.field.base import UNDEFINED
from langflow.type_extraction.type_extraction import post_process_type
from pydantic import ValidationError

View file

@ -2,6 +2,7 @@ import importlib
import pytest
from langflow.utils.util import build_template_from_function, get_base_classes, get_default_factory
from pydantic import BaseModel