Source code for cdp_backend.bin.search_cdp_events

#!/usr/bin/env python

import argparse
import logging
import sys
import traceback
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Dict, List, NamedTuple

import dask.dataframe as dd
import fireo
from google.auth.credentials import AnonymousCredentials
from google.cloud.firestore import Client
from nltk import ngrams
from nltk.stem import SnowballStemmer

from cdp_backend.database import models as db_models
from cdp_backend.utils.string_utils import clean_text

###############################################################################

logging.basicConfig(
    level=logging.INFO,
    format="[%(levelname)4s: %(module)s:%(lineno)4s %(asctime)s] %(message)s",
)
log = logging.getLogger(__name__)

###############################################################################


[docs] class SearchSortByField(NamedTuple): name: str local_field: str remote_field: str
DATETIME_WEIGHTED_RELEVANCE = SearchSortByField( name="datetime_weighted_relevance", local_field="datetime_weighted_tfidf", remote_field="datetime_weighted_value", ) RELEVANCE = SearchSortByField( name="relevance", local_field="tfidf", remote_field="value", ) ###############################################################################
[docs] class Args(argparse.Namespace): def __init__(self) -> None: self.__parse() def __parse(self) -> None: p = argparse.ArgumentParser( prog="search_cdp_events", description="Search CDP events given a query." ) p.add_argument( "instance", type=str, help="Which instance to query.", ) p.add_argument( "-q", "--query", type=str, default="residential zoning and housing affordability", help="Query to search with.", ) p.add_argument( "-s", "--sort_by", type=str, default=DATETIME_WEIGHTED_RELEVANCE.name, choices=[DATETIME_WEIGHTED_RELEVANCE.name, RELEVANCE.name], help="Choice between datetime weighted and pure relevance (TFIDF score).", ) p.add_argument( "-f", "--first", type=int, default=4, help="Number of results to return.", ) p.add_argument( "-l", "--local_index_glob", type=str, default="tfidf-*.parquet", help="The file glob for which files to use for reading a planned index.", ) p.parse_args(namespace=self)
###############################################################################
[docs] def get_stemmed_grams_from_query(query: str) -> List[str]: # Spawn stemmer stemmer = SnowballStemmer("english") # Create stemmed grams for query query_terms = clean_text(query, clean_stop_words=True).split() stemmed_grams = [] for n_gram_size in [1, 2, 3]: grams = ngrams(query_terms, n_gram_size) for gram in grams: stemmed_grams.append(" ".join(stemmer.stem(term.lower()) for term in gram)) return stemmed_grams
def _query_event_index( stemmed_gram: str, ) -> List[db_models.IndexedEventGram]: # Filter for stemmed gram filtered_set = db_models.IndexedEventGram.collection.filter( "stemmed_gram", "==", stemmed_gram ) return list(filtered_set.fetch(limit=int(1e9)))
[docs] class EventMatch(NamedTuple): event: db_models.Event pure_relevance: float datetime_weighted_relevance: float contained_grams: List[str] selected_context_span: str keywords: List[str]
[docs] def main() -> None: try: args = Args() # Connect to the database fireo.connection( client=Client( project=args.instance, credentials=AnonymousCredentials(), ) ) run_remote_search(args.query, args.sort_by, args.first) run_local_search(args.query, args.local_index_glob, args.sort_by, args.first) except Exception as e: log.error("=============================================") log.error("\n\n" + traceback.format_exc()) log.error("=============================================") log.error("\n\n" + str(e) + "\n") log.error("=============================================") sys.exit(1)
############################################################################### # Allow caller to directly run this module (usually in development scenarios) if __name__ == "__main__": main()