src.utils

  1from pathlib import Path
  2
  3import matplotlib as matplotlib
  4import matplotlib.cm as cm
  5import pandas as pd
  6import streamlit as st
  7import tokenizers
  8import torch
  9import torch.nn.functional as F
 10from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode
 11
 12PROJ = Path(__file__).parent
 13
 14tokenizer_hash_funcs = {
 15    tokenizers.Tokenizer: lambda _: None,
 16    tokenizers.AddedToken: lambda _: None,
 17}
 18# device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu")
 19device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 20
 21classmap = {
 22    "O": "O",
 23    "PER": "🙎",
 24    "person": "🙎",
 25    "LOC": "🌎",
 26    "location": "🌎",
 27    "ORG": "🏤",
 28    "corporation": "🏤",
 29    "product": "📱",
 30    "creative": "🎷",
 31    "MISC": "🎷",
 32}
 33
 34
 35def aggrid_interactive_table(df: pd.DataFrame) -> dict:
 36    """Creates an st-aggrid interactive table based on a dataframe.
 37
 38    Args:
 39        df (pd.DataFrame]): Source dataframe
 40    Returns:
 41        dict: The selected row
 42    """
 43    options = GridOptionsBuilder.from_dataframe(
 44        df, enableRowGroup=True, enableValue=True, enablePivot=True
 45    )
 46
 47    options.configure_side_bar()
 48    # options.configure_default_column(cellRenderer=JsCode('''function(params) {return '<a href="#samples-loss">'+params.value+'</a>'}'''))
 49
 50    options.configure_selection("single")
 51    selection = AgGrid(
 52        df,
 53        enable_enterprise_modules=True,
 54        gridOptions=options.build(),
 55        theme="light",
 56        update_mode=GridUpdateMode.NO_UPDATE,
 57        allow_unsafe_jscode=True,
 58    )
 59
 60    return selection
 61
 62
 63def explode_df(df: pd.DataFrame) -> pd.DataFrame:
 64    """Takes a dataframe and explodes all the fields."""
 65
 66    df_tokens = df.apply(pd.Series.explode)
 67    if "losses" in df.columns:
 68        df_tokens["losses"] = df_tokens["losses"].astype(float)
 69    return df_tokens  # type: ignore
 70
 71
 72def align_sample(row: pd.Series):
 73    """Uses word_ids to align all lists in a sample."""
 74
 75    columns = row.axes[0].to_list()
 76    indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
 77
 78    out = {}
 79
 80    tokens = []
 81    for i, tok in enumerate(row.tokens):
 82        if row.word_ids[i] == -1:
 83            continue
 84
 85        if row.word_ids[i] != row.word_ids[i - 1]:
 86            tokens.append(tok.lstrip("▁").lstrip("##").rstrip("@@"))
 87        else:
 88            tokens[-1] += tok.lstrip("▁").lstrip("##").rstrip("@@")
 89    out["tokens"] = tokens
 90
 91    if "preds" in columns:
 92        out["preds"] = [row.preds[i] for i in indices]
 93
 94    if "labels" in columns:
 95        out["labels"] = [row.labels[i] for i in indices]
 96
 97    if "losses" in columns:
 98        out["losses"] = [row.losses[i] for i in indices]
 99
