src.utils
1from pathlib import Path 2 3import matplotlib as matplotlib 4import matplotlib.cm as cm 5import pandas as pd 6import streamlit as st 7import tokenizers 8import torch 9import torch.nn.functional as F 10from st_aggrid import AgGrid, GridOptionsBuilder, GridUpdateMode 11 12PROJ = Path(__file__).parent 13 14tokenizer_hash_funcs = { 15 tokenizers.Tokenizer: lambda _: None, 16 tokenizers.AddedToken: lambda _: None, 17} 18# device = torch.device("cuda" if torch.cuda.is_available() else "cpu" if torch.has_mps else "cpu") 19device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 21classmap = { 22 "O": "O", 23 "PER": "🙎", 24 "person": "🙎", 25 "LOC": "🌎", 26 "location": "🌎", 27 "ORG": "🏤", 28 "corporation": "🏤", 29 "product": "📱", 30 "creative": "🎷", 31 "MISC": "🎷", 32} 33 34 35def aggrid_interactive_table(df: pd.DataFrame) -> dict: 36 """Creates an st-aggrid interactive table based on a dataframe. 37 38 Args: 39 df (pd.DataFrame]): Source dataframe 40 Returns: 41 dict: The selected row 42 """ 43 options = GridOptionsBuilder.from_dataframe( 44 df, enableRowGroup=True, enableValue=True, enablePivot=True 45 ) 46 47 options.configure_side_bar() 48 # options.configure_default_column(cellRenderer=JsCode('''function(params) {return '<a href="#samples-loss">'+params.value+'</a>'}''')) 49 50 options.configure_selection("single") 51 selection = AgGrid( 52 df, 53 enable_enterprise_modules=True, 54 gridOptions=options.build(), 55 theme="light", 56 update_mode=GridUpdateMode.NO_UPDATE, 57 allow_unsafe_jscode=True, 58 ) 59 60 return selection 61 62 63def explode_df(df: pd.DataFrame) -> pd.DataFrame: 64 """Takes a dataframe and explodes all the fields.""" 65 66 df_tokens = df.apply(pd.Series.explode) 67 if "losses" in df.columns: 68 df_tokens["losses"] = df_tokens["losses"].astype(float) 69 return df_tokens # type: ignore 70 71 72def align_sample(row: pd.Series): 73 """Uses word_ids to align all lists in a sample.""" 74 75 columns = row.axes[0].to_list() 76 indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]] 77 78 out = {} 79 80 tokens = [] 81 for i, tok in enumerate(row.tokens): 82 if row.word_ids[i] == -1: 83 continue 84 85 if row.word_ids[i] != row.word_ids[i - 1]: 86 tokens.append(tok.lstrip("▁").lstrip("##").rstrip("@@")) 87 else: 88 tokens[-1] += tok.lstrip("▁").lstrip("##").rstrip("@@") 89 out["tokens"] = tokens 90 91 if "preds" in columns: 92 out["preds"] = [row.preds[i] for i in indices] 93 94 if "labels" in columns: 95 out["labels"] = [row.labels[i] for i in indices] 96 97 if "losses" in columns: 98 out["losses"] = [row.losses[i] for i in indices] 99 100 if "probs" in columns: 101 out["probs"] = [row.probs[i] for i in indices] 102 103 if "hidden_states" in columns: 104 out["hidden_states"] = [row.hidden_states[i] for i in indices] 105 106 if "ids" in columns: 107 out["ids"] = row.ids 108 109 assert len(tokens) == len(out["preds"]), (tokens, row.tokens) 110 111 return out 112 113 114@st.cache( 115 allow_output_mutation=True, 116 hash_funcs=tokenizer_hash_funcs, 117) 118def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame: 119 """Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities. 120 121 Args: 122 text (str): The text to be processed 123 tokenizer: Tokenizer to use 124 model (_type_): Model to use 125 device (torch.device): The device we want pytorch to use for its calcultaions. 126 127 Returns: 128 pd.DataFrame: A data frame holding the tagged text. 129 """ 130 131 tokens = tokenizer(text).tokens() 132 tokenized = tokenizer(text, return_tensors="pt") 133 word_ids = [w if w is not None else -1 for w in tokenized.word_ids()] 134 input_ids = tokenized.input_ids.to(device) 135 outputs = model(input_ids, output_hidden_states=True) 136 preds = torch.argmax(outputs.logits, dim=2) 137 preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()] 138 hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy() 139 # hidden_states = np.mean([hidden_states, outputs.hidden_states[0][0].detach().cpu().numpy()], axis=0) 140 141 probs = 1 // ( 142 torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy() 143 ) 144 145 df = pd.DataFrame( 146 [[tokens, word_ids, preds, probs, hidden_states]], 147 columns="tokens word_ids preds probs hidden_states".split(), 148 ) 149 merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist()) 150 return explode_df(merged_df).reset_index().drop(columns=["index"]) 151 152 153def get_bg_color(label: str): 154 """Retrieves a label's color from the session state.""" 155 return st.session_state[f"color_{label}"] 156 157 158def get_fg_color(bg_color_hex: str) -> str: 159 """Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast. 160 161 Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/ 162 163 Args: 164 bg_color_hex (str): The background color given as a HEX stirng. 165 166 Returns: 167 str: Either "black" or "white". 168 """ 169 r = int(bg_color_hex[1:3], 16) 170 g = int(bg_color_hex[3:5], 16) 171 b = int(bg_color_hex[5:7], 16) 172 yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000 173 return "black" if (yiq >= 128) else "white" 174 175 176def colorize_classes(df: pd.DataFrame) -> pd.DataFrame: 177 """Colorizes the errors in the dataframe.""" 178 179 def colorize_row(row): 180 return [ 181 "background-color: " 182 + ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink") 183 + ";" 184 ] * len(row) 185 186 def colorize_col(col): 187 if col.name == "labels" or col.name == "preds": 188 bgs = [] 189 fgs = [] 190 for v in col.values: 191 bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff") 192 fgs.append(get_fg_color(bgs[-1])) 193 return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)] 194 return [""] * len(col) 195 196 df = df.reset_index().drop(columns=["index"]).T 197 return df # .style.apply(colorize_col, axis=0) 198 199 200def htmlify_labeled_example(example: pd.DataFrame) -> str: 201 """Builds an HTML (string) representation of a single example. 202 203 Args: 204 example (pd.DataFrame): The example to process. 205 206 Returns: 207 str: An HTML string representation of a single example. 208 """ 209 html = [] 210 211 for _, row in example.iterrows(): 212 pred = row.preds.split("-")[1] if "-" in row.preds else "O" 213 label = row.labels 214 label_class = row.labels.split("-")[1] if "-" in row.labels else "O" 215 216 color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000" 217 true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000" 218 219 font_color = get_fg_color(color) if color else "white" 220 true_font_color = get_fg_color(true_color) if true_color else "white" 221 222 is_correct = row.preds == row.labels 223 loss_html = ( 224 "" 225 if float(row.losses) < 0.01 226 else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>" 227 ) 228 loss_html = "" 229 230 if row.labels == row.preds == "O": 231 html.append(f"<span>{row.tokens}</span>") 232 elif row.labels == "IGN": 233 assert False 234 else: 235 opacity = "1" if not is_correct else "0.5" 236 correct = ( 237 "" 238 if is_correct 239 else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>" 240 ) 241 pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else "" 242 html.append( 243 f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}" 244 ) 245 246 return " ".join(html) 247 248 249def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str: 250 """Turns a value into a color using a color map.""" 251 norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) 252 cmap = cm.get_cmap(cmap_name) # PiYG 253 rgba = cmap(norm(abs(value))) 254 color = matplotlib.colors.rgb2hex(rgba[:3]) 255 return color
36def aggrid_interactive_table(df: pd.DataFrame) -> dict: 37 """Creates an st-aggrid interactive table based on a dataframe. 38 39 Args: 40 df (pd.DataFrame]): Source dataframe 41 Returns: 42 dict: The selected row 43 """ 44 options = GridOptionsBuilder.from_dataframe( 45 df, enableRowGroup=True, enableValue=True, enablePivot=True 46 ) 47 48 options.configure_side_bar() 49 # options.configure_default_column(cellRenderer=JsCode('''function(params) {return '<a href="#samples-loss">'+params.value+'</a>'}''')) 50 51 options.configure_selection("single") 52 selection = AgGrid( 53 df, 54 enable_enterprise_modules=True, 55 gridOptions=options.build(), 56 theme="light", 57 update_mode=GridUpdateMode.NO_UPDATE, 58 allow_unsafe_jscode=True, 59 ) 60 61 return selection
Creates an st-aggrid interactive table based on a dataframe.
Args
- df (pd.DataFrame]): Source dataframe
Returns
dict: The selected row
64def explode_df(df: pd.DataFrame) -> pd.DataFrame: 65 """Takes a dataframe and explodes all the fields.""" 66 67 df_tokens = df.apply(pd.Series.explode) 68 if "losses" in df.columns: 69 df_tokens["losses"] = df_tokens["losses"].astype(float) 70 return df_tokens # type: ignore
Takes a dataframe and explodes all the fields.
73def align_sample(row: pd.Series): 74 """Uses word_ids to align all lists in a sample.""" 75 76 columns = row.axes[0].to_list() 77 indices = [i for i, id in enumerate(row.word_ids) if id >= 0 and id != row.word_ids[i - 1]] 78 79 out = {} 80 81 tokens = [] 82 for i, tok in enumerate(row.tokens): 83 if row.word_ids[i] == -1: 84 continue 85 86 if row.word_ids[i] != row.word_ids[i - 1]: 87 tokens.append(tok.lstrip("▁").lstrip("##").rstrip("@@")) 88 else: 89 tokens[-1] += tok.lstrip("▁").lstrip("##").rstrip("@@") 90 out["tokens"] = tokens 91 92 if "preds" in columns: 93 out["preds"] = [row.preds[i] for i in indices] 94 95 if "labels" in columns: 96 out["labels"] = [row.labels[i] for i in indices] 97 98 if "losses" in columns: 99 out["losses"] = [row.losses[i] for i in indices] 100 101 if "probs" in columns: 102 out["probs"] = [row.probs[i] for i in indices] 103 104 if "hidden_states" in columns: 105 out["hidden_states"] = [row.hidden_states[i] for i in indices] 106 107 if "ids" in columns: 108 out["ids"] = row.ids 109 110 assert len(tokens) == len(out["preds"]), (tokens, row.tokens) 111 112 return out
Uses word_ids to align all lists in a sample.
115@st.cache( 116 allow_output_mutation=True, 117 hash_funcs=tokenizer_hash_funcs, 118) 119def tag_text(text: str, tokenizer, model, device: torch.device) -> pd.DataFrame: 120 """Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities. 121 122 Args: 123 text (str): The text to be processed 124 tokenizer: Tokenizer to use 125 model (_type_): Model to use 126 device (torch.device): The device we want pytorch to use for its calcultaions. 127 128 Returns: 129 pd.DataFrame: A data frame holding the tagged text. 130 """ 131 132 tokens = tokenizer(text).tokens() 133 tokenized = tokenizer(text, return_tensors="pt") 134 word_ids = [w if w is not None else -1 for w in tokenized.word_ids()] 135 input_ids = tokenized.input_ids.to(device) 136 outputs = model(input_ids, output_hidden_states=True) 137 preds = torch.argmax(outputs.logits, dim=2) 138 preds = [model.config.id2label[p] for p in preds[0].cpu().numpy()] 139 hidden_states = outputs.hidden_states[-1][0].detach().cpu().numpy() 140 # hidden_states = np.mean([hidden_states, outputs.hidden_states[0][0].detach().cpu().numpy()], axis=0) 141 142 probs = 1 // ( 143 torch.min(F.softmax(outputs.logits, dim=-1), dim=-1).values[0].detach().cpu().numpy() 144 ) 145 146 df = pd.DataFrame( 147 [[tokens, word_ids, preds, probs, hidden_states]], 148 columns="tokens word_ids preds probs hidden_states".split(), 149 ) 150 merged_df = pd.DataFrame(df.apply(align_sample, axis=1).tolist()) 151 return explode_df(merged_df).reset_index().drop(columns=["index"])
Tags a given text and creates an (exploded) DataFrame with the predicted labels and probabilities.
Args
- text (str): The text to be processed
- tokenizer: Tokenizer to use
- model (_type_): Model to use
- device (torch.device): The device we want pytorch to use for its calcultaions.
Returns
pd.DataFrame: A data frame holding the tagged text.
154def get_bg_color(label: str): 155 """Retrieves a label's color from the session state.""" 156 return st.session_state[f"color_{label}"]
Retrieves a label's color from the session state.
159def get_fg_color(bg_color_hex: str) -> str: 160 """Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast. 161 162 Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/ 163 164 Args: 165 bg_color_hex (str): The background color given as a HEX stirng. 166 167 Returns: 168 str: Either "black" or "white". 169 """ 170 r = int(bg_color_hex[1:3], 16) 171 g = int(bg_color_hex[3:5], 16) 172 b = int(bg_color_hex[5:7], 16) 173 yiq = ((r * 299) + (g * 587) + (b * 114)) / 1000 174 return "black" if (yiq >= 128) else "white"
Chooses the proper (foreground) text color (black/white) for a given background color, maximizing contrast.
Adapted from https://gomakethings.com/dynamically-changing-the-text-color-based-on-background-color-contrast-with-vanilla-js/
Args
- bg_color_hex (str): The background color given as a HEX stirng.
Returns
str: Either "black" or "white".
177def colorize_classes(df: pd.DataFrame) -> pd.DataFrame: 178 """Colorizes the errors in the dataframe.""" 179 180 def colorize_row(row): 181 return [ 182 "background-color: " 183 + ("white" if (row["labels"] == "IGN" or (row["preds"] == row["labels"])) else "pink") 184 + ";" 185 ] * len(row) 186 187 def colorize_col(col): 188 if col.name == "labels" or col.name == "preds": 189 bgs = [] 190 fgs = [] 191 for v in col.values: 192 bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff") 193 fgs.append(get_fg_color(bgs[-1])) 194 return [f"background-color: {bg}; color: {fg};" for bg, fg in zip(bgs, fgs)] 195 return [""] * len(col) 196 197 df = df.reset_index().drop(columns=["index"]).T 198 return df # .style.apply(colorize_col, axis=0)
Colorizes the errors in the dataframe.
201def htmlify_labeled_example(example: pd.DataFrame) -> str: 202 """Builds an HTML (string) representation of a single example. 203 204 Args: 205 example (pd.DataFrame): The example to process. 206 207 Returns: 208 str: An HTML string representation of a single example. 209 """ 210 html = [] 211 212 for _, row in example.iterrows(): 213 pred = row.preds.split("-")[1] if "-" in row.preds else "O" 214 label = row.labels 215 label_class = row.labels.split("-")[1] if "-" in row.labels else "O" 216 217 color = get_bg_color(row.preds.split("-")[1]) if "-" in row.preds else "#000000" 218 true_color = get_bg_color(row.labels.split("-")[1]) if "-" in row.labels else "#000000" 219 220 font_color = get_fg_color(color) if color else "white" 221 true_font_color = get_fg_color(true_color) if true_color else "white" 222 223 is_correct = row.preds == row.labels 224 loss_html = ( 225 "" 226 if float(row.losses) < 0.01 227 else f"<span style='background-color: yellow; color: font_color; padding: 0 5px;'>{row.losses:.3f}</span>" 228 ) 229 loss_html = "" 230 231 if row.labels == row.preds == "O": 232 html.append(f"<span>{row.tokens}</span>") 233 elif row.labels == "IGN": 234 assert False 235 else: 236 opacity = "1" if not is_correct else "0.5" 237 correct = ( 238 "" 239 if is_correct 240 else f"<span title='{label}' style='background-color: {true_color}; opacity: 1; color: {true_font_color}; padding: 0 5px; border: 1px solid black; min-width: 30px'>{classmap[label_class]}</span>" 241 ) 242 pred_icon = classmap[pred] if pred != "O" and row.preds[:2] != "I-" else "" 243 html.append( 244 f"<span style='border: 1px solid black; color: {color}; padding: 0 5px;' title={row.preds}>{pred_icon + ' '}{row.tokens}</span>{correct}{loss_html}" 245 ) 246 247 return " ".join(html)
Builds an HTML (string) representation of a single example.
Args
- example (pd.DataFrame): The example to process.
Returns
str: An HTML string representation of a single example.
250def color_map_color(value: float, cmap_name="Set1", vmin=0, vmax=1) -> str: 251 """Turns a value into a color using a color map.""" 252 norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) 253 cmap = cm.get_cmap(cmap_name) # PiYG 254 rgba = cmap(norm(abs(value))) 255 color = matplotlib.colors.rgb2hex(rgba[:3]) 256 return color
Turns a value into a color using a color map.