src.load

 1from typing import Optional
 2
 3import pandas as pd
 4import streamlit as st
 5from datasets import Dataset  # type: ignore
 6
 7from src.data import encode_dataset, get_collator, get_data, predict
 8from src.model import get_encoder, get_model, get_tokenizer
 9from src.subpages import Context
10from src.utils import align_sample, device, explode_df
11
12_TOKENIZER_NAME = (
13    "xlm-roberta-base",
14    "gagan3012/bert-tiny-finetuned-ner",
15    "distilbert-base-german-cased",
16)[0]
17
18
19def _load_models_and_tokenizer(
20    encoder_model_name: str,
21    model_name: str,
22    tokenizer_name: Optional[str],
23    device: str = "cpu",
24):
25    sentence_encoder = get_encoder(encoder_model_name, device=device)
26    tokenizer = get_tokenizer(tokenizer_name if tokenizer_name else model_name)
27    labels = "O B-COMMA".split() if "comma" in model_name else None
28    model = get_model(model_name, labels=labels)
29    return sentence_encoder, model, tokenizer
30
31
32@st.cache(allow_output_mutation=True)
33def load_context(
34    encoder_model_name: str,
35    model_name: str,
36    ds_name: str,
37    ds_config_name: str,
38    ds_split_name: str,
39    split_sample_size: int,
40    **kw_args,
41) -> Context:
42    """Utility method loading (almost) everything we need for the application.
43    This exists just because we want to cache the results of this function.
44
45    Args:
46        encoder_model_name (str): Name of the sentence encoder to load.
47        model_name (str): Name of the NER model to load.
48        ds_name (str): Dataset name or path.
49        ds_config_name (str): Dataset config name.
50        ds_split_name (str): Dataset split name.
51        split_sample_size (int): Number of examples to load from the split.
52
53    Returns:
54        Context: An object containing everything we need for the application.
55    """
56
57    sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
58        encoder_model_name=encoder_model_name,
59        model_name=model_name,
60        tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None,
61        device=str(device),
62    )
63    collator = get_collator(tokenizer)
64
65    # load data related stuff
66    split: Dataset = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
67    tags = split.features["ner_tags"].feature
68    split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
69
70    # transform into dataframe
71    df = predict(split_encoded, model, tokenizer, collator, tags)
72    df["word_ids"] = word_ids
73    df["ids"] = ids
74
75    # explode, clean, merge
76    df_tokens = explode_df(df)
77    df_tokens_cleaned = df_tokens.query("labels != 'IGN'")
78    df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
79    df_tokens_merged = explode_df(df_merged)
80
81    return Context(
82        **{
83            "model": model,
84            "tokenizer": tokenizer,
85            "sentence_encoder": sentence_encoder,
86            "df": df,
87            "df_tokens": df_tokens,
88            "df_tokens_cleaned": df_tokens_cleaned,
89            "df_tokens_merged": df_tokens_merged,
90            "tags": tags,
91            "labels": tags.names,
92            "split_sample_size": split_sample_size,
93            "ds_name": ds_name,
94            "ds_config_name": ds_config_name,
95            "ds_split_name": ds_split_name,
96            "split": split,
97        }
98    )
@st.cache(allow_output_mutation=True)
def load_context( encoder_model_name: str, model_name: str, ds_name: str, ds_config_name: str, ds_split_name: str, split_sample_size: int, **kw_args) -> src.subpages.page.Context:
33@st.cache(allow_output_mutation=True)
34def load_context(
35    encoder_model_name: str,
36    model_name: str,
37    ds_name: str,
38    ds_config_name: str,
39    ds_split_name: str,
40    split_sample_size: int,
41    **kw_args,
42) -> Context:
43    """Utility method loading (almost) everything we need for the application.
44    This exists just because we want to cache the results of this function.
45
46    Args:
47        encoder_model_name (str): Name of the sentence encoder to load.
48        model_name (str): Name of the NER model to load.
49        ds_name (str): Dataset name or path.
50        ds_config_name (str): Dataset config name.
51        ds_split_name (str): Dataset split name.
52        split_sample_size (int): Number of examples to load from the split.
53
54    Returns:
55        Context: An object containing everything we need for the application.
56    """
57
58    sentence_encoder, model, tokenizer = _load_models_and_tokenizer(
59        encoder_model_name=encoder_model_name,
60        model_name=model_name,
61        tokenizer_name=_TOKENIZER_NAME if "comma" in model_name else None,
62        device=str(device),
63    )
64    collator = get_collator(tokenizer)
65
66    # load data related stuff
67    split: Dataset = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size)
68    tags = split.features["ner_tags"].feature
69    split_encoded, word_ids, ids = encode_dataset(split, tokenizer)
70
71    # transform into dataframe
72    df = predict(split_encoded, model, tokenizer, collator, tags)
73    df["word_ids"] = word_ids
74    df["ids"] = ids
75
76    # explode, clean, merge
77    df_tokens = explode_df(df)
78    df_tokens_cleaned = df_tokens.query("labels != 'IGN'")
79    df_merged = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
80    df_tokens_merged = explode_df(df_merged)
81
82    return Context(
83        **{
84            "model": model,
85            "tokenizer": tokenizer,
86            "sentence_encoder": sentence_encoder,
87            "df": df,
88            "df_tokens": df_tokens,
89            "df_tokens_cleaned": df_tokens_cleaned,
90            "df_tokens_merged": df_tokens_merged,
91            "tags": tags,
92            "labels": tags.names,
93            "split_sample_size": split_sample_size,
94            "ds_name": ds_name,
95            "ds_config_name": ds_config_name,
96            "ds_split_name": ds_split_name,
97            "split": split,
98        }
99    )

Utility method loading (almost) everything we need for the application. This exists just because we want to cache the results of this function.

Args
  • encoder_model_name (str): Name of the sentence encoder to load.
  • model_name (str): Name of the NER model to load.
  • ds_name (str): Dataset name or path.
  • ds_config_name (str): Dataset config name.
  • ds_split_name (str): Dataset split name.
  • split_sample_size (int): Number of examples to load from the split.
Returns

Context: An object containing everything we need for the application.