100    if "probs" in columns:
101        out["probs"] = [row.probs[i] for i in indices]
102
103    if "hidden_states" in columns:
104        out["hidden_states"] = [row.hidden_states[i] for i in indices]
105
106    if "ids" in columns:
107        out["ids"] = row.ids
108
109    assert len(tokens) == len(out["preds"]), (tokens, row.tokens)
110
111    return out
112
113
114@st.cache(
115    allow_output_mutation=True,
116    hash_funcs=tokenizer_hash_funcs,
117)
118def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
119    """Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities.
120
121    Args:
122        text (str): The text to be processed
123        tokenizer: Tokenizer to use
124        model (_type_): Model to use
125        device (torch.device): The device we want pytorch to use for its calcultaions.
126
127    Returns:
128        pd.DataFrame: A data frame holding the tagged text.
129    """
130
131    tokens = tokenizer(text).tokens()
132    tokenized = tokenizer(text, return_tensors="pt")
133    word_ids = [w if w is not None else -1 for w in tokenized.word_ids()]
134    input_ids = tokenized.input_ids.to(device)
135    outputs = model(input_ids, output_hidden_states=True)
136    preds = torch.argmax(outputs.logits, dim=2)
137    preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()]
138    hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy()
139    # hidden_states = np.mean([hidden_states, outputs.hidden_states[0][0].detach().cpu().numpy()], axis=0)
140
141    probs = 1 // (
142        torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy()
143    )
144
145    df = pd.DataFrame(
146        [[tokens, word_ids, preds, probs, hidden_states]],
147        columns="tokens word_ids preds probs hidden_states".split(),
148    )
149    merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
150    return explode_df(merged_df).reset_index().drop(columns=["index"])
151
152
153def get_bg_color(label: str):
154    """Retrieves a label's color from the session state."""
155    return st.session_state[f"color_{label}"]
156
157
158def get_fg_color(bg_color_hex: str) -> str:
159    """Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast.
160
161    Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/
162
163    Args:
164        bg_color_hex (str): The background color given as a HEX stirng.
165
166    Returns:
167        str: Either "black" or "white".
168    """
169    r = int(bg_color_hex[1:3], 16)
170    g = int(bg_color_hex[3:5], 16)
171    b = int(bg_color_hex[5:7], 16)
172    yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
173    return "black" if (yiq >= 128) else "white"
174
175
176def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
177    """Colorizes the errors in the dataframe."""
178
179    def colorize_row(row):
180        return [
181            "background-color: "
182            + ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink")
183            + ";"
184        ] * len(row)
185
186    def colorize_col(col):
187        if col.name == "labels" or col.name == "preds":
188            bgs = []
189            fgs = []
190            for v in col.values:
191                bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
192                fgs.append(get_fg_color(bgs[-1]))
193            return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)]
194        return [""] * len(col)
195
196    df = df.reset_index().drop(columns=["index"]).T
197    return df  # .style.apply(colorize_col, axis=0)
198
199
200def htmlify_labeled_example(example: pd.DataFrame) -> str:
201    """Builds an HTML (string) representation of a single example.
202
203    Args:
204        example (pd.DataFrame): The example to process.
205
206    Returns:
207        str: An HTML string representation of a single example.
208    """
209    html = []
210
211    for _, row in example.iterrows():
212        pred = row.preds.split("-")[1] if "-" in row.preds else "O"
213        label = row.labels
214        label_class = row.labels.split("-")[1] if "-" in row.labels else "O"
215
216        color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000"
217        true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000"
218
219        font_color = get_fg_color(color) if color else "white"
220        true_font_color = get_fg_color(true_color) if true_color else "white"
221
222        is_correct = row.preds == row.labels
223        loss_html = (
224            ""
225            if float(row.losses) < 0.01
226            else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>"
227        )
228        loss_html = ""
229
230        if row.labels == row.preds == "O":
231            html.append(f"<span>{row.tokens}</span>")
232        elif row.labels == "IGN":
233            assert False
234        else:
235            opacity = "1" if not is_correct else "0.5"
236            correct = (
237                ""
238                if is_correct
239                else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>"
240            )
241            pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else ""
242            html.append(
243                f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}"
244            )
245
246    return " ".join(html)
247
248
249def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
250    """Turns a value into a color using a color map."""
251    norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
252    cmap = cm.get_cmap(cmap_name)  # PiYG
253    rgba = cmap(norm(abs(value)))
254    color = matplotlib.colors.rgb2hex(rgba[:3])
255    return color
def aggrid_interactive_table(df: pandas.core.frame.DataFrame) -> dict:
36def aggrid_interactive_table(df: pd.DataFrame) -> dict:
37    """Creates an st-aggrid interactive table based on a dataframe.
38
39    Args:
40        df (pd.DataFrame]): Source dataframe
41    Returns:
42        dict: The selected row
43    """
44    options = GridOptionsBuilder.from_dataframe(
45        df, enableRowGroup=True, enableValue=True, enablePivot=True
46    )
47
48    options.configure_side_bar()
49    # options.configure_default_column(cellRenderer=JsCode('''function(params) {return '<a href="#samples-loss">'+params.value+'</a>'}'''))
50
51    options.configure_selection("single")
52    selection = AgGrid(
53        df,
54        enable_enterprise_modules=True,
55        gridOptions=options.build(),
56        theme="light",
57        update_mode=GridUpdateMode.NO_UPDATE,
58        allow_unsafe_jscode=True,
59    )
60
61    return selection

