src.load
1from typing import Optional 2 3import pandas as pd 4import streamlit as st 5from datasets import Dataset # type: ignore 6 7from src.data import encode_dataset, get_collator, get_data, predict 8from src.model import get_encoder, get_model, get_tokenizer 9from src.subpages import Context 10from src.utils import align_sample, device, explode_df 11 12_TOKENIZER_NAME = ( 13 "xlm-roberta-base", 14 "gagan3012/bert-tiny-finetuned-ner", 15 "distilbert-base-german-cased", 16)[0] 17 18 19def _load_models_and_tokenizer( 20 encoder_model_name: str, 21 model_name: str, 22 tokenizer_name: Optional[str], 23 device: str = "cpu", 24): 25 sentence_encoder = get_encoder(encoder_model_name, device=device) 26 tokenizer = get_tokenizer(tokenizer_name if tokenizer_name else model_name) 27 labels = "O B-COMMA".split() if "comma" in model_name else None 28 model = get_model(model_name, labels=labels) 29 return sentence_encoder, model, tokenizer 30 31 32@st.cache(allow_output_mutation=True) 33def load_context( 34 encoder_model_name: str, 35 model_name: str, 36 ds_name: str, 37 ds_config_name: str, 38 ds_split_name: str, 39 split_sample_size: int, 40 **kw_args, 41) -> Context: 42 """Utility method loading (almost) everything we need for the application. 43 This exists just because we want to cache the results of this function. 44 45 Args: 46 encoder_model_name (str): Name of the sentence encoder to load. 47 model_name (str): Name of the NER model to load. 48 ds_name (str): Dataset name or path. 49 ds_config_name (str): Dataset config name. 50 ds_split_name (str): Dataset split name. 51 split_sample_size (int): Number of examples to load from the split. 52 53 Returns: 54 Context: An object containing everything we need for the application. 55 """ 56 57 sentence_encoder, model, tokenizer = _load_models_and_tokenizer( 58 encoder_model_name=encoder_model_name, 59 model_name=model_name, 60 tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None, 61 device=str(device), 62 ) 63 collator = get_collator(tokenizer) 64 65 # load data related stuff 66 split: Dataset = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size) 67 tags = split.features["ner_tags"].feature 68 split_encoded, word_ids, ids = encode_dataset(split, tokenizer) 69 70 # transform into dataframe 71 df = predict(split_encoded, model, tokenizer, collator, tags) 72 df["word_ids"] = word_ids 73 df["ids"] = ids 74 75 # explode, clean, merge 76 df_tokens = explode_df(df) 77 df_tokens_cleaned = df_tokens.query("labels != 'IGN'") 78 df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist()) 79 df_tokens_merged = explode_df(df_merged) 80 81 return Context( 82 **{ 83 "model": model, 84 "tokenizer": tokenizer, 85 "sentence_encoder": sentence_encoder, 86 "df": df, 87 "df_tokens": df_tokens, 88 "df_tokens_cleaned": df_tokens_cleaned, 89 "df_tokens_merged": df_tokens_merged, 90 "tags": tags, 91 "labels": tags.names, 92 "split_sample_size": split_sample_size, 93 "ds_name": ds_name, 94 "ds_config_name": ds_config_name, 95 "ds_split_name": ds_split_name, 96 "split": split, 97 } 98 )
@st.cache(allow_output_mutation=True)
def
load_context( encoder_model_name: str, model_name: str, ds_name: str, ds_config_name: str, ds_split_name: str, split_sample_size: int, **kw_args) -> src.subpages.page.Context:
33@st.cache(allow_output_mutation=True) 34def load_context( 35 encoder_model_name: str, 36 model_name: str, 37 ds_name: str, 38 ds_config_name: str, 39 ds_split_name: str, 40 split_sample_size: int, 41 **kw_args, 42) -> Context: 43 """Utility method loading (almost) everything we need for the application. 44 This exists just because we want to cache the results of this function. 45 46 Args: 47 encoder_model_name (str): Name of the sentence encoder to load. 48 model_name (str): Name of the NER model to load. 49 ds_name (str): Dataset name or path. 50 ds_config_name (str): Dataset config name. 51 ds_split_name (str): Dataset split name. 52 split_sample_size (int): Number of examples to load from the split. 53 54 Returns: 55 Context: An object containing everything we need for the application. 56 """ 57 58 sentence_encoder, model, tokenizer = _load_models_and_tokenizer( 59 encoder_model_name=encoder_model_name, 60 model_name=model_name, 61 tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None, 62 device=str(device), 63 ) 64 collator = get_collator(tokenizer) 65 66 # load data related stuff 67 split: Dataset = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size) 68 tags = split.features["ner_tags"].feature 69 split_encoded, word_ids, ids = encode_dataset(split, tokenizer) 70 71 # transform into dataframe 72 df = predict(split_encoded, model, tokenizer, collator, tags) 73 df["word_ids"] = word_ids 74 df["ids"] = ids 75 76 # explode, clean, merge 77 df_tokens = explode_df(df) 78 df_tokens_cleaned = df_tokens.query("labels != 'IGN'") 79 df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist()) 80 df_tokens_merged = explode_df(df_merged) 81 82 return Context( 83 **{ 84 "model": model, 85 "tokenizer": tokenizer, 86 "sentence_encoder": sentence_encoder, 87 "df": df, 88 "df_tokens": df_tokens, 89 "df_tokens_cleaned": df_tokens_cleaned, 90 "df_tokens_merged": df_tokens_merged, 91 "tags": tags, 92 "labels": tags.names, 93 "split_sample_size": split_sample_size, 94 "ds_name": ds_name, 95 "ds_config_name": ds_config_name, 96 "ds_split_name": ds_split_name, 97 "split": split, 98 } 99 )
Utility method loading (almost) everything we need for the application. This exists just because we want to cache the results of this function.
Args
- encoder_model_name (str): Name of the sentence encoder to load.
- model_name (str): Name of the NER model to load.
- ds_name (str): Dataset name or path.
- ds_config_name (str): Dataset config name.
- ds_split_name (str): Dataset split name.
- split_sample_size (int): Number of examples to load from the split.
Returns
Context: An object containing everything we need for the application.