src.subpages.inspect
Inspect your whole dataset, either unfiltered or by id.
1"""Inspect your whole dataset, either unfiltered or by id.""" 2import streamlit as st 3 4from src.subpages.page import Context, Page 5from src.utils import aggrid_interactive_table, colorize_classes 6 7 8class InspectPage(Page): 9 name = "Inspect" 10 icon = "search" 11 12 def render(self, context: Context): 13 st.title(self.name) 14 with st.expander("💡", expanded=True): 15 st.write("Inspect your whole dataset, either unfiltered or by id.") 16 17 df = context.df_tokens 18 cols = ( 19 "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() 20 ) 21 if "token_type_ids" not in df.columns: 22 cols.remove("token_type_ids") 23 df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols] 24 25 if st.checkbox("Filter by id", value=True): 26 ids = list(sorted(map(int, df.ids.unique()))) 27 next_id = st.session_state.get("next_id", 0) 28 29 example_id = st.selectbox("Select an example", ids, index=next_id) 30 df = df[df.ids == str(example_id)][1:-1] 31 # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses")) # type: ignore 32 st.dataframe(colorize_classes(df.round(3).astype(str))) 33 34 if st.button("Next example"): 35 st.session_state.next_id = (ids.index(example_id) + 1) % len(ids) 36 if st.button("Previous example"): 37 st.session_state.next_id = (ids.index(example_id) - 1) % len(ids) 38 else: 39 aggrid_interactive_table(df.round(3))
9class InspectPage(Page): 10 name = "Inspect" 11 icon = "search" 12 13 def render(self, context: Context): 14 st.title(self.name) 15 with st.expander("💡", expanded=True): 16 st.write("Inspect your whole dataset, either unfiltered or by id.") 17 18 df = context.df_tokens 19 cols = ( 20 "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() 21 ) 22 if "token_type_ids" not in df.columns: 23 cols.remove("token_type_ids") 24 df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols] 25 26 if st.checkbox("Filter by id", value=True): 27 ids = list(sorted(map(int, df.ids.unique()))) 28 next_id = st.session_state.get("next_id", 0) 29 30 example_id = st.selectbox("Select an example", ids, index=next_id) 31 df = df[df.ids == str(example_id)][1:-1] 32 # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses")) # type: ignore 33 st.dataframe(colorize_classes(df.round(3).astype(str))) 34 35 if st.button("Next example"): 36 st.session_state.next_id = (ids.index(example_id) + 1) % len(ids) 37 if st.button("Previous example"): 38 st.session_state.next_id = (ids.index(example_id) - 1) % len(ids) 39 else: 40 aggrid_interactive_table(df.round(3))
Base class for all pages.
13 def render(self, context: Context): 14 st.title(self.name) 15 with st.expander("💡", expanded=True): 16 st.write("Inspect your whole dataset, either unfiltered or by id.") 17 18 df = context.df_tokens 19 cols = ( 20 "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() 21 ) 22 if "token_type_ids" not in df.columns: 23 cols.remove("token_type_ids") 24 df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols] 25 26 if st.checkbox("Filter by id", value=True): 27 ids = list(sorted(map(int, df.ids.unique()))) 28 next_id = st.session_state.get("next_id", 0) 29 30 example_id = st.selectbox("Select an example", ids, index=next_id) 31 df = df[df.ids == str(example_id)][1:-1] 32 # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses")) # type: ignore 33 st.dataframe(colorize_classes(df.round(3).astype(str))) 34 35 if st.button("Next example"): 36 st.session_state.next_id = (ids.index(example_id) + 1) % len(ids) 37 if st.button("Previous example"): 38 st.session_state.next_id = (ids.index(example_id) - 1) % len(ids) 39 else: 40 aggrid_interactive_table(df.round(3))
This function renders the page.