#!/usr/bin/env python
from __future__ import annotations
import logging
import re
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
import spacy
from faster_whisper import WhisperModel as FasterWhisper
from pydub import AudioSegment
from spacy.cli.download import download as download_spacy_model
from spacy.tokenizer import Tokenizer
from spacy.util import compile_infix_regex
from tqdm import tqdm
from .. import __version__
from ..pipeline import transcript_model
from .sr_model import SRModel
if TYPE_CHECKING:
from spacy.language import Language
###############################################################################
log = logging.getLogger(__name__)
###############################################################################
DEFAULT_SPACY_MODEL = "en_core_web_lg"
spacy.prefer_gpu()
###############################################################################
MODEL_NAME_FAKE_CONFIDENCE_LUT = {
"tiny": 0.5,
"base": 0.6,
"small": 0.65,
"medium": 0.71,
"large-v2": 0.72,
}
[docs]
class WhisperModel(SRModel):
@staticmethod
def _load_spacy_model() -> "Language":
nlp = spacy.load(
DEFAULT_SPACY_MODEL,
# Only keep the parser
# We are only using this for sentence parsing
disable=[
"tagger",
"ner",
"lemmatizer",
"textcat",
],
)
# Do not split hyphenated words and numbers
# "re-upping" should not be split into ["re", "-", "upping"].
# Credit: https://stackoverflow.com/a/59996153
def custom_tokenizer(nlp: "Language") -> Tokenizer:
inf = list(nlp.Defaults.infixes)
inf.remove(r"(?<=[0-9])[+\-\*^](?=[0-9-])")
infixes = (*inf, r"(?<=[0-9])[+*^](?=[0-9-])", r"(?<=[0-9])-(?=-)")
infixes = tuple(
[x for x in infixes if "-|–|—|--|---|——|~" not in x] # noqa: RUF001
)
infix_re = compile_infix_regex(infixes)
return Tokenizer(
nlp.vocab,
prefix_search=nlp.tokenizer.prefix_search,
suffix_search=nlp.tokenizer.suffix_search,
infix_finditer=infix_re.finditer,
token_match=nlp.tokenizer.token_match,
rules=nlp.Defaults.tokenizer_exceptions,
)
nlp.tokenizer = custom_tokenizer(nlp)
return nlp
def __init__(
self,
model_name: str = "medium",
confidence: float | None = None,
**kwargs: Any,
):
"""
Initialize an OpenAI Whisper Model Transcription processor.
Parameters
----------
model_name: str
The model version to use. Default: "medium"
See:
https://github.com/openai/whisper/tree/0b5dcfdef7ec04250b76e13f1630e32b0935ce76#available-models-and-languages
confidence: Optional[float]
A confidence value to set for all transcripts produced by this SR Model.
See source code for issues related to this.
Default: None (lookup a fake confidence to use depending on model selected)
kwargs: Any
Any extra arguments to catch.
"""
# Handle large -> large v2
if model_name == "large":
model_name = "large-v2"
self.model_name = model_name
# Load whisper model
self.model = FasterWhisper(self.model_name)
# TODO: whisper doesn't provide a confidence value
# Additionally, we have been overloading confidence with webvtt
# conversion. We may want to get rid of confidence?
# Our current confidence default is 0.001 higher than our old
# WebVTT parser to ensure that these transcripts are chosen.
if confidence is not None:
self.confidence = confidence
else:
self.confidence = MODEL_NAME_FAKE_CONFIDENCE_LUT[model_name]
# Init spacy
try:
self.nlp = self._load_spacy_model()
except Exception:
download_spacy_model(DEFAULT_SPACY_MODEL)
self.nlp = self._load_spacy_model()
[docs]
def transcribe(
self,
file_uri: str | Path,
**kwargs: Any,
) -> transcript_model.Transcript:
"""
Transcribe audio from file and return a Transcript model.
Parameters
----------
file_uri: Union[str, Path]
The uri to the audio file or caption file to transcribe.
kwargs: Any
Any extra arguments to catch.
Returns
-------
outputs: transcript_model.Transcript
The transcript model for the supplied media file.
"""
log.info(f"Transcribing '{file_uri}'")
segments, _ = self.model.transcribe(file_uri, word_timestamps=True)
timestamped_words_with_meta = []
for segment in tqdm(segments, desc="Transcribing segment..."):
for word in segment.words:
word_text = word.word
word_text = word_text.replace("♪", "")
word_text = word_text.replace("≫", "")
word_text = re.sub(r" +", " ", word_text)
word_text = re.sub(r"( +)(\.)", ".", word_text)
word_text = word_text.strip()
if len(word_text) > 0:
timestamped_words_with_meta.append(
{
"text": word_text,
"start": word.start,
"end": word.end,
}
)
# For some reason, whisper sometimes returns segments with
# start and end times that are impossible
# i.e. start and end of 185 second when the total audio duration is 180 seconds
# Fix all timestamps by rescaling to audio duration
# This is a hack -- but all of the word level timestamps are a hack anyway...
whisper_reported_duration = timestamped_words_with_meta[-1]["end"]
file_reported_duration = AudioSegment.from_file(file_uri).duration_seconds
# Scale to between 0 and 1
# Then rescale to real duration
log.info("Ensuring timestamps fit within audio")
for word_with_meta in timestamped_words_with_meta:
# Scale to between 0 and 1
word_with_meta["start"] = (
word_with_meta["start"] / whisper_reported_duration
)
word_with_meta["end"] = word_with_meta["end"] / whisper_reported_duration
# Rescale to real duration
word_with_meta["start"] = word_with_meta["start"] * file_reported_duration
word_with_meta["end"] = word_with_meta["end"] * file_reported_duration
# Process all text
joined_all_words = " ".join(
[word_with_meta["text"] for word_with_meta in timestamped_words_with_meta]
)
joined_all_words = re.sub(r" +", " ", joined_all_words).strip()
doc = self.nlp(joined_all_words)
# Process sentences
sentences_with_word_metas = []
current_word_index_start = 0
log.info("Constructing sentences with word metadata")
for doc_sent in doc.sents:
doc_sent_text = doc_sent.text.strip()
# Sometimes spacy produces a doc sentence that is just a period or comma.
# This sentence is attached to the end of the word
# in the timestamped words with metas list
# We can simply ignore those odd sentences
if any([c == doc_sent_text for c in [".", ","]]):
continue
log.info(f"Doc sent: '{doc_sent_text}'")
# Split the sentence
doc_sent_words = doc_sent_text.split(" ")
# Find the words
word_subset = timestamped_words_with_meta[
current_word_index_start : current_word_index_start
+ len(doc_sent_words)
]
log.info(f"\tWords: {[w_w_m['text'] for w_w_m in word_subset]}")
# Append the words
sentences_with_word_metas.append(word_subset)
# Increase the current word index start
current_word_index_start = current_word_index_start + len(doc_sent_words)
# Remove any length zero sentences
sentences_with_word_metas = [
sentence_with_word_metas
for sentence_with_word_metas in sentences_with_word_metas
if len(sentence_with_word_metas) > 0
]
# Reformat data to our structure
structured_sentences: list[transcript_model.Sentence] = []
log.info("Converting sentences with word meta to transcript format")
for sent_index, sentence_with_word_metas in enumerate(
sentences_with_word_metas,
):
# Join all the sentence text
sentence_text = " ".join(
[word_with_meta["text"] for word_with_meta in sentence_with_word_metas]
).strip()
# Make sure the first letter is capitalized
# NOTE: we cannot use the `capitalize` string function
# because it will lowercase the rest of the text
sentence_text = sentence_text[0].upper() + sentence_text[1:]
# Create the sentence object
structured_sentences.append(
transcript_model.Sentence(
index=sent_index,
confidence=self.confidence,
start_time=sentence_with_word_metas[0]["start"],
end_time=sentence_with_word_metas[-1]["end"],
text=sentence_text,
words=[
transcript_model.Word(
index=word_index,
start_time=word_with_meta["start"],
end_time=word_with_meta["end"],
text=self._clean_word(word_with_meta["text"]),
)
for word_index, word_with_meta in enumerate(
sentence_with_word_metas
)
],
)
)
# Return complete transcript object
return transcript_model.Transcript(
generator=(
f"CDP Whisper Conversion "
f"-- CDP v{__version__} "
f"-- Whisper Model Name '{self.model_name}'"
),
confidence=self.confidence,
session_datetime=None,
created_datetime=datetime.utcnow().isoformat(),
sentences=structured_sentences,
)