Source code for cdp_backend.annotation.speaker_labels

#!/usr/bin/env python

from __future__ import annotations

import logging
from pathlib import Path
from uuid import uuid4

import numpy as np
from pydub import AudioSegment
from tqdm import tqdm
from transformers import pipeline
from transformers.pipelines.base import Pipeline

from ..pipeline.transcript_model import Transcript

###############################################################################

log = logging.getLogger(__name__)

###############################################################################

DEFAULT_MODEL = "trained-speakerbox"

###############################################################################


[docs] def annotate( # noqa: C901 transcript: str | Path | Transcript, audio: str | Path | AudioSegment, model: str | Pipeline = DEFAULT_MODEL, min_intra_sentence_chunk_duration: float = 0.5, max_intra_sentence_chunk_duration: float = 2.0, min_sentence_mean_confidence: float = 0.985, ) -> Transcript: """ Annotate a transcript using a pre-trained speaker identification model. Parameters ---------- transcript: Union[str, Path, Transcript] The path to or an already in-memory Transcript object to fill with speaker name annotations. audio: Union[str, Path, AudioSegment] The path to the matching audio for the associated transcript. model: Union[str, Pipeline] The path to the trained Speakerbox audio classification model or a preloaded audio classification Pipeline. Default: "trained-speakerbox" min_intra_sentence_chunk_duration: float The minimum duration of a sentence to annotate. Anything less than this will be ignored from annotation. Default: 0.5 seconds max_intra_sentence_chunk_duration: float The maximum duration for a sentences audio to split to. This should match whatever was used during model training (i.e. trained on 2 second audio chunks, apply on 2 second audio chunks) Default: 2 seconds min_sentence_mean_confidence: float The minimum allowable mean confidence of all predicted chunk confidences to determine if the predicted label should be commited as an annotation or not. Default: 0.985 Returns ------- Transcript The annotated transcript. Note this updates in place, this does not return a new Transcript object. Notes ----- A plain text walkthrough of this function is as follows: For each sentence in the provided transcript, the matching audio portion is retrieved, for example if the sentence start time is 12.05 seconds and end time is 20.47 seconds, that exact portion of audio is pulled from the full audio file. If the audio duration for the chunk is less than the `min_intra_sentence_chunk_duration`, we ignore the sentence. If the audio duration for the chunk is greater than `max_intra_sentence_chunk_duration`, the audio is split into chunks of length `max_intra_sentence_chunk_duration` (i.e. from 12.05 to 14.05, from 14.05 to 16.05, etc. if the last chunk is less than the `min_intra_sentence_chunk_duration`, it is ignored). Each audio chunk is ran through the audio classification model and predicts every known person to the model. Each prediction has a confidence attached to it. The confidence values are used to create a dictionary of: `label -> list of confidence values`. Once all chunks for a single sentence is predicted, the mean confidence is computed for each. The label with the highest mean confidence is used as the sentence speaker name. If the highest mean confidence is less than the `min_sentence_mean_confidence`, no label is stored in the `speaker_name` sentence property (thresholded out). """ # Generate random uuid filename for storing temp audio chunks tmp_audio_chunk_save_path = f"tmp-audio-chunk--{str(uuid4())}.wav" # Load transcript if isinstance(transcript, (str, Path)): with open(transcript) as open_f: loaded_transcript = Transcript.from_json(open_f.read()) else: loaded_transcript = transcript # Load audio if isinstance(audio, (str, Path)): audio = AudioSegment.from_file(audio) # Load model if isinstance(model, str): classifier = pipeline("audio-classification", model=model) else: classifier = model n_speakers = len(classifier.model.config.id2label) # Convert to millis min_intra_sentence_chunk_duration_millis = min_intra_sentence_chunk_duration * 1000 max_intra_sentence_chunk_duration_millis = max_intra_sentence_chunk_duration * 1000 # Iter transcript, get sentence audio, chunk into sections # of two seconds or less, classify each, and take most common, # with thresholding confidence * segments met_threshold = 0 missed_threshold = 0 log.info("Annotating sentences with speaker names") for i, sentence in tqdm( enumerate(loaded_transcript.sentences), desc="Sentences annotated", ): # Keep track of each sentence chunk classification and score chunk_scores: dict[str, list[float]] = {} # Get audio slice for sentence sentence_start_millis = sentence.start_time * 1000 sentence_end_millis = sentence.end_time * 1000 # Split into smaller chunks for chunk_start_millis in np.arange( sentence_start_millis, sentence_end_millis, max_intra_sentence_chunk_duration_millis, ): # Tentative chunk end chunk_end_millis = ( chunk_start_millis + max_intra_sentence_chunk_duration_millis ) # Determine chunk end time # If start + chunk duration is longer than sentence # Chunk needs to be cut at sentence end if sentence_end_millis < chunk_end_millis: chunk_end_millis = sentence_end_millis # Only allow if duration is greater than min intra sentence chunk duration duration = chunk_end_millis - chunk_start_millis if duration >= min_intra_sentence_chunk_duration_millis: # Get chunk chunk = audio[chunk_start_millis:chunk_end_millis] # Write to temp chunk.export(tmp_audio_chunk_save_path, format="wav") # Predict and store scores for sentence preds = classifier(tmp_audio_chunk_save_path, top_k=n_speakers) for pred in preds: if pred["label"] not in chunk_scores: chunk_scores[pred["label"]] = [] chunk_scores[pred["label"]].append(pred["score"]) # Create mean score sentence_speaker = None if len(chunk_scores) > 0: mean_scores: dict[str, float] = {} for speaker, scores in chunk_scores.items(): mean_scores[speaker] = sum(scores) / len(scores) # Get highest scoring speaker and their score highest_mean_speaker = "" highest_mean_score = 0.0 for speaker, score in mean_scores.items(): if score > highest_mean_score: highest_mean_speaker = speaker highest_mean_score = score # Threshold holdout if highest_mean_score >= min_sentence_mean_confidence: sentence_speaker = highest_mean_speaker met_threshold += 1 else: log.debug( f"Missed speaker annotation confidence threshold for sentence {i} " f"-- Highest Mean Confidence: {highest_mean_score}" ) missed_threshold += 1 # Store to transcript sentence.speaker_name = sentence_speaker # Remove last made chunk file Path(tmp_audio_chunk_save_path).unlink() log.info( f"Total sentences: {len(loaded_transcript.sentences)}, " f"Sentences Annotated: {met_threshold}, " f"Missed Threshold: {missed_threshold}" ) return loaded_transcript