src.subpages.losses
Show count, mean and median loss per token and label.
1"""Show count, mean and median loss per token and label.""" 2import streamlit as st 3 4from src.subpages.page import Context, Page 5from src.utils import AgGrid, aggrid_interactive_table 6 7 8@st.cache 9def get_loss_by_token(df_tokens): 10 return ( 11 df_tokens.groupby("tokens")[["losses"]] 12 .agg(["count", "mean", "median", "sum"]) 13 .droplevel(level=0, axis=1) # Get rid of multi-level columns 14 .sort_values(by="sum", ascending=False) 15 .reset_index() 16 ) 17 18 19@st.cache 20def get_loss_by_label(df_tokens): 21 return ( 22 df_tokens.groupby("labels")[["losses"]] 23 .agg(["count", "mean", "median", "sum"]) 24 .droplevel(level=0, axis=1) 25 .sort_values(by="mean", ascending=False) 26 .reset_index() 27 ) 28 29 30class LossesPage(Page): 31 name = "Loss by Token/Label" 32 icon = "sort-alpha-down" 33 34 def render(self, context: Context): 35 st.title(self.name) 36 with st.expander("💡", expanded=True): 37 st.write("Show count, mean and median loss per token and label.") 38 st.write( 39 "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues." 40 ) 41 42 col1, _, col2 = st.columns([8, 1, 6]) 43 44 with col1: 45 st.subheader("💬 Loss by Token") 46 47 st.session_state["_merge_tokens"] = st.checkbox( 48 "Merge tokens", value=True, key="merge_tokens" 49 ) 50 loss_by_token = ( 51 get_loss_by_token(context.df_tokens_merged) 52 if st.session_state["merge_tokens"] 53 else get_loss_by_token(context.df_tokens_cleaned) 54 ) 55 aggrid_interactive_table(loss_by_token.round(3)) 56 # st.subheader("🏷️ Loss by Label") 57 # loss_by_label = get_loss_by_label(df_tokens_cleaned) 58 # st.dataframe(loss_by_label) 59 60 st.write( 61 "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._" 62 ) 63 64 with col2: 65 st.subheader("🏷️ Loss by Label") 66 loss_by_label = get_loss_by_label(context.df_tokens_cleaned) 67 AgGrid(loss_by_label.round(3), height=200)
@st.cache
def
get_loss_by_token(df_tokens)
@st.cache
def
get_loss_by_label(df_tokens)
31class LossesPage(Page): 32 name = "Loss by Token/Label" 33 icon = "sort-alpha-down" 34 35 def render(self, context: Context): 36 st.title(self.name) 37 with st.expander("💡", expanded=True): 38 st.write("Show count, mean and median loss per token and label.") 39 st.write( 40 "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues." 41 ) 42 43 col1, _, col2 = st.columns([8, 1, 6]) 44 45 with col1: 46 st.subheader("💬 Loss by Token") 47 48 st.session_state["_merge_tokens"] = st.checkbox( 49 "Merge tokens", value=True, key="merge_tokens" 50 ) 51 loss_by_token = ( 52 get_loss_by_token(context.df_tokens_merged) 53 if st.session_state["merge_tokens"] 54 else get_loss_by_token(context.df_tokens_cleaned) 55 ) 56 aggrid_interactive_table(loss_by_token.round(3)) 57 # st.subheader("🏷️ Loss by Label") 58 # loss_by_label = get_loss_by_label(df_tokens_cleaned) 59 # st.dataframe(loss_by_label) 60 61 st.write( 62 "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._" 63 ) 64 65 with col2: 66 st.subheader("🏷️ Loss by Label") 67 loss_by_label = get_loss_by_label(context.df_tokens_cleaned) 68 AgGrid(loss_by_label.round(3), height=200)
Base class for all pages.
35 def render(self, context: Context): 36 st.title(self.name) 37 with st.expander("💡", expanded=True): 38 st.write("Show count, mean and median loss per token and label.") 39 st.write( 40 "Look out for tokens that have a big gap between mean and median, indicating systematic labeling issues." 41 ) 42 43 col1, _, col2 = st.columns([8, 1, 6]) 44 45 with col1: 46 st.subheader("💬 Loss by Token") 47 48 st.session_state["_merge_tokens"] = st.checkbox( 49 "Merge tokens", value=True, key="merge_tokens" 50 ) 51 loss_by_token = ( 52 get_loss_by_token(context.df_tokens_merged) 53 if st.session_state["merge_tokens"] 54 else get_loss_by_token(context.df_tokens_cleaned) 55 ) 56 aggrid_interactive_table(loss_by_token.round(3)) 57 # st.subheader("🏷️ Loss by Label") 58 # loss_by_label = get_loss_by_label(df_tokens_cleaned) 59 # st.dataframe(loss_by_label) 60 61 st.write( 62 "_Caveat: Even though tokens have contextual representations, we average them to get these summary statistics._" 63 ) 64 65 with col2: 66 st.subheader("🏷️ Loss by Label") 67 loss_by_label = get_loss_by_label(context.df_tokens_cleaned) 68 AgGrid(loss_by_label.round(3), height=200)
This function renders the page.