From cf5b5951c4d5b84b30271444044e7337eb86cba6 Mon Sep 17 00:00:00 2001 From: Eric Hare Date: Wed, 25 Sep 2024 11:33:16 -0700 Subject: [PATCH] bugfix: Properly output a Tool from Glean Search (#3851) * bugfix: Properly output a Tool from Glean Search * [autofix.ci] apply automated fixes * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../AssemblyAIFormatTranscript.py | 3 +- .../AssemblyAIListTranscripts.py | 4 +- .../components/prompts/LangChainHubPrompt.py | 2 - .../components/tools/GleanSearchAPI.py | 191 +++++++++++------- 4 files changed, 117 insertions(+), 83 deletions(-) diff --git a/src/backend/base/langflow/components/documentloaders/AssemblyAIFormatTranscript.py b/src/backend/base/langflow/components/documentloaders/AssemblyAIFormatTranscript.py index 605cab6a5..0b28e73ba 100644 --- a/src/backend/base/langflow/components/documentloaders/AssemblyAIFormatTranscript.py +++ b/src/backend/base/langflow/components/documentloaders/AssemblyAIFormatTranscript.py @@ -1,5 +1,4 @@ import datetime -from typing import Dict, List from langflow.custom import Component from langflow.io import DataInput, Output @@ -49,7 +48,7 @@ class AssemblyAITranscriptionParser(Component): self.status = error_message return Data(data={"error": error_message}) - def parse_with_speakers(self, utterances: List[Dict]) -> str: + def parse_with_speakers(self, utterances: list[dict]) -> str: parsed_result = [] for utterance in utterances: speaker = utterance["speaker"] diff --git a/src/backend/base/langflow/components/documentloaders/AssemblyAIListTranscripts.py b/src/backend/base/langflow/components/documentloaders/AssemblyAIListTranscripts.py index ae4154f74..28f241671 100644 --- a/src/backend/base/langflow/components/documentloaders/AssemblyAIListTranscripts.py +++ b/src/backend/base/langflow/components/documentloaders/AssemblyAIListTranscripts.py @@ -1,5 +1,3 @@ -from typing import List - import assemblyai as aai from langflow.custom import Component @@ -48,7 +46,7 @@ class AssemblyAIListTranscripts(Component): Output(display_name="Transcript List", name="transcript_list", method="list_transcripts"), ] - def list_transcripts(self) -> List[Data]: + def list_transcripts(self) -> list[Data]: aai.settings.api_key = self.api_key params = aai.ListTranscriptParameters() diff --git a/src/backend/base/langflow/components/prompts/LangChainHubPrompt.py b/src/backend/base/langflow/components/prompts/LangChainHubPrompt.py index 0609d06bb..09ed7ac0f 100644 --- a/src/backend/base/langflow/components/prompts/LangChainHubPrompt.py +++ b/src/backend/base/langflow/components/prompts/LangChainHubPrompt.py @@ -2,8 +2,6 @@ import re from langchain_core.prompts import HumanMessagePromptTemplate -from langchain_core.prompts import HumanMessagePromptTemplate - from langflow.custom import Component from langflow.inputs import DefaultPromptField, SecretStrInput, StrInput from langflow.io import Output diff --git a/src/backend/base/langflow/components/tools/GleanSearchAPI.py b/src/backend/base/langflow/components/tools/GleanSearchAPI.py index 00db90499..a8c2e405d 100644 --- a/src/backend/base/langflow/components/tools/GleanSearchAPI.py +++ b/src/backend/base/langflow/components/tools/GleanSearchAPI.py @@ -3,7 +3,9 @@ from typing import Any from urllib.parse import urljoin import httpx +from langchain.tools import StructuredTool from langchain_core.pydantic_v1 import BaseModel +from pydantic.v1 import Field from langflow.base.langchain_utilities.model import LCToolComponent from langflow.field_typing import Tool @@ -28,90 +30,127 @@ class GleanSearchAPIComponent(LCToolComponent): NestedDictInput(name="request_options", display_name="Request Options", required=False), ] + class GleanAPIWrapper(BaseModel): + """ + Wrapper around Glean API. + """ + + glean_api_url: str + glean_access_token: str + act_as: str = "langflow-component@datastax.com" # TODO: Detect this + + def _prepare_request( + self, + query: str, + page_size: int = 10, + request_options: dict[str, Any] | None = None, + ) -> dict: + # Ensure there's a trailing slash + url = self.glean_api_url + if not url.endswith("/"): + url += "/" + + return { + "url": urljoin(url, "search"), + "headers": { + "Authorization": f"Bearer {self.glean_access_token}", + "X-Scio-ActAs": self.act_as, + }, + "payload": { + "query": query, + "pageSize": page_size, + "requestOptions": request_options, + }, + } + + def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: + results = self._search_api_results(query, **kwargs) + + if len(results) == 0: + raise AssertionError("No good Glean Search Result was found") + + return results + + def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: + results = self.results(query, **kwargs) + + processed_results = [] + for result in results: + if "title" in result: + result["snippets"] = result.get("snippets", [{"snippet": {"text": result["title"]}}]) + if "text" not in result["snippets"][0]: + result["snippets"][0]["text"] = result["title"] + + processed_results.append(result) + + return processed_results + + def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: + request_details = self._prepare_request(query, **kwargs) + + response = httpx.post( + request_details["url"], + json=request_details["payload"], + headers=request_details["headers"], + ) + + response.raise_for_status() + response_json = response.json() + + return response_json.get("results", []) + + @staticmethod + def _result_as_string(result: dict) -> str: + return json.dumps(result, indent=4) + + class GleanSearchAPISchema(BaseModel): + query: str = Field(..., description="The search query") + page_size: int = Field(10, description="Maximum number of results to return") + request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options") + def build_tool(self) -> Tool: - wrapper = self._build_wrapper() - - return Tool(name="glean_search_api", description="Search with the Glean API", func=wrapper.run) - - def run_model(self) -> Data | list[Data]: - wrapper = self._build_wrapper() - - results = wrapper.results( - query=self.query, - page_size=self.page_size, - request_options=self.request_options, + wrapper = self._build_wrapper( + glean_api_url=self.glean_api_url, + glean_access_token=self.glean_access_token, ) - list_results = results.get("results", []) + tool = StructuredTool.from_function( + name="glean_search_api", + description="Search Glean for relevant results.", + func=wrapper.run, + args_schema=self.GleanSearchAPISchema, + ) + + self.status = "Glean Search API Tool for Langchain" + + return tool + + def run_model(self) -> list[Data]: + tool = self.build_tool() + + results = tool.run( + { + "query": self.query, + "page_size": self.page_size, + "request_options": self.request_options, + } + ) # Build the data data = [] - for result in list_results: - data.append(Data(data=result)) + for result in results: + data.append(Data(data=result, text=result["snippets"][0]["text"])) self.status = data return data - def _build_wrapper(self): - class GleanAPIWrapper(BaseModel): - """ - Wrapper around Glean API. - """ - - glean_api_url: str - glean_access_token: str - act_as: str = "langflow-component@datastax.com" # TODO: Detect this - - def _prepare_request( - self, - query: str, - page_size: int = 10, - request_options: dict[str, Any] | None = None, - ) -> dict: - # Ensure there's a trailing slash - url = self.glean_api_url - if not url.endswith("/"): - url += "/" - - return { - "url": urljoin(url, "search"), - "headers": { - "Authorization": f"Bearer {self.glean_access_token}", - "X-Scio-ActAs": self.act_as, - }, - "payload": { - "query": query, - "pageSize": page_size, - "requestOptions": request_options, - }, - } - - def run(self, query: str, **kwargs: Any) -> str: - results = self.results(query, **kwargs) - - return self._result_as_string(results) - - def results(self, query: str, **kwargs: Any) -> dict: - results = self._search_api_results(query, **kwargs) - - return results - - def _search_api_results(self, query: str, **kwargs: Any) -> dict[str, Any]: - request_details = self._prepare_request(query, **kwargs) - - response = httpx.post( - request_details["url"], - json=request_details["payload"], - headers=request_details["headers"], - ) - - response.raise_for_status() - - return response.json() - - @staticmethod - def _result_as_string(result: dict) -> str: - return json.dumps(result, indent=4) - - return GleanAPIWrapper(glean_api_url=self.glean_api_url, glean_access_token=self.glean_access_token) + def _build_wrapper( + self, + glean_api_url: str, + glean_access_token: str, + ): + return self.GleanAPIWrapper( + glean_api_url=glean_api_url, + glean_access_token=glean_access_token, + )