fix: enhancements for the AssemblyAI component (#3934)

Enhancements for AssemblyAI component

Co-authored-by: Patrick Loeber <98830383+ploeber@users.noreply.github.com>
This commit is contained in:
Patrick Loeber 2024-10-01 16:28:18 +02:00 committed by GitHub
commit d19c16462b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 137 additions and 116 deletions

View file

@ -1,65 +0,0 @@
import datetime
from langflow.custom import Component
from langflow.io import DataInput, Output
from langflow.schema import Data
class AssemblyAITranscriptionParser(Component):
display_name = "AssemblyAI Parse Transcript"
description = (
"Parse AssemblyAI transcription result. "
"If Speaker Labels was enabled, format utterances with speakers and timestamps"
)
documentation = "https://www.assemblyai.com/docs"
icon = "AssemblyAI"
inputs = [
DataInput(
name="transcription_result",
display_name="Transcription Result",
info="The transcription result from AssemblyAI",
),
]
outputs = [
Output(display_name="Parsed Transcription", name="parsed_transcription", method="parse_transcription"),
]
def parse_transcription(self) -> Data:
# check if it's an error message from the previous step
if self.transcription_result.data.get("error"):
self.status = self.transcription_result.data["error"]
return self.transcription_result
try:
transcription_data = self.transcription_result.data
if transcription_data.get("utterances"):
# If speaker diarization was enabled
parsed_result = self.parse_with_speakers(transcription_data["utterances"])
elif transcription_data.get("text"):
# If speaker diarization was not enabled
parsed_result = transcription_data["text"]
else:
raise ValueError("Unexpected transcription format")
self.status = parsed_result
return Data(data={"text": parsed_result})
except Exception as e:
error_message = f"Error parsing transcription: {str(e)}"
self.status = error_message
return Data(data={"error": error_message})
def parse_with_speakers(self, utterances: list[dict]) -> str:
parsed_result = []
for utterance in utterances:
speaker = utterance["speaker"]
start_time = self.format_timestamp(utterance["start"])
text = utterance["text"]
parsed_result.append(f'Speaker {speaker} {start_time}\n"{text}"\n')
return "\n".join(parsed_result)
def format_timestamp(self, milliseconds: int) -> str:
return str(datetime.timedelta(milliseconds=milliseconds)).split(".")[0]

View file

@ -1,7 +1,7 @@
import assemblyai as aai
from langflow.custom import Component
from langflow.io import DataInput, DropdownInput, FloatInput, IntInput, MessageInput, Output, SecretStrInput
from langflow.io import DataInput, DropdownInput, FloatInput, IntInput, MultilineInput, Output, SecretStrInput
from langflow.schema import Data
@ -23,7 +23,7 @@ class AssemblyAILeMUR(Component):
display_name="Transcription Result",
info="The transcription result from AssemblyAI",
),
MessageInput(
MultilineInput(
name="prompt",
display_name="Input Prompt",
info="The text to prompt the model",
@ -34,6 +34,7 @@ class AssemblyAILeMUR(Component):
options=["claude3_5_sonnet", "claude3_opus", "claude3_haiku", "claude3_sonnet"],
value="claude3_5_sonnet",
info="The model that is used for the final prompt after compression is performed",
advanced=True,
),
FloatInput(
name="temperature",
@ -49,6 +50,32 @@ class AssemblyAILeMUR(Component):
value=2000,
info="Max output size in tokens, up to 4000",
),
DropdownInput(
name="endpoint",
display_name="Endpoint",
options=["task", "summary", "question-answer"],
value="task",
info=(
"The LeMUR endpoint to use. For 'summary' and 'question-answer',"
" no prompt input is needed. See https://www.assemblyai.com/docs/api-reference/lemur/ for more info."
),
advanced=True,
),
MultilineInput(
name="questions",
display_name="Questions",
info="Comma-separated list of your questions. Only used if Endpoint is 'question-answer'",
advanced=True,
),
MultilineInput(
name="transcript_ids",
display_name="Transcript IDs",
info=(
"Comma-separated list of transcript IDs. LeMUR can perform actions over multiple transcripts."
" If provided, the Transcription Result is ignored."
),
advanced=True,
),
]
outputs = [
@ -59,41 +86,87 @@ class AssemblyAILeMUR(Component):
"""Use the LeMUR task endpoint to input the LLM prompt."""
aai.settings.api_key = self.api_key
# check if it's an error message from the previous step
if self.transcription_result.data.get("error"):
if not self.transcription_result and not self.transcript_ids:
error = "Either a Transcription Result or Transcript IDs must be provided"
self.status = error
return Data(data={"error": error})
elif self.transcription_result and self.transcription_result.data.get("error"):
# error message from the previous step
self.status = self.transcription_result.data["error"]
return self.transcription_result
if not self.prompt or not self.prompt.text:
self.status = "No prompt specified"
elif self.endpoint == "task" and not self.prompt:
self.status = "No prompt specified for the task endpoint"
return Data(data={"error": "No prompt specified"})
try:
transcript = aai.Transcript.get_by_id(self.transcription_result.data["id"])
except Exception as e:
error = f"Getting transcription failed: {str(e)}"
elif self.endpoint == "question-answer" and not self.questions:
error = "No Questions were provided for the question-answer endpoint"
self.status = error
return Data(data={"error": error})
if transcript.status == aai.TranscriptStatus.completed:
try:
result = transcript.lemur.task(
prompt=self.prompt.text,
final_model=self.get_final_model(self.final_model),
temperature=self.temperature,
max_output_size=self.max_output_size,
)
# Check for valid transcripts
transcript_ids = None
if self.transcription_result and "id" in self.transcription_result.data:
transcript_ids = [self.transcription_result.data["id"]]
elif self.transcript_ids:
transcript_ids = self.transcript_ids.split(",") or []
transcript_ids = [t.strip() for t in transcript_ids]
result = Data(data=result.dict())
self.status = result
return result
except Exception as e:
error = f"An Exception happened while calling LeMUR: {str(e)}"
self.status = error
return Data(data={"error": error})
if not transcript_ids:
error = "Either a valid Transcription Result or valid Transcript IDs must be provided"
self.status = error
return Data(data={"error": error})
# Get TranscriptGroup and check if there is any error
transcript_group = aai.TranscriptGroup(transcript_ids=transcript_ids)
transcript_group, failures = transcript_group.wait_for_completion(return_failures=True)
if failures:
error = f"Getting transcriptions failed: {failures[0]}"
self.status = error
return Data(data={"error": error})
for t in transcript_group.transcripts:
if t.status == aai.TranscriptStatus.error:
self.status = t.error
return Data(data={"error": t.error})
# Perform LeMUR action
try:
response = self.perform_lemur_action(transcript_group, self.endpoint)
result = Data(data=response)
self.status = result
return result
except Exception as e:
error = f"An Error happened: {str(e)}"
self.status = error
return Data(data={"error": error})
def perform_lemur_action(self, transcript_group: aai.TranscriptGroup, endpoint: str) -> dict:
print("Endpoint:", endpoint, type(endpoint))
if endpoint == "task":
result = transcript_group.lemur.task(
prompt=self.prompt,
final_model=self.get_final_model(self.final_model),
temperature=self.temperature,
max_output_size=self.max_output_size,
)
elif endpoint == "summary":
result = transcript_group.lemur.summarize(
final_model=self.get_final_model(self.final_model),
temperature=self.temperature,
max_output_size=self.max_output_size,
)
elif endpoint == "question-answer":
questions = self.questions.split(",")
questions = [aai.LemurQuestion(question=q) for q in questions]
result = transcript_group.lemur.question(
questions=questions,
final_model=self.get_final_model(self.final_model),
temperature=self.temperature,
max_output_size=self.max_output_size,
)
else:
self.status = transcript.error
return Data(data={"error": transcript.error})
raise ValueError(f"Endpoint not supported: {endpoint}")
return result.dict()
def get_final_model(self, model_name: str) -> aai.LemurModel:
if model_name == "claude3_5_sonnet":

View file

@ -29,16 +29,19 @@ class AssemblyAIListTranscripts(Component):
options=["all", "queued", "processing", "completed", "error"],
value="all",
info="Filter by transcript status",
advanced=True,
),
MessageTextInput(
name="created_on",
display_name="Created On",
info="Only get transcripts created on this date (YYYY-MM-DD)",
advanced=True,
),
BoolInput(
name="throttled_only",
display_name="Throttled Only",
info="Only get throttled transcripts, overrides the status filter",
advanced=True,
),
]

