Source code for cdp_backend.pipeline.generate_event_index_pipeline

#!/usr/bin/env python

from __future__ import annotations

import logging
import math
import shutil
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import NamedTuple

import pandas as pd
import pytz
import rapidfuzz
from dataclasses_json import DataClassJsonMixin
from gcsfs import GCSFileSystem
from nltk import ngrams
from nltk.stem import SnowballStemmer
from prefect import Flow, task, unmapped

from ..database import functions as db_functions
from ..database import models as db_models
from ..file_store import functions as fs_functions
from ..utils import string_utils
from .pipeline_config import EventIndexPipelineConfig
from .transcript_model import Sentence, Transcript

REMOTE_INDEX_CHUNK_DIR = "index-chunks"

###############################################################################

log = logging.getLogger(__name__)

###############################################################################


@task
def get_transcripts(credentials_file: str) -> list[db_models.Transcript]:
    """
    Initialize fireo connection and pull all Transcript models.

    Parameters
    ----------
    credentials_file: str
        Path to Google Service Account Credentials JSON file.

    Returns
    -------
    transcripts: List[db_models.Transcript]
        All transcript documents found in the database.

    Notes
    -----
    This _can_ become a dangerous operation if the database has millions of transcript
    documents as this will additionally pull all referenced Event, Session, Body,
    and File database models.
    """
    # Fetch all transcripts
    # This comes with references to Session and File models
    # Session models come with reference to Event model
    # Event models come with reference to Body model
    return db_functions.get_all_of_collection(
        db_model=db_models.Transcript,
        credentials_file=credentials_file,
    )


@task
def get_highest_confidence_transcript_for_each_session(
    transcripts: list[db_models.Transcript],
) -> list[db_models.Transcript]:
    """
    Filter down a list transcript documents to just a single transcript
    per session taking the highest confidence transcript document.

    Parameters
    ----------
    transcripts: List[db_models.Transcript]
        List of transcript database documents.

    Returns
    -------
    transcripts: List[db_models.Transcript]
        Filtered list of transcript database documents where only a single transcript
        exists for each referenced session.
    """
    # We can't use pandas groupby because sessions objects can't be naively compared
    # Instead we create a Dict of session id to document model
    # We update as we iterate through list of all transcripts
    selected_transcripts: dict[str, pd.Series] = {}
    for transcript in transcripts:
        referenced_session_id = transcript.session_ref.ref.id
        if referenced_session_id not in selected_transcripts:
            selected_transcripts[referenced_session_id] = transcript

        # Multiple transcripts for a single session
        # pick the higher confidence
        elif (
            transcript.confidence
            > selected_transcripts[referenced_session_id].confidence
        ):
            selected_transcripts[referenced_session_id] = transcript

    return list(selected_transcripts.values())


