src.subpages.lossy_samples

Show every example sorted by loss (descending) for close inspection.

  1"""Show every example sorted by loss (descending) for close inspection."""
  2import pandas as pd
  3import streamlit as st
  4
  5from src.subpages.page import Context, Page
  6from src.utils import (
  7    colorize_classes,
  8    get_bg_color,
  9    get_fg_color,
 10    htmlify_labeled_example,
 11)
 12
 13
 14class LossySamplesPage(Page):
 15    name = "Samples by Loss"
 16    icon = "sort-numeric-down-alt"
 17
 18    def get_widget_defaults(self):
 19        return {
 20            "skip_correct": True,
 21            "samples_by_loss_show_df": True,
 22        }
 23
 24    def render(self, context: Context):
 25        st.title(self.name)
 26        with st.expander("💡", expanded=True):
 27            st.write("Show every example sorted by loss (descending) for close inspection.")
 28            st.write(
 29                "The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss."
 30            )
 31            st.write(
 32                "The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample."
 33            )
 34            st.write(
 35                "The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class."
 36            )
 37
 38        st.subheader("💥 Samples ⬇loss")
 39        skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
 40        show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")
 41
 42        st.write(
 43            """<style>
 44thead {
 45    display: none;
 46}
 47td {
 48    white-space: nowrap;
 49    padding: 0 5px !important;
 50}
 51</style>""",
 52            unsafe_allow_html=True,
 53        )
 54
 55        top_indices = (
 56            context.df.sort_values(by="total_loss", ascending=False)
 57            .query("total_loss > 0.5")
 58            .index
 59        )
 60
 61        cnt = 0
 62        for idx in top_indices:
 63            sample = context.df_tokens_merged.loc[idx]
 64
 65            if isinstance(sample, pd.Series):
 66                continue
 67
 68            if skip_correct and sum(sample.labels != sample.preds) == 0:
 69                continue
 70
 71            if show_df:
 72
 73                def colorize_col(col):
 74                    if col.name == "labels" or col.name == "preds":
 75                        bgs = []
 76                        fgs = []
 77                        ops = []
 78                        for v in col.values:
 79                            bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
 80                            fgs.append(get_fg_color(bgs[-1]))
 81                            ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
 82                        return [
 83                            f"background-color: {bg}; color: {fg}; opacity: {op};"
 84                            for bg, fg, op in zip(bgs, fgs, ops)
 85                        ]
 86                    return [""] * len(col)
 87
 88                df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
 89                losses_slice = pd.IndexSlice["losses", :]
 90                # x = df.T.astype(str)
 91                # st.dataframe(x)
 92                # st.dataframe(x.loc[losses_slice])
 93                styler = (
 94                    df.T.style.apply(colorize_col, axis=1)
 95                    .bar(subset=losses_slice, axis=1)
 96                    .format(precision=3)
 97                )
 98                # styler.data = styler.data.astype(str)
 99                st.write(styler.to_html(), unsafe_allow_html=True)
100                st.write("")
101                # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses'))  # type: ignore
102                # st.write(
103                #     colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
104                # )
105
106            col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])
107
108            cnt += 1
109            counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
110            loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
111            col1.write(f"{counter}{loss}", unsafe_allow_html=True)
112            col1.write("")
113
114            col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
115            # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)
class LossySamplesPage(src.subpages.page.Page):
 15class LossySamplesPage(Page):
 16    name = "Samples by Loss"
 17    icon = "sort-numeric-down-alt"
 18
 19    def get_widget_defaults(self):
 20        return {
 21            "skip_correct": True,
 22            "samples_by_loss_show_df": True,
 23        }
 24
 25    def render(self, context: Context):
 26        st.title(self.name)
 27        with st.expander("💡", expanded=True):
 28            st.write("Show every example sorted by loss (descending) for close inspection.")
 29            st.write(
 30                "The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss."
 31            )
 32            st.write(
 33                "The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample."
 34            )
 35            st.write(
 36                "The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class."
 37            )
 38
 39        st.subheader("💥 Samples ⬇loss")
 40        skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
 41        show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")
 42
 43        st.write(
 44            """<style>
 45thead {
 46    display: none;
 47}
 48td {
 49    white-space: nowrap;
 50    padding: 0 5px !important;
 51}
 52</style>""",
 53            unsafe_allow_html=True,
 54        )
 55
 56        top_indices = (
 57            context.df.sort_values(by="total_loss", ascending=False)
 58            .query("total_loss > 0.5")
 59            .index
 60        )
 61
 62        cnt = 0
 63        for idx in top_indices:
 64            sample = context.df_tokens_merged.loc[idx]
 65
 66            if isinstance(sample, pd.Series):
 67                continue
 68
 69            if skip_correct and sum(sample.labels != sample.preds) == 0:
 70                continue
 71
 72            if show_df:
 73
 74                def colorize_col(col):
 75                    if col.name == "labels" or col.name == "preds":
 76                        bgs = []
 77                        fgs = []
 78                        ops = []
 79                        for v in col.values:
 80                            bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
 81                            fgs.append(get_fg_color(bgs[-1]))
 82                            ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
 83                        return [
 84                            f"background-color: {bg}; color: {fg}; opacity: {op};"
 85                            for bg, fg, op in zip(bgs, fgs, ops)
 86                        ]
 87                    return [""] * len(col)
 88
 89                df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
 90                losses_slice = pd.IndexSlice["losses", :]
 91                # x = df.T.astype(str)
 92                # st.dataframe(x)
 93                # st.dataframe(x.loc[losses_slice])
 94                styler = (
 95                    df.T.style.apply(colorize_col, axis=1)
 96                    .bar(subset=losses_slice, axis=1)
 97                    .format(precision=3)
 98                )
 99                # styler.data = styler.data.astype(str)
