src.subpages.losses

Show count, mean and median loss per token and label.

 1"""Show count, mean and median loss per token and label."""
 2import streamlit as st
 3
 4from src.subpages.page import Context, Page
 5from src.utils import AgGrid, aggrid_interactive_table
 6
 7
 8@st.cache
 9def get_loss_by_token(df_tokens):
10    return (
11        df_tokens.groupby("tokens")[["losses"]]
12        .agg(["count", "mean", "median", "sum"])
13        .droplevel(level=0, axis=1)  # Get rid of multi-level columns
14        .sort_values(by="sum", ascending=False)
15        .reset_index()
16    )
17
18
19@st.cache
20def get_loss_by_label(df_tokens):
21    return (
22        df_tokens.groupby("labels")[["losses"]]
23        .agg(["count", "mean", "median", "sum"])
24        .droplevel(level=0, axis=1)
25        .sort_values(by="mean", ascending=False)
26        .reset_index()
27    )
28
29
30class LossesPage(Page):
31    name = "Loss by Token/Label"
32    icon = "sort-alpha-down"
33
34    def render(self, context: Context):
35        st.title(self.name)
36        with st.expander("💡", expanded=True):
37            st.write("Show count, mean and median loss per token and label.")
38            st.write(
39                "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues."
40            )
41
42        col1, _, col2 = st.columns([8, 1, 6])
43
44        with col1:
45            st.subheader("💬 Loss by Token")
46
47            st.session_state["_merge_tokens"] = st.checkbox(
48                "Merge tokens", value=True, key="merge_tokens"
49            )
50            loss_by_token = (
51                get_loss_by_token(context.df_tokens_merged)
52                if st.session_state["merge_tokens"]
53                else get_loss_by_token(context.df_tokens_cleaned)
54            )
55            aggrid_interactive_table(loss_by_token.round(3))
56            # st.subheader("🏷️ Loss by Label")
57            # loss_by_label = get_loss_by_label(df_tokens_cleaned)
58            # st.dataframe(loss_by_label)
59
60            st.write(
61                "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._"
62            )
63
64        with col2:
65            st.subheader("🏷️ Loss by Label")
66            loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
67            AgGrid(loss_by_label.round(3), height=200)
@st.cache
def get_loss_by_token(df_tokens)
 9@st.cache
10def get_loss_by_token(df_tokens):
11    return (
12        df_tokens.groupby("tokens")[["losses"]]
13        .agg(["count", "mean", "median", "sum"])
14        .droplevel(level=0, axis=1)  # Get rid of multi-level columns
15        .sort_values(by="sum", ascending=False)
16        .reset_index()
17    )
@st.cache
def get_loss_by_label(df_tokens)
20@st.cache
21def get_loss_by_label(df_tokens):
22    return (
23        df_tokens.groupby("labels")[["losses"]]
24        .agg(["count", "mean", "median", "sum"])
25        .droplevel(level=0, axis=1)
26        .sort_values(by="mean", ascending=False)
27        .reset_index()
28    )
class LossesPage(src.subpages.page.Page):
31class LossesPage(Page):
32    name = "Loss by Token/Label"
33    icon = "sort-alpha-down"
34
35    def render(self, context: Context):
36        st.title(self.name)
37        with st.expander("💡", expanded=True):
38            st.write("Show count, mean and median loss per token and label.")
39            st.write(
40                "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues."
41            )
42
43        col1, _, col2 = st.columns([8, 1, 6])
44
45        with col1:
46            st.subheader("💬 Loss by Token")
47
48            st.session_state["_merge_tokens"] = st.checkbox(
49                "Merge tokens", value=True, key="merge_tokens"
50            )
51            loss_by_token = (
52                get_loss_by_token(context.df_tokens_merged)
53                if st.session_state["merge_tokens"]
54                else get_loss_by_token(context.df_tokens_cleaned)
55            )
56            aggrid_interactive_table(loss_by_token.round(3))
57            # st.subheader("🏷️ Loss by Label")
58            # loss_by_label = get_loss_by_label(df_tokens_cleaned)
59            # st.dataframe(loss_by_label)
60
61            st.write(
62                "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._"
63            )
64
65        with col2:
66            st.subheader("🏷️ Loss by Label")
67            loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
68            AgGrid(loss_by_label.round(3), height=200)

Base class for all pages.

LossesPage()
name: str = 'Loss by Token/Label'
icon: str = 'sort-alpha-down'
def render(self, context: src.subpages.page.Context)
35    def render(self, context: Context):
36        st.title(self.name)
37        with st.expander("💡", expanded=True):
38            st.write("Show count, mean and median loss per token and label.")
39            st.write(
40                "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues."
41            )
42
43        col1, _, col2 = st.columns([8, 1, 6])
44
45        with col1:
46            st.subheader("💬 Loss by Token")
47
48            st.session_state["_merge_tokens"] = st.checkbox(
49                "Merge tokens", value=True, key="merge_tokens"
50            )
51            loss_by_token = (
52                get_loss_by_token(context.df_tokens_merged)
53                if st.session_state["merge_tokens"]
54                else get_loss_by_token(context.df_tokens_cleaned)
55            )
56            aggrid_interactive_table(loss_by_token.round(3))
57            # st.subheader("🏷️ Loss by Label")
58            # loss_by_label = get_loss_by_label(df_tokens_cleaned)
59            # st.dataframe(loss_by_label)
60
61            st.write(
62                "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._"
63            )
64
65        with col2:
66            st.subheader("🏷️ Loss by Label")
67            loss_by_label = get_loss_by_label(context.df_tokens_cleaned)
68            AgGrid(loss_by_label.round(3), height=200)

This function renders the page.