[docs] class EventTranscripts(NamedTuple): event_id: str event_datetime: datetime transcript_db_files: list[db_models.File]
@task def get_transcripts_per_event( transcripts: list[db_models.Transcript], ) -> list[EventTranscripts]: """ Group all transcripts related to a single event together into EventTranscripts objects. """ # Create event transcripts as event id mapped to EventTranscripts object event_transcripts: dict[str, EventTranscripts] = {} for transcript in transcripts: # Add new event session = transcript.session_ref.get() referenced_event_id = session.event_ref.ref.id if referenced_event_id not in event_transcripts: event_transcripts[referenced_event_id] = EventTranscripts( event_id=referenced_event_id, event_datetime=session.event_ref.get().event_datetime, transcript_db_files=[transcript.file_ref.get()], ) # Update existing event_transcripts object else: event_transcripts[referenced_event_id].transcript_db_files.append( transcript.file_ref.get() ) return list(event_transcripts.values())
[docs] @dataclass class SentenceManager(DataClassJsonMixin): original_details: Sentence cleaned_text: str n_grams: list[tuple[str]]
[docs] @dataclass class ContextualizedGram(DataClassJsonMixin): # We attach the id for simpler gram grouping # We attach the datetime for simpler datetime weighting event_id: str event_datetime: datetime unstemmed_gram: str stemmed_gram: str context_span: str
@task def read_transcripts_and_generate_grams( event_transcripts: EventTranscripts, n_grams: int, credentials_file: str ) -> list[ContextualizedGram]: """ Parse all documents and create a list of contextualized grams for later weighting. Parameters ---------- event_transcripts: EventTranscripts The EventTranscripts object to parse all transcripts for. n_grams: int N number of terms to act as a unique entity. credentials_file: str Path to Google Service Account Credentials JSON file. Returns ------- grams: List[ContextualizedGram] All grams found in all transcripts provided. """ fs = GCSFileSystem(token=credentials_file) # Store all n_gram results event_n_grams: list[ContextualizedGram] = [] # Iter over each transcript for transcript_db_file in event_transcripts.transcript_db_files: with TemporaryDirectory() as temp_dir: temp_dir_path = Path(temp_dir) local_transcript_filepath = temp_dir_path / transcript_db_file.name # Download transcript fs.get( rpath=transcript_db_file.uri, lpath=str(local_transcript_filepath), ) # Init transcript with open(local_transcript_filepath) as open_f: transcript = Transcript.from_json(open_f.read()) # Get cleaned sentences by removing stop words cleaned_sentences: list[SentenceManager] = [ SentenceManager( original_details=sentence, cleaned_text=string_utils.clean_text( sentence.text, clean_stop_words=True, clean_emojis=True, ), n_grams=[], ) for sentence in transcript.sentences ] # Filter any empty sentences cleaned_sentences = [ sm for sm in cleaned_sentences if len(sm.cleaned_text) > 1 ] # Get all n_grams for each sentence for sm in cleaned_sentences: sm.n_grams = [*ngrams(sm.cleaned_text.split(), n_grams)] # Init stemmer and stem all grams stemmer = SnowballStemmer("english") for sm in cleaned_sentences: for n_gram in sm.n_grams: # Join into a single n gram unstemmed_n_gram = " ".join(n_gram) # Join, lower, and stem the n gram stemmed_n_gram = " ".join( [stemmer.stem(term.lower()) for term in n_gram] ) # Get context span # Because ngrams function, cleaning, and split may affect the exact # matchup of the term, use fuzzy diff to find closest closest_term = "" closest_term_score = 0.0 for term in sm.original_details.text.split(): similarity = rapidfuzz.fuzz.QRatio(term, n_gram[0]) if similarity > closest_term_score: closest_term = term closest_term_score = similarity # If the closest term is blank, skip processing this item # Otherwise this throws an error if closest_term == "": continue # Get surrounding terms terms = sm.original_details.text.split() target_term_index = terms.index(closest_term) # Get left and right indices left_i = 0 if target_term_index - 8 < 0 else target_term_index - 8 right_i = ( None if target_term_index + 7 >= len(terms) - 1 else target_term_index + 7 ) context_span = " ".join(terms[left_i:right_i]) # Append ellipsis if left_i != 0: context_span = f"... {context_span}" if right_i is not None: context_span = f"{context_span}..." # Append to event list event_n_grams.append( ContextualizedGram( event_id=event_transcripts.event_id, event_datetime=event_transcripts.event_datetime, unstemmed_gram=unstemmed_n_gram, stemmed_gram=stemmed_n_gram, context_span=context_span, ) ) return event_n_grams @task def convert_all_n_grams_to_dataframe( all_events_n_grams: list[list[ContextualizedGram]], ) -> pd.DataFrame: """Flatten all n grams from all events into one single dataframe.""" return pd.DataFrame( [ n_gram.to_dict() for single_event_n_grams in all_events_n_grams for n_gram in single_event_n_grams ] ) @task def compute_tfidf( n_grams: pd.DataFrame, datetime_weighting_days_decay: int = 30, ) -> pd.DataFrame: """ Compute term frequencies, inverse document frequencies, tfidf, and weighted tfidf values for each n_gram in the dataframe. """ # Get term frequencies n_grams["tf"] = n_grams.groupby( ["event_id", "stemmed_gram"] ).stemmed_gram.transform("count") # Drop duplicates for inverse-document-frequencies n_grams = n_grams.drop_duplicates(["event_id", "stemmed_gram"]) # Get idf n = len(n_grams.event_id.unique()) n_grams["idf"] = ( n_grams.groupby("stemmed_gram") .event_id.transform("count") .apply(lambda df: math.log(n / df)) ) # Store tfidf n_grams["tfidf"] = n_grams.tf * n_grams.idf # Drop terms worth nothing n_grams = n_grams[n_grams.tfidf != 0] # Add datetime weighted tfidf utcnow = datetime.utcnow() utcnow = pytz.timezone("UTC").localize(utcnow) n_grams["datetime_weighted_tfidf"] = n_grams.apply( # Unit of decay is in months (`/ 30`) # `+ 2` protects against divison by zero lambda row: row.tfidf / math.log( ((utcnow - row.event_datetime).days / datetime_weighting_days_decay) + 2 ), axis=1, ) return n_grams @task def chunk_index( n_grams_df: pd.DataFrame, n_grams: int, credentials_file: str, bucket_name: str, ngrams_per_chunk: int = 50_000, storage_dir: str | Path = "index/", store_remote: bool = False, ) -> None: """ Split the large n_grams dataframe into multiple lists of IndexedEventGram models for batched, mapped, upload. Optionally store to cloud firestore. """ # Clean the storage dir storage_dir = Path(storage_dir) if storage_dir.exists(): shutil.rmtree(storage_dir) # Create storage dir storage_dir.mkdir(parents=True) # Split single large dataframe into many dataframes for chunk_index, chunk_offset in enumerate( range(0, n_grams_df.shape[0], ngrams_per_chunk) ): n_grams_chunk = n_grams_df[chunk_offset : chunk_offset + ngrams_per_chunk] save_filename = f"n_gram-{n_grams}--index_chunk-{chunk_index}.parquet" local_chunk_path = storage_dir / save_filename n_grams_chunk.to_parquet(local_chunk_path) # Optional remote storage if store_remote: fs_functions.upload_file( credentials_file=credentials_file, bucket=bucket_name, filepath=str(local_chunk_path), save_name=f"{REMOTE_INDEX_CHUNK_DIR}/{save_filename}", overwrite=True, )
[docs] def create_event_index_generation_pipeline( config: EventIndexPipelineConfig, n_grams: int = 1, ngrams_per_chunk: int = 50_000, store_remote: bool = False, ) -> Flow: """ Create the Prefect Flow object to preview, run, or visualize for indexing all events in the database. Parameters ---------- config: EventIndexPipelineConfig Configuration options for the pipeline. n_grams: int N number of terms to act as a unique entity. Default: 1 ngrams_per_chunk: int The number of ngrams to store in a single chunk file. Default: 50_000 store_remote: bool Should the generated index chunks be sent to cloud storage. Default: False (only store locally) Returns ------- flow: Flow The constructed CDP Event Index Pipeline as a Prefect Flow. """ with Flow("CDP Event Index Pipeline") as flow: # Ensure stopwords are downloaded # Do this once to ensure that we don't enter a race condition # with multiple workers trying to download / read overtop one another # later on. try: from nltk.corpus import stopwords stopwords.words("english") except LookupError: import nltk nltk.download("stopwords") log.info("Downloaded nltk stopwords") from nltk.corpus import stopwords stopwords.words("english") # Get all transcripts all_transcripts = get_transcripts( credentials_file=config.google_credentials_file ) # Select highest confidence transcript for each session selected_transcripts = get_highest_confidence_transcript_for_each_session( transcripts=all_transcripts ) # Get all transcripts for each event (multi-session events) event_transcripts = get_transcripts_per_event(transcripts=selected_transcripts) # Read all transcripts for each event and generate grams all_event_transcript_n_grams = read_transcripts_and_generate_grams.map( event_transcripts=event_transcripts, n_grams=unmapped(n_grams), credentials_file=unmapped(config.google_credentials_file), ) # Convert to dataframe for tfidf calc all_events_n_grams = convert_all_n_grams_to_dataframe( all_events_n_grams=all_event_transcript_n_grams, ) # Weighted n grams by tfidf scored_n_grams = compute_tfidf( n_grams=all_events_n_grams, datetime_weighting_days_decay=config.datetime_weighting_days_decay, ) # Create index chunks and store local and optional remote chunk_index( n_grams_df=scored_n_grams, n_grams=n_grams, credentials_file=config.google_credentials_file, bucket_name=config.validated_gcs_bucket_name, ngrams_per_chunk=ngrams_per_chunk, storage_dir=config.local_storage_dir, store_remote=store_remote, ) return flow