100                st.write(styler.to_html(), unsafe_allow_html=True)
101                st.write("")
102                # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses'))  # type: ignore
103                # st.write(
104                #     colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
105                # )
106
107            col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])
108
109            cnt += 1
110            counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
111            loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
112            col1.write(f"{counter}{loss}", unsafe_allow_html=True)
113            col1.write("")
114
115            col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
116            # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)

Base class for all pages.

LossySamplesPage()
name: str = 'Samples by Loss'
icon: str = 'sort-numeric-down-alt'
def get_widget_defaults(self)
19    def get_widget_defaults(self):
20        return {
21            "skip_correct": True,
22            "samples_by_loss_show_df": True,
23        }

This function holds the default settings for all the page's widgets.

Returns

dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.

def render(self, context: src.subpages.page.Context)
 25    def render(self, context: Context):
 26        st.title(self.name)
 27        with st.expander("💡", expanded=True):
 28            st.write("Show every example sorted by loss (descending) for close inspection.")
 29            st.write(
 30                "The **dataframe** is mostly self-explanatory. The cells are color-coded by label, a lighter color signifies a continuation label. Cells in the loss row are filled red from left to right relative to the top loss."
 31            )
 32            st.write(
 33                "The **numbers to the left**: Top (black background) are sample number (listed here) and sample index (from the dataset). Below on yellow background is the total loss for the given sample."
 34            )
 35            st.write(
 36                "The **annotated sample**: Every predicted entity (every token, really) gets a black border. The text color signifies the predicted label, with the first token of a sequence of token also showing the label's icon. If (and only if) the prediction is wrong, a small little box after the entity (token) contains the correct target class, with a background color corresponding to that class."
 37            )
 38
 39        st.subheader("💥 Samples ⬇loss")
 40        skip_correct = st.checkbox("Skip correct examples", value=True, key="skip_correct")
 41        show_df = st.checkbox("Show dataframes", key="samples_by_loss_show_df")
 42
 43        st.write(
 44            """<style>
 45thead {
 46    display: none;
 47}
 48td {
 49    white-space: nowrap;
 50    padding: 0 5px !important;
 51}
 52</style>""",
 53            unsafe_allow_html=True,
 54        )
 55
 56        top_indices = (
 57            context.df.sort_values(by="total_loss", ascending=False)
 58            .query("total_loss > 0.5")
 59            .index
 60        )
 61
 62        cnt = 0
 63        for idx in top_indices:
 64            sample = context.df_tokens_merged.loc[idx]
 65
 66            if isinstance(sample, pd.Series):
 67                continue
 68
 69            if skip_correct and sum(sample.labels != sample.preds) == 0:
 70                continue
 71
 72            if show_df:
 73
 74                def colorize_col(col):
 75                    if col.name == "labels" or col.name == "preds":
 76                        bgs = []
 77                        fgs = []
 78                        ops = []
 79                        for v in col.values:
 80                            bgs.append(get_bg_color(v.split("-")[1]) if "-" in v else "#ffffff")
 81                            fgs.append(get_fg_color(bgs[-1]))
 82                            ops.append("1" if v.split("-")[0] == "B" or v == "O" else "0.5")
 83                        return [
 84                            f"background-color: {bg}; color: {fg}; opacity: {op};"
 85                            for bg, fg, op in zip(bgs, fgs, ops)
 86                        ]
 87                    return [""] * len(col)
 88
 89                df = sample.reset_index().drop(["index", "hidden_states", "ids"], axis=1).round(3)
 90                losses_slice = pd.IndexSlice["losses", :]
 91                # x = df.T.astype(str)
 92                # st.dataframe(x)
 93                # st.dataframe(x.loc[losses_slice])
 94                styler = (
 95                    df.T.style.apply(colorize_col, axis=1)
 96                    .bar(subset=losses_slice, axis=1)
 97                    .format(precision=3)
 98                )
 99                # styler.data = styler.data.astype(str)
100                st.write(styler.to_html(), unsafe_allow_html=True)
101                st.write("")
102                # st.dataframe(colorize_classes(sample.drop("hidden_states", axis=1)))#.bar(subset='losses'))  # type: ignore
103                # st.write(
104                #     colorize_errors(sample.round(3).drop("hidden_states", axis=1).astype(str))
105                # )
106
107            col1, _, col2 = st.columns([3.5 / 32, 0.5 / 32, 28 / 32])
108
109            cnt += 1
110            counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: white; padding: 0 5px'>[{cnt} | {idx}]</span>"
111            loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>"
112            col1.write(f"{counter}{loss}", unsafe_allow_html=True)
113            col1.write("")
114
115            col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
116            # st.write(f"[{i};{idx}] " + htmlify_corr_sample(sample), unsafe_allow_html=True)

This function renders the page.