feat: Needle Search Tool With Template (#6648)
* feat: Needle Search Tool With Templte * lint * lint * lint * lint * refactor: Use Langflow Agent instead of CrewAI Agent * techdebt: adjust Needle component to use tool mode and remove tool component * lint * lint * Update Invoice Summarizer.json * Update Invoice Summarizer.json * update to the component * refactor: Use Needle icon svg * make format * component updates * update with latest agent component * updated a missing connection when updating the agent component * update template --------- Co-authored-by: Edwin Jose <edwin.jose@datastax.com>
This commit is contained in:
parent
1249ea138c
commit
94bc8dbc7d
17 changed files with 2122 additions and 99 deletions
|
|
@ -1,16 +1,14 @@
|
|||
from langchain.chains import ConversationalRetrievalChain
|
||||
from langchain_community.retrievers.needle import NeedleRetriever
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from langflow.custom.custom_component.component import Component
|
||||
from langflow.io import DropdownInput, Output, SecretStrInput, StrInput
|
||||
from langflow.io import IntInput, MessageTextInput, Output, SecretStrInput
|
||||
from langflow.schema.message import Message
|
||||
from langflow.utils.constants import MESSAGE_SENDER_AI
|
||||
|
||||
|
||||
class NeedleComponent(Component):
|
||||
display_name = "Needle Retriever"
|
||||
description = "A retriever that uses the Needle API to search collections and generates responses using OpenAI."
|
||||
description = "A retriever that uses the Needle API to search collections."
|
||||
documentation = "https://docs.needle-ai.com"
|
||||
icon = "Needle"
|
||||
name = "needle"
|
||||
|
|
@ -22,30 +20,24 @@ class NeedleComponent(Component):
|
|||
info="Your Needle API key.",
|
||||
required=True,
|
||||
),
|
||||
SecretStrInput(
|
||||
name="openai_api_key",
|
||||
display_name="OpenAI API Key",
|
||||
info="Your OpenAI API key.",
|
||||
required=True,
|
||||
),
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="collection_id",
|
||||
display_name="Collection ID",
|
||||
info="The ID of the Needle collection.",
|
||||
required=True,
|
||||
),
|
||||
StrInput(
|
||||
MessageTextInput(
|
||||
name="query",
|
||||
display_name="User Query",
|
||||
info="Enter your question here.",
|
||||
info="Enter your question here. In tool mode, you can also specify top_k parameter (min: 20).",
|
||||
required=True,
|
||||
tool_mode=True,
|
||||
),
|
||||
DropdownInput(
|
||||
name="output_type",
|
||||
display_name="Output Type",
|
||||
info="Return either the answer or the chunks.",
|
||||
options=["answer", "chunks"],
|
||||
value="answer",
|
||||
IntInput(
|
||||
name="top_k",
|
||||
display_name="Top K Results",
|
||||
info="Number of search results to return (min: 20).",
|
||||
value=20,
|
||||
required=True,
|
||||
),
|
||||
]
|
||||
|
|
@ -53,74 +45,57 @@ class NeedleComponent(Component):
|
|||
outputs = [Output(display_name="Result", name="result", type_="Message", method="run")]
|
||||
|
||||
def run(self) -> Message:
|
||||
needle_api_key = self.needle_api_key or ""
|
||||
openai_api_key = self.openai_api_key or ""
|
||||
collection_id = self.collection_id
|
||||
query = self.query
|
||||
output_type = self.output_type
|
||||
# Extract query and top_k
|
||||
query_input = self.query
|
||||
actual_query = query_input.get("query", "") if isinstance(query_input, dict) else query_input
|
||||
|
||||
# Define error messages
|
||||
needle_api_key = "The Needle API key cannot be empty."
|
||||
openai_api_key = "The OpenAI API key cannot be empty."
|
||||
collection_id_error = "The Collection ID cannot be empty."
|
||||
query_error = "The query cannot be empty."
|
||||
# Parse top_k from tool input or use default, always enforcing minimum of 20
|
||||
try:
|
||||
if isinstance(query_input, dict) and "top_k" in query_input:
|
||||
agent_top_k = query_input.get("top_k")
|
||||
# Check if agent_top_k is not None before converting to int
|
||||
top_k = max(20, int(agent_top_k)) if agent_top_k is not None else max(20, self.top_k)
|
||||
else:
|
||||
top_k = max(20, self.top_k)
|
||||
except (ValueError, TypeError):
|
||||
top_k = max(20, self.top_k)
|
||||
|
||||
# Validate inputs
|
||||
if not needle_api_key.strip():
|
||||
raise ValueError(needle_api_key)
|
||||
if not openai_api_key.strip():
|
||||
raise ValueError(openai_api_key)
|
||||
if not collection_id.strip():
|
||||
raise ValueError(collection_id_error)
|
||||
if not query.strip():
|
||||
raise ValueError(query_error)
|
||||
|
||||
# Handle output_type if it's somehow a list
|
||||
if isinstance(output_type, list):
|
||||
output_type = output_type[0]
|
||||
# Validate required inputs
|
||||
if not self.needle_api_key or not self.needle_api_key.strip():
|
||||
error_msg = "The Needle API key cannot be empty."
|
||||
raise ValueError(error_msg)
|
||||
if not self.collection_id or not self.collection_id.strip():
|
||||
error_msg = "The Collection ID cannot be empty."
|
||||
raise ValueError(error_msg)
|
||||
if not actual_query or not actual_query.strip():
|
||||
error_msg = "The query cannot be empty."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
try:
|
||||
# Initialize the retriever
|
||||
# Initialize the retriever and get documents
|
||||
retriever = NeedleRetriever(
|
||||
needle_api_key=needle_api_key,
|
||||
collection_id=collection_id,
|
||||
needle_api_key=self.needle_api_key,
|
||||
collection_id=self.collection_id,
|
||||
top_k=top_k,
|
||||
)
|
||||
|
||||
# Create the chain
|
||||
llm = ChatOpenAI(
|
||||
temperature=0.7,
|
||||
api_key=openai_api_key,
|
||||
)
|
||||
docs = retriever.get_relevant_documents(actual_query)
|
||||
|
||||
qa_chain = ConversationalRetrievalChain.from_llm(
|
||||
llm=llm,
|
||||
retriever=retriever,
|
||||
return_source_documents=True,
|
||||
)
|
||||
|
||||
# Process the query
|
||||
result = qa_chain({"question": query, "chat_history": []})
|
||||
|
||||
# Format content based on output type
|
||||
if str(output_type).lower().strip() == "chunks":
|
||||
# If chunks selected, include full context and answer
|
||||
docs = result["source_documents"]
|
||||
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
|
||||
text_content = f"Question: {query}\n\nContext:\n{context}\n\nAnswer: {result['answer']}"
|
||||
# Format the response
|
||||
if not docs:
|
||||
text_content = "No relevant documents found for the query."
|
||||
else:
|
||||
# If answer selected, only include the answer
|
||||
text_content = result["answer"]
|
||||
context = "\n\n".join([f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs)])
|
||||
text_content = f"Question: {actual_query}\n\nContext:\n{context}"
|
||||
|
||||
# Create a Message object following chat.py pattern
|
||||
# Return formatted message
|
||||
return Message(
|
||||
text=text_content,
|
||||
type="assistant",
|
||||
sender=MESSAGE_SENDER_AI,
|
||||
additional_kwargs={
|
||||
"source_documents": [
|
||||
{"page_content": doc.page_content, "metadata": doc.metadata}
|
||||
for doc in result["source_documents"]
|
||||
]
|
||||
"source_documents": [{"page_content": doc.page_content, "metadata": doc.metadata} for doc in docs],
|
||||
"top_k_used": top_k,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -6,6 +6,7 @@ from langflow.api.v1.schemas import ResultDataResponse, VertexBuildResponse
|
|||
from langflow.schema.schema import OutputValue
|
||||
from langflow.serialization import serialize
|
||||
from langflow.services.tracing.schema import Log
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Use a smaller test size for hypothesis
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from langflow.components.outputs.chat import ChatOutput
|
|||
from langflow.components.tools.calculator import CalculatorToolComponent
|
||||
from langflow.graph import Graph
|
||||
from langflow.schema.data import Data
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ from langflow.io import (
|
|||
TableInput,
|
||||
)
|
||||
from langflow.schema import Data
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import pytest
|
|||
from langflow.components.helpers.structured_output import StructuredOutputComponent
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from langflow.inputs.inputs import TableInput
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pydantic import BaseModel
|
||||
from tests.base import ComponentTestBaseWithoutClient
|
||||
from tests.unit.mock_language_model import MockLanguageModel
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from langflow.graph import Graph
|
|||
from langflow.graph.graph.constants import Finish
|
||||
from langflow.graph.state.model import create_state_model
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@ from typing import Any
|
|||
|
||||
import pytest
|
||||
from langflow.helpers.base_model import build_model_from_schema
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class TestBuildModelFromSchema:
|
||||
# Successfully creates a Pydantic model from a valid schema
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from langflow.inputs.inputs import (
|
|||
)
|
||||
from langflow.inputs.utils import instantiate_input
|
||||
from langflow.schema.message import Message
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MockLanguageModel(BaseLanguageModel, BaseModel):
|
||||
"""A mock language model for testing purposes."""
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from hypothesis import strategies as st
|
|||
from langchain_core.documents import Document
|
||||
from langflow.serialization.constants import MAX_ITEMS_LENGTH, MAX_TEXT_LENGTH
|
||||
from langflow.serialization.serialization import serialize, serialize_or_str
|
||||
|
||||
from pydantic import BaseModel as PydanticBaseModel
|
||||
from pydantic.v1 import BaseModel as PydanticV1BaseModel
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from langflow.schema.data import Data
|
|||
from langflow.template import Input, Output
|
||||
from langflow.template.field.base import UNDEFINED
|
||||
from langflow.type_extraction.type_extraction import post_process_type
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import importlib
|
|||
|
||||
import pytest
|
||||
from langflow.utils.util import build_template_from_function, get_base_classes, get_default_factory
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue