Source code for speakerbox.main

#!/usr/bin/env python

import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

from transformers import pipeline

if TYPE_CHECKING:
    import datasets
    from datasets import Dataset, DatasetDict, arrow_dataset
    from pyannote.core.annotation import Annotation
    from transformers import EvalPrediction, feature_extraction_utils

###############################################################################

log = logging.getLogger(__name__)

###############################################################################

DEFAULT_BASE_MODEL = "superb/wav2vec2-base-superb-sid"

EVAL_RESULTS_TEMPLATE = """
## Results

* **Accuracy:** {accuracy}
* **Precision:** {precision}
* **Recall:** {recall}
* **Validation Loss:** {loss}

### Confusion
"""

DEFAULT_TRAINER_ARGUMENTS_ARGS = {
    "evaluation_strategy": "epoch",
    "save_strategy": "epoch",
    "learning_rate": 3e-5,
    "per_device_train_batch_size": 8,
    "gradient_accumulation_steps": 1,
    "eval_accumulation_steps": 40,
    "per_device_eval_batch_size": 8,
    "num_train_epochs": 5,
    "warmup_ratio": 0.1,
    "logging_steps": 10,
    "load_best_model_at_end": True,
    "metric_for_best_model": "accuracy",
    "gradient_checkpointing": True,
}


###############################################################################


[docs] def eval_model( validation_dataset: "Dataset", model_name: str = "trained-speakerbox", ) -> Tuple[float, float, float, float]: """ Evaluate a trained model. This will store two files in the model directory, one for the accuracy, precision, and recall in a markdown file and the other is the generated top one confusion matrix as a PNG file. Parameters ---------- validation_dataset: Dataset The dataset to validate the model against. model_name: str A name for the model. This will also create a directory with the same name to store the produced model in. Default: "trained-speakerbox" Returns ------- accuracy: float The model accuracy as returned by sklearn.metrics.accuracy_score. precision: float The model (weighted) precision as returned by sklearn.metrics.precision_score. recall: float The model (weighted) recall as returned by sklearn.metrics.recall_score. loss: float The model log loss as returned by sklearn.metrics.log_loss. """ import matplotlib.pyplot as plt from sklearn.metrics import ( ConfusionMatrixDisplay, accuracy_score, log_loss, precision_score, recall_score, ) log.info("Setting up evaluation pipeline") classifier = pipeline( "audio-classification", model=model_name, ) log.info("Running eval") def predict(example: "datasets.arrow_dataset.Example") -> Dict[str, Any]: pred = classifier(example["audio"], top_k=1000) pred_as_dict = {i["label"]: i["score"] for i in pred} top_pred = max(pred_as_dict, key=pred_as_dict.get) # type: ignore return { "pred_label": top_pred, "true_label": classifier.model.config.id2label[example["label"]], "pred_scores": [i["score"] for i in pred], } validation_dataset = validation_dataset.map(predict) # Create confusion ConfusionMatrixDisplay.from_predictions( validation_dataset["true_label"], validation_dataset["pred_label"], ) plt.xticks(rotation=45) plt.yticks(rotation=45) plt.savefig(f"{model_name}/validation-confusion.png", bbox_inches="tight") # Compute metrics accuracy = accuracy_score( y_true=validation_dataset["true_label"], y_pred=validation_dataset["pred_label"], ) precision = precision_score( y_true=validation_dataset["true_label"], y_pred=validation_dataset["pred_label"], average="weighted", ) recall = recall_score( y_true=validation_dataset["true_label"], y_pred=validation_dataset["pred_label"], average="weighted", ) loss = log_loss( y_true=validation_dataset["true_label"], y_pred=validation_dataset["pred_scores"], ) # Store metrics with open(f"{model_name}/results.md", "w") as open_f: open_f.write( EVAL_RESULTS_TEMPLATE.format( accuracy=accuracy, precision=precision, recall=recall, loss=loss, ) ) return (accuracy, precision, recall, loss)
[docs] def train( dataset: "DatasetDict", model_name: str = "trained-speakerbox", model_base: str = DEFAULT_BASE_MODEL, max_duration: float = 2.0, seed: Optional[int] = None, use_cpu: bool = False, trainer_arguments_kws: Dict[str, Any] = DEFAULT_TRAINER_ARGUMENTS_ARGS, ) -> Path: """ Train a speaker classification model. Parameters ---------- dataset: DatasetDict The datasets to use for training, testing, and validation. Should only contain the columns/features: "label" and "audio". The values in the "audio" column should be paths to the audio files. model_name: str A name for the model. This will also create a directory with the same name to store the produced model in. Default: "trained-speakerbox" model_base: str The model base to use before fine tuning. max_duration: float The maximum duration to use for each audio clip. Any clips longer than this will be trimmed. Default: 2.0 seed: Optional[int] Seed to pass to torch, numpy, and Python RNGs. Default: None (do not set a seed) use_cpu: bool Should the model be trained using CPU. This also sets `no_cuda=True` on TrainerArguments. Default: False (use GPU if available) trainer_arguments_kws: Dict[Any] Any additional keyword arguments to be passed to the HuggingFace TrainerArguments object. Default: DEFAULT_TRAINER_ARGUMENTS_ARGS Returns ------- model_storage_path: Path The path to the directory where the model is stored. """ import numpy as np import torch import transformers from datasets import Audio, load_metric from transformers import ( Trainer, TrainingArguments, Wav2Vec2FeatureExtractor, Wav2Vec2ForSequenceClassification, ) from .utils import set_global_seed # Handle seed if seed: set_global_seed(seed) # Handle cpu if use_cpu: trainer_arguments_kws["no_cuda"] = True # Load feature extractor feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_base) # Convert dataset audios log.info("Casting all audio paths to torch Audio") dataset = dataset.cast_column("audio", Audio(feature_extractor.sampling_rate)) # Construct label to id and vice-versa LUTs label2id, id2label = {}, {} for i, label in enumerate(dataset["train"].features["label"].names): label2id[label] = str(i) id2label[str(i)] = label # Construct preprocessing function def preprocess( examples: "arrow_dataset.Batch", ) -> "feature_extraction_utils.BatchFeature": audio_arrays = [x["array"] for x in examples["audio"]] inputs = feature_extractor( audio_arrays, sampling_rate=feature_extractor.sampling_rate, max_length=int(feature_extractor.sampling_rate * max_duration), do_normalize=True, truncation=True, padding=True, ) return inputs # Encode the dataset log.info("Extracting features from audio") dataset = dataset.map(preprocess, batched=True) # Create AutoModel log.info("Setting up Trainer") model = Wav2Vec2ForSequenceClassification.from_pretrained( model_base, num_labels=len(id2label), label2id=label2id, id2label=id2label, ignore_mismatched_sizes=True, ) # Create fine tuning Trainer args = TrainingArguments( model_name, **trainer_arguments_kws, ) # Compute accuracy metrics metric = load_metric("accuracy") def compute_metrics(eval_pred: "EvalPrediction") -> Optional[Dict]: # Eval pred comes with both the predictions and the attention mask # grab just the predictions predictions = np.argmax(eval_pred.predictions[0], axis=-1) return metric.compute(predictions=predictions, references=eval_pred.label_ids) # Trainer and train! trainer = Trainer( model, args, train_dataset=dataset["train"], eval_dataset=dataset["test"], tokenizer=feature_extractor, compute_metrics=compute_metrics, ) torch.cuda.empty_cache() transformers.logging.set_verbosity_info() trainer.train() # Save model trainer.save_model() return Path(model_name).resolve()
[docs] def apply( # noqa: C901 audio: Union[str, Path], model: str, mode: Literal["diarize", "naive"] = "diarize", min_chunk_duration: float = 0.5, max_chunk_duration: float = 2.0, confidence_threshold: float = 0.85, ) -> "Annotation": """ Iteritively apply the model across chunks of an audio file. Parameters ---------- audio: Union[str, Path] The audio filepath. model: str The path to the trained audio-classification model. mode: Literal["diarize", "naive"] Which mode to use for processing. "diarize" will diarize the audio prior to generating chunks to classify. "naive" will iteratively process chunks. "naive" is assumed to be faster but have worse performance. Default: "diarize" min_chunk_duration: float The minimum size in seconds a chunk of audio is allowed to be for it to be ran through the classification pipeline. Default: 0.5 seconds max_chunk_duration: float The maximum size in seconds a chunk of audio is allowed to be for it to be ran through the classification pipeline. Default: 2 seconds confidence_threshold: float A value to act as a lower bound to the reported confidence of the model prediction. Any classification that has a confidence lower than this value will be ignore and not added as a segment. Default: 0.95 (fairly strict / must have high confidence in prediction) Returns ------- Annotation A pyannote.core Annotation with all labeled segments. """ import numpy as np from pyannote.audio import Pipeline from pyannote.core.annotation import Annotation from pyannote.core.segment import Segment from pyannote.core.utils.types import Label, TrackName from pydub import AudioSegment from tqdm import tqdm # Just set track name to the same as the audio filepath track_name = str(audio) # Read audio file loaded_audio = AudioSegment.from_file(audio) # Load model classifier = pipeline("audio-classification", model=model) # Get number of speakers n_speakers = len(classifier.model.config.id2label) # Generate random uuid filename for storing temp audio chunks tmp_audio_chunk_save_path = Path(".tmp-audio-chunk-during-apply.wav") def _diarize() -> List[Tuple[Segment, TrackName, Label]]: # noqa: C901 diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization") dia = diarization_pipeline(audio) # Prep for calculations max_chunk_duration_millis = max_chunk_duration * 1000 # Return chunks for each diarized section records: List[Tuple[Segment, TrackName, Label]] = [] for turn, _, _ in tqdm(dia.itertracks(yield_label=True)): # Keep track of each turn chunk classification and score chunk_scores: Dict[str, List[float]] = {} # Get audio slice for turn turn_start_millis = turn.start * 1000 turn_end_millis = turn.end * 1000 # Split into smaller chunks for chunk_start_millis_float in np.arange( turn_start_millis, turn_end_millis, max_chunk_duration_millis, ): # Round start to nearest int chunk_start_millis = round(chunk_start_millis_float) # Tentative chunk end chunk_end_millis = round(chunk_start_millis + max_chunk_duration_millis) # Determine chunk end time # If start + chunk duration is longer than turn # Chunk needs to be cut at turn end if turn_end_millis < chunk_end_millis: chunk_end_millis = round(turn_end_millis) # Only allow if duration is greater than # min intra turn chunk duration duration = chunk_end_millis - chunk_start_millis if duration >= min_chunk_duration: # Get chunk chunk = loaded_audio[chunk_start_millis:chunk_end_millis] # Write to temp chunk.export(tmp_audio_chunk_save_path, format="wav") # Predict and store scores for turn preds = classifier( str(tmp_audio_chunk_save_path), top_k=n_speakers, ) for pred in preds: if pred["label"] not in chunk_scores: chunk_scores[pred["label"]] = [] chunk_scores[pred["label"]].append(pred["score"]) # Create mean score turn_speaker = None if len(chunk_scores) > 0: mean_scores: Dict[str, float] = {} for speaker, scores in chunk_scores.items(): mean_scores[speaker] = sum(scores) / len(scores) # Get highest scoring speaker and their score highest_mean_speaker = "" highest_mean_score = 0.0 for speaker, score in mean_scores.items(): if score > highest_mean_score: highest_mean_speaker = speaker highest_mean_score = score # Threshold holdout if highest_mean_score >= confidence_threshold: turn_speaker = highest_mean_speaker # Store record records.append( ( Segment(turn.start, turn.end), track_name, turn_speaker, ) ) return records def _naive() -> List[Tuple[Segment, TrackName, Label]]: # Move audio window, apply, and append annotation record records: List[Tuple[Segment, TrackName, Label]] = [] for chunk_start_seconds in tqdm( np.arange(0, loaded_audio.duration_seconds, max_chunk_duration) ): # Calculate chunk end chunk_end_seconds = chunk_start_seconds + max_chunk_duration if chunk_end_seconds > loaded_audio.duration_seconds: chunk_end_seconds = loaded_audio.duration_seconds # Check if duration is long enough duration = chunk_end_seconds - chunk_start_seconds if duration >= min_chunk_duration: # Convert seconds to millis chunk_start_millis = chunk_start_seconds * 1000 chunk_end_millis = chunk_end_seconds * 1000 # Select chunk chunk = loaded_audio[chunk_start_millis:chunk_end_millis] # Write chunk to temp chunk.export(tmp_audio_chunk_save_path, format="wav") # Predict, keep top 1 and store to records pred = classifier(str(tmp_audio_chunk_save_path), top_k=1)[0] if pred["score"] >= confidence_threshold: records.append( ( Segment(chunk_start_seconds, chunk_end_seconds), track_name, pred["label"], ) ) return records # Classify based off strategy mode_lut = { "diarize": _diarize, "naive": _naive, } # Generate records and clean up try: records = mode_lut[mode]() # Merge segments that are touching merged_records: List[Tuple[Segment, TrackName, Label]] = [] current_record: Optional[Tuple[Segment, TrackName, Label]] = None for record in records: if current_record is None: current_record = record else: # The label matches and the segment start and end points are # touching, merge if ( record[2] == current_record[2] and record[0].start == current_record[0].end ): # Make new record with merged data # because tuples are immutable current_record = ( Segment(current_record[0].start, record[0].end), track_name, current_record[2], ) else: merged_records.append(current_record) current_record = record # Add the last current segment # we only do this type check to handle the type error if current_record is not None: merged_records.append(current_record) return Annotation.from_records(merged_records) finally: # Always clean up tmp file if tmp_audio_chunk_save_path.exists(): tmp_audio_chunk_save_path.unlink()