View file

@ -27,6 +27,7 @@ class AssemblyAITranscriptionJobPoller(Component):
display_name="Polling Interval",
value=3.0,
info="The polling interval in seconds",
advanced=True,
),
]
@ -52,7 +53,13 @@ class AssemblyAITranscriptionJobPoller(Component):
return Data(data={"error": error})
if transcript.status == aai.TranscriptStatus.completed:
data = Data(data=transcript.json_response)
json_response = transcript.json_response
text = json_response.pop("text", None)
utterances = json_response.pop("utterances", None)
transcript_id = json_response.pop("id", None)
sorted_data = {"text": text, "utterances": utterances, "id": transcript_id}
sorted_data.update(json_response)
data = Data(data=sorted_data)
self.status = data
return data
else:

View file

@ -81,18 +81,24 @@ class AssemblyAITranscriptionJobCreator(Component):
],
value="best",
info="The speech model to use for the transcription",
advanced=True,
),
BoolInput(
name="language_detection",
display_name="Automatic Language Detection",
info="Enable automatic language detection",
advanced=True,
),
MessageTextInput(
name="language_code",
display_name="Language",
info="The language of the audio file. Can be set manually if automatic language detection is disabled.\n"
"See https://www.assemblyai.com/docs/getting-started/supported-languages "
"for a list of supported language codes.",
info=(
"""
The language of the audio file. Can be set manually if automatic language detection is disabled.
See https://www.assemblyai.com/docs/getting-started/supported-languages """
"for a list of supported language codes."
),
advanced=True,
),
BoolInput(
name="speaker_labels",

View file

@ -123,6 +123,7 @@ pytest-split = "^0.9.0"
devtools = "^0.12.2"
pytest-flakefinder = "^1.1.0"
types-markdown = "^3.7.0.20240822"
assemblyai = "^0.33.0"
[tool.pytest.ini_options]
@ -244,7 +245,8 @@ dependencies = [
"crewai>=0.36.0",
"spider-client>=0.0.27",
"diskcache>=5.6.3",
"clickhouse-connect==0.7.19"
"clickhouse-connect==0.7.19",
"assemblyai>=0.33.0"
]
# Optional dependencies for uv