src.subpages.lossy_samples
Show every example sorted by loss (descending) for close inspection.
1"""Show every example sorted by loss (descending) for close inspection.""" 2import pandas as pd 3import streamlit as st 4 5from src.subpages.page import Context, Page 6from src.utils import ( 7 colorize_classes, 8 get_bg_color, 9 get_fg_color, 10 htmlify_labeled_example, 11) 12 13 14class LossySamplesPage(Page): 15 name = "Samples by Loss" 16 icon = "sort-numeric-down-alt" 17 18 def get_widget_defaults(self): 19 return { 20 "skip_correct": True, 21 "samples_by_loss_show_df": True, 22 } 23 24 def render(self, context: Context): 25 st.title(self.name) 26 with st.expander("💡", expanded=True): 27 st.write("Show every example sorted by loss (descending) for close inspection.") 28 st.write( 29 "The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss." 30 ) 31 st.write( 32 "The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample." 33 ) 34 st.write( 35 "The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class." 36 ) 37 38 st.subheader("💥 Samples ⬇loss") 39 skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct") 40 show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df") 41 42 st.write( 43 """<style> 44thead { 45 display: none; 46} 47td { 48 white-space: nowrap; 49 padding: 0 5px !important; 50} 51</style>""", 52 unsafe_allow_html=True, 53 ) 54 55 top_indices = ( 56 context.df.sort_values(by="total_loss", ascending=False) 57 .query("total_loss > 0.5") 58 .index 59 ) 60 61 cnt = 0 62 for idx in top_indices: 63 sample = context.df_tokens_merged.loc[idx] 64 65 if isinstance(sample, pd.Series): 66 continue 67 68 if skip_correct and sum(sample.labels != sample.preds) == 0: 69 continue 70 71 if show_df: 72 73 def colorize_col(col): 74 if col.name == "labels" or col.name == "preds": 75 bgs = [] 76 fgs = [] 77 ops = [] 78 for v in col.values: 79 bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff") 80 fgs.append(get_fg_color(bgs[-1])) 81 ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5") 82 return [ 83 f"background-color: {bg}; color: {fg}; opacity: {op};" 84 for bg, fg, op in zip(bgs, fgs, ops) 85 ] 86 return [""] * len(col) 87 88 df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3) 89 losses_slice = pd.IndexSlice["losses", :] 90 # x = df.T.astype(str) 91 # st.dataframe(x) 92 # st.dataframe(x.loc[losses_slice]) 93 styler = ( 94 df.T.style.apply(colorize_col, axis=1) 95 .bar(subset=losses_slice, axis=1) 96 .format(precision=3) 97 ) 98 # styler.data = styler.data.astype(str) 99 st.write(styler.to_html(), unsafe_allow_html=True) 100 st.write("") 101 # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses')) # type: ignore 102 # st.write( 103 # colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str)) 104 # ) 105 106 col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32]) 107 108 cnt += 1 109 counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>" 110 loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>" 111 col1.write(f"{counter}{loss}", unsafe_allow_html=True) 112 col1.write("") 113 114 col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True) 115 # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
15class LossySamplesPage(Page): 16 name = "Samples by Loss" 17 icon = "sort-numeric-down-alt" 18 19 def get_widget_defaults(self): 20 return { 21 "skip_correct": True, 22 "samples_by_loss_show_df": True, 23 } 24 25 def render(self, context: Context): 26 st.title(self.name) 27 with st.expander("💡", expanded=True): 28 st.write("Show every example sorted by loss (descending) for close inspection.") 29 st.write( 30 "The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss." 31 ) 32 st.write( 33 "The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample." 34 ) 35 st.write( 36 "The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class." 37 ) 38 39 st.subheader("💥 Samples ⬇loss") 40 skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct") 41 show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df") 42 43 st.write( 44 """<style> 45thead { 46 display: none; 47} 48td { 49 white-space: nowrap; 50 padding: 0 5px !important; 51} 52</style>""", 53 unsafe_allow_html=True, 54 ) 55 56 top_indices = ( 57 context.df.sort_values(by="total_loss", ascending=False) 58 .query("total_loss > 0.5") 59 .index 60 ) 61 62 cnt = 0 63 for idx in top_indices: 64 sample = context.df_tokens_merged.loc[idx] 65 66 if isinstance(sample, pd.Series): 67 continue 68 69 if skip_correct and sum(sample.labels != sample.preds) == 0: 70 continue 71 72 if show_df: 73 74 def colorize_col(col): 75 if col.name == "labels" or col.name == "preds": 76 bgs = [] 77 fgs = [] 78 ops = [] 79 for v in col.values: 80 bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff") 81 fgs.append(get_fg_color(bgs[-1])) 82 ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5") 83 return [ 84 f"background-color: {bg}; color: {fg}; opacity: {op};" 85 for bg, fg, op in zip(bgs, fgs, ops) 86 ] 87 return [""] * len(col) 88 89 df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3) 90 losses_slice = pd.IndexSlice["losses", :] 91 # x = df.T.astype(str) 92 # st.dataframe(x) 93 # st.dataframe(x.loc[losses_slice]) 94 styler = ( 95 df.T.style.apply(colorize_col, axis=1) 96 .bar(subset=losses_slice, axis=1) 97 .format(precision=3) 98 ) 99 # styler.data = styler.data.astype(str) 100 st.write(styler.to_html(), unsafe_allow_html=True) 101 st.write("") 102 # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses')) # type: ignore 103 # st.write( 104 # colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str)) 105 # ) 106 107 col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32]) 108 109 cnt += 1 110 counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>" 111 loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>" 112 col1.write(f"{counter}{loss}", unsafe_allow_html=True) 113 col1.write("") 114 115 col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True) 116 # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
Base class for all pages.
def
get_widget_defaults(self)
19 def get_widget_defaults(self): 20 return { 21 "skip_correct": True, 22 "samples_by_loss_show_df": True, 23 }
This function holds the default settings for all the page's widgets.
Returns
dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.
25 def render(self, context: Context): 26 st.title(self.name) 27 with st.expander("💡", expanded=True): 28 st.write("Show every example sorted by loss (descending) for close inspection.") 29 st.write( 30 "The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss." 31 ) 32 st.write( 33 "The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample." 34 ) 35 st.write( 36 "The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class." 37 ) 38 39 st.subheader("💥 Samples ⬇loss") 40 skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct") 41 show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df") 42 43 st.write( 44 """<style> 45thead { 46 display: none; 47} 48td { 49 white-space: nowrap; 50 padding: 0 5px !important; 51} 52</style>""", 53 unsafe_allow_html=True, 54 ) 55 56 top_indices = ( 57 context.df.sort_values(by="total_loss", ascending=False) 58 .query("total_loss > 0.5") 59 .index 60 ) 61 62 cnt = 0 63 for idx in top_indices: 64 sample = context.df_tokens_merged.loc[idx] 65 66 if isinstance(sample, pd.Series): 67 continue 68 69 if skip_correct and sum(sample.labels != sample.preds) == 0: 70 continue 71 72 if show_df: 73 74 def colorize_col(col): 75 if col.name == "labels" or col.name == "preds": 76 bgs = [] 77 fgs = [] 78 ops = [] 79 for v in col.values: 80 bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff") 81 fgs.append(get_fg_color(bgs[-1])) 82 ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5") 83 return [ 84 f"background-color: {bg}; color: {fg}; opacity: {op};" 85 for bg, fg, op in zip(bgs, fgs, ops) 86 ] 87 return [""] * len(col) 88 89 df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3) 90 losses_slice = pd.IndexSlice["losses", :] 91 # x = df.T.astype(str) 92 # st.dataframe(x) 93 # st.dataframe(x.loc[losses_slice]) 94 styler = ( 95 df.T.style.apply(colorize_col, axis=1) 96 .bar(subset=losses_slice, axis=1) 97 .format(precision=3) 98 ) 99 # styler.data = styler.data.astype(str) 100 st.write(styler.to_html(), unsafe_allow_html=True) 101 st.write("") 102 # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses')) # type: ignore 103 # st.write( 104 # colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str)) 105 # ) 106 107 col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32]) 108 109 cnt += 1 110 counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>" 111 loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>" 112 col1.write(f"{counter}{loss}", unsafe_allow_html=True) 113 col1.write("") 114 115 col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True) 116 # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
This function renders the page.