bugfix: Properly output a Tool from Glean Search (#3851)

* bugfix: Properly output a Tool from Glean Search

* [autofix.ci] apply automated fixes

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Eric Hare 2024-09-25 11:33:16 -07:00 committed by GitHub
commit cf5b5951c4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 117 additions and 83 deletions

View file

@ -1,5 +1,4 @@
import datetime
from typing import Dict, List
from langflow.custom import Component
from langflow.io import DataInput, Output
@ -49,7 +48,7 @@ class AssemblyAITranscriptionParser(Component):
self.status = error_message
return Data(data={"error": error_message})
def parse_with_speakers(self, utterances: List[Dict]) -> str:
def parse_with_speakers(self, utterances: list[dict]) -> str:
parsed_result = []
for utterance in utterances:
speaker = utterance["speaker"]

View file

@ -1,5 +1,3 @@
from typing import List
import assemblyai as aai
from langflow.custom import Component
@ -48,7 +46,7 @@ class AssemblyAIListTranscripts(Component):
Output(display_name="Transcript List", name="transcript_list", method="list_transcripts"),
]
def list_transcripts(self) -> List[Data]:
def list_transcripts(self) -> list[Data]:
aai.settings.api_key = self.api_key
params = aai.ListTranscriptParameters()

View file

@ -2,8 +2,6 @@ import re
from langchain_core.prompts import HumanMessagePromptTemplate
from langchain_core.prompts import HumanMessagePromptTemplate
from langflow.custom import Component
from langflow.inputs import DefaultPromptField, SecretStrInput, StrInput
from langflow.io import Output

View file

@ -3,7 +3,9 @@ from typing import Any
from urllib.parse import urljoin
import httpx
from langchain.tools import StructuredTool
from langchain_core.pydantic_v1 import BaseModel
from pydantic.v1 import Field
from langflow.base.langchain_utilities.model import LCToolComponent
from langflow.field_typing import Tool
@ -28,90 +30,127 @@ class GleanSearchAPIComponent(LCToolComponent):
NestedDictInput(name="request_options", display_name="Request Options", required=False),
]
class GleanAPIWrapper(BaseModel):
"""
Wrapper around Glean API.
"""
glean_api_url: str
glean_access_token: str
act_as: str = "langflow-component@datastax.com" # TODO: Detect this
def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"
return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}
def results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self._search_api_results(query, **kwargs)
if len(results) == 0:
raise AssertionError("No good Glean Search Result was found")
return results
def run(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
results = self.results(query, **kwargs)
processed_results = []
for result in results:
if "title" in result:
result["snippets"] = result.get("snippets", [{"snippet": {"text": result["title"]}}])
if "text" not in result["snippets"][0]:
result["snippets"][0]["text"] = result["title"]
processed_results.append(result)
return processed_results
def _search_api_results(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
request_details = self._prepare_request(query, **kwargs)
response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)
response.raise_for_status()
response_json = response.json()
return response_json.get("results", [])
@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)
class GleanSearchAPISchema(BaseModel):
query: str = Field(..., description="The search query")
page_size: int = Field(10, description="Maximum number of results to return")
request_options: dict[str, Any] | None = Field(default_factory=dict, description="Request Options")
def build_tool(self) -> Tool:
wrapper = self._build_wrapper()
return Tool(name="glean_search_api", description="Search with the Glean API", func=wrapper.run)
def run_model(self) -> Data | list[Data]:
wrapper = self._build_wrapper()
results = wrapper.results(
query=self.query,
page_size=self.page_size,
request_options=self.request_options,
wrapper = self._build_wrapper(
glean_api_url=self.glean_api_url,
glean_access_token=self.glean_access_token,
)
list_results = results.get("results", [])
tool = StructuredTool.from_function(
name="glean_search_api",
description="Search Glean for relevant results.",
func=wrapper.run,
args_schema=self.GleanSearchAPISchema,
)
self.status = "Glean Search API Tool for Langchain"
return tool
def run_model(self) -> list[Data]:
tool = self.build_tool()
results = tool.run(
{
"query": self.query,
"page_size": self.page_size,
"request_options": self.request_options,
}
)
# Build the data
data = []
for result in list_results:
data.append(Data(data=result))
for result in results:
data.append(Data(data=result, text=result["snippets"][0]["text"]))
self.status = data
return data
def _build_wrapper(self):
class GleanAPIWrapper(BaseModel):
"""
Wrapper around Glean API.
"""
glean_api_url: str
glean_access_token: str
act_as: str = "langflow-component@datastax.com" # TODO: Detect this
def _prepare_request(
self,
query: str,
page_size: int = 10,
request_options: dict[str, Any] | None = None,
) -> dict:
# Ensure there's a trailing slash
url = self.glean_api_url
if not url.endswith("/"):
url += "/"
return {
"url": urljoin(url, "search"),
"headers": {
"Authorization": f"Bearer {self.glean_access_token}",
"X-Scio-ActAs": self.act_as,
},
"payload": {
"query": query,
"pageSize": page_size,
"requestOptions": request_options,
},
}
def run(self, query: str, **kwargs: Any) -> str:
results = self.results(query, **kwargs)
return self._result_as_string(results)
def results(self, query: str, **kwargs: Any) -> dict:
results = self._search_api_results(query, **kwargs)
return results
def _search_api_results(self, query: str, **kwargs: Any) -> dict[str, Any]:
request_details = self._prepare_request(query, **kwargs)
response = httpx.post(
request_details["url"],
json=request_details["payload"],
headers=request_details["headers"],
)
response.raise_for_status()
return response.json()
@staticmethod
def _result_as_string(result: dict) -> str:
return json.dumps(result, indent=4)
return GleanAPIWrapper(glean_api_url=self.glean_api_url, glean_access_token=self.glean_access_token)
def _build_wrapper(
self,
glean_api_url: str,
glean_access_token: str,
):
return self.GleanAPIWrapper(
glean_api_url=glean_api_url,
glean_access_token=glean_access_token,
)