Creates an st-aggrid interactive table based on a dataframe.

Args
  • df (pd.DataFrame]): Source dataframe
Returns

dict: The selected row

def explode_df(df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame:
64def explode_df(df: pd.DataFrame) -> pd.DataFrame:
65    """Takes a dataframe and explodes all the fields."""
66
67    df_tokens = df.apply(pd.Series.explode)
68    if "losses" in df.columns:
69        df_tokens["losses"] = df_tokens["losses"].astype(float)
70    return df_tokens  # type: ignore

Takes a dataframe and explodes all the fields.

def align_sample(row: pandas.core.series.Series)
 73def align_sample(row: pd.Series):
 74    """Uses word_ids to align all lists in a sample."""
 75
 76    columns = row.axes[0].to_list()
 77    indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]]
 78
 79    out = {}
 80
 81    tokens = []
 82    for i, tok in enumerate(row.tokens):
 83        if row.word_ids[i] == -1:
 84            continue
 85
 86        if row.word_ids[i] != row.word_ids[i - 1]:
 87            tokens.append(tok.lstrip("▁").lstrip("##").rstrip("@@"))
 88        else:
 89            tokens[-1] += tok.lstrip("▁").lstrip("##").rstrip("@@")
 90    out["tokens"] = tokens
 91
 92    if "preds" in columns:
 93        out["preds"] = [row.preds[i] for i in indices]
 94
 95    if "labels" in columns:
 96        out["labels"] = [row.labels[i] for i in indices]
 97
 98    if "losses" in columns:
 99        out["losses"] = [row.losses[i] for i in indices]
100
101    if "probs" in columns:
102        out["probs"] = [row.probs[i] for i in indices]
103
104    if "hidden_states" in columns:
105        out["hidden_states"] = [row.hidden_states[i] for i in indices]
106
107    if "ids" in columns:
108        out["ids"] = row.ids
109
110    assert len(tokens) == len(out["preds"]), (tokens, row.tokens)
111
112    return out

Uses word_ids to align all lists in a sample.

@st.cache(allow_output_mutation=True, hash_funcs=tokenizer_hash_funcs)
def tag_text( text: str, tokenizer, model, device: torch.device) -> pandas.core.frame.DataFrame:
115@st.cache(
116    allow_output_mutation=True,
117    hash_funcs=tokenizer_hash_funcs,
118)
119def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame:
120    """Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities.
121
122    Args:
123        text (str): The text to be processed
124        tokenizer: Tokenizer to use
125        model (_type_): Model to use
126        device (torch.device): The device we want pytorch to use for its calcultaions.
127
128    Returns:
129        pd.DataFrame: A data frame holding the tagged text.
130    """
131
132    tokens = tokenizer(text).tokens()
133    tokenized = tokenizer(text, return_tensors="pt")
134    word_ids = [w if w is not None else -1 for w in tokenized.word_ids()]
135    input_ids = tokenized.input_ids.to(device)
136    outputs = model(input_ids, output_hidden_states=True)
137    preds = torch.argmax(outputs.logits, dim=2)
138    preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()]
139    hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy()
140    # hidden_states = np.mean([hidden_states, outputs.hidden_states[0][0].detach().cpu().numpy()], axis=0)
141
142    probs = 1 // (
143        torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy()
144    )
145
146    df = pd.DataFrame(
147        [[tokens, word_ids, preds, probs, hidden_states]],
148        columns="tokens word_ids preds probs hidden_states".split(),
149    )
150    merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist())
151    return explode_df(merged_df).reset_index().drop(columns=["index"])

Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities.

Args
  • text (str): The text to be processed
  • tokenizer: Tokenizer to use
  • model (_type_): Model to use
  • device (torch.device): The device we want pytorch to use for its calcultaions.
Returns

pd.DataFrame: A data frame holding the tagged text.

def get_bg_color(label: str)
154def get_bg_color(label: str):
155    """Retrieves a label's color from the session state."""
156    return st.session_state[f"color_{label}"]

Retrieves a label's color from the session state.

def get_fg_color(bg_color_hex: str) -> str:
159def get_fg_color(bg_color_hex: str) -> str:
160    """Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast.
161
162    Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/
163
164    Args:
165        bg_color_hex (str): The background color given as a HEX stirng.
166
167    Returns:
168        str: Either "black" or "white".
169    """
170    r = int(bg_color_hex[1:3], 16)
171    g = int(bg_color_hex[3:5], 16)
172    b = int(bg_color_hex[5:7], 16)
173    yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000
174    return "black" if (yiq >= 128) else "white"

Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast.

Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/

Args
  • bg_color_hex (str): The background color given as a HEX stirng.
Returns

str: Either "black" or "white".

def colorize_classes(df: pandas.core.frame.DataFrame) -> pandas.core.frame.DataFrame:
177def colorize_classes(df: pd.DataFrame) -> pd.DataFrame:
178    """Colorizes the errors in the dataframe."""
179
180    def colorize_row(row):
181        return [
182            "background-color: "
183            + ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink")
184            + ";"
185        ] * len(row)
186
187    def colorize_col(col):
188        if col.name == "labels" or col.name == "preds":
189            bgs = []
190            fgs = []
191            for v in col.values:
192                bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
193                fgs.append(get_fg_color(bgs[-1]))
194            return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)]
195        return [""] * len(col)
196
197    df = df.reset_index().drop(columns=["index"]).T
198    return df  # .style.apply(colorize_col, axis=0)

Colorizes the errors in the dataframe.

def htmlify_labeled_example(example: pandas.core.frame.DataFrame) -> str:
201def htmlify_labeled_example(example: pd.DataFrame) -> str:
202    """Builds an HTML (string) representation of a single example.
203
204    Args:
205        example (pd.DataFrame): The example to process.
206
207    Returns:
208        str: An HTML string representation of a single example.
209    """
210    html = []
211
212    for _, row in example.iterrows():
213        pred = row.preds.split("-")[1] if "-" in row.preds else "O"
214        label = row.labels
215        label_class = row.labels.split("-")[1] if "-" in row.labels else "O"
216
217        color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000"
218        true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000"
219
220        font_color = get_fg_color(color) if color else "white"
221        true_font_color = get_fg_color(true_color) if true_color else "white"
222
223        is_correct = row.preds == row.labels
224        loss_html = (
225            ""
226            if float(row.losses) < 0.01
227            else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>"
228        )
229        loss_html = ""
230
231        if row.labels == row.preds == "O":
232            html.append(f"<span>{row.tokens}</span>")
233        elif row.labels == "IGN":
234            assert False
235        else:
236            opacity = "1" if not is_correct else "0.5"
237            correct = (
238                ""
239                if is_correct
240                else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>"
241            )
242            pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else ""
243            html.append(
244                f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}"
245            )
246
247    return " ".join(html)

Builds an HTML (string) representation of a single example.

Args
  • example (pd.DataFrame): The example to process.
Returns

str: An HTML string representation of a single example.

def color_map_color(value: float, cmap_name='Set1', vmin=0, vmax=1) -> str:
250def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str:
251    """Turns a value into a color using a color map."""
252    norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)
253    cmap = cm.get_cmap(cmap_name)  # PiYG
254    rgba = cmap(norm(abs(value)))
255    color = matplotlib.colors.rgb2hex(rgba[:3])
256    return color

Turns a value into a color using a color map.