src.subpages.inspect

Inspect your whole dataset, either unfiltered or by id.

 1"""Inspect your whole dataset, either unfiltered or by id."""
 2import streamlit as st
 3
 4from src.subpages.page import Context, Page
 5from src.utils import aggrid_interactive_table, colorize_classes
 6
 7
 8class InspectPage(Page):
 9    name = "Inspect"
10    icon = "search"
11
12    def render(self, context: Context):
13        st.title(self.name)
14        with st.expander("💡", expanded=True):
15            st.write("Inspect your whole dataset, either unfiltered or by id.")
16
17        df = context.df_tokens
18        cols = (
19            "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
20        )
21        if "token_type_ids" not in df.columns:
22            cols.remove("token_type_ids")
23        df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols]
24
25        if st.checkbox("Filter by id", value=True):
26            ids = list(sorted(map(int, df.ids.unique())))
27            next_id = st.session_state.get("next_id", 0)
28
29            example_id = st.selectbox("Select an example", ids, index=next_id)
30            df = df[df.ids == str(example_id)][1:-1]
31            # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses"))  # type: ignore
32            st.dataframe(colorize_classes(df.round(3).astype(str)))
33
34            if st.button("Next example"):
35                st.session_state.next_id = (ids.index(example_id) + 1) % len(ids)
36            if st.button("Previous example"):
37                st.session_state.next_id = (ids.index(example_id) - 1) % len(ids)
38        else:
39            aggrid_interactive_table(df.round(3))
class InspectPage(src.subpages.page.Page):
 9class InspectPage(Page):
10    name = "Inspect"
11    icon = "search"
12
13    def render(self, context: Context):
14        st.title(self.name)
15        with st.expander("💡", expanded=True):
16            st.write("Inspect your whole dataset, either unfiltered or by id.")
17
18        df = context.df_tokens
19        cols = (
20            "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
21        )
22        if "token_type_ids" not in df.columns:
23            cols.remove("token_type_ids")
24        df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols]
25
26        if st.checkbox("Filter by id", value=True):
27            ids = list(sorted(map(int, df.ids.unique())))
28            next_id = st.session_state.get("next_id", 0)
29
30            example_id = st.selectbox("Select an example", ids, index=next_id)
31            df = df[df.ids == str(example_id)][1:-1]
32            # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses"))  # type: ignore
33            st.dataframe(colorize_classes(df.round(3).astype(str)))
34
35            if st.button("Next example"):
36                st.session_state.next_id = (ids.index(example_id) + 1) % len(ids)
37            if st.button("Previous example"):
38                st.session_state.next_id = (ids.index(example_id) - 1) % len(ids)
39        else:
40            aggrid_interactive_table(df.round(3))

Base class for all pages.

InspectPage()
name: str = 'Inspect'
icon: str = 'search'
def render(self, context: src.subpages.page.Context)
13    def render(self, context: Context):
14        st.title(self.name)
15        with st.expander("💡", expanded=True):
16            st.write("Inspect your whole dataset, either unfiltered or by id.")
17
18        df = context.df_tokens
19        cols = (
20            "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
21        )
22        if "token_type_ids" not in df.columns:
23            cols.remove("token_type_ids")
24        df = df.drop("hidden_states", axis=1).drop("attention_mask", axis=1)[cols]
25
26        if st.checkbox("Filter by id", value=True):
27            ids = list(sorted(map(int, df.ids.unique())))
28            next_id = st.session_state.get("next_id", 0)
29
30            example_id = st.selectbox("Select an example", ids, index=next_id)
31            df = df[df.ids == str(example_id)][1:-1]
32            # st.dataframe(colorize_classes(df).format(precision=3).bar(subset="losses"))  # type: ignore
33            st.dataframe(colorize_classes(df.round(3).astype(str)))
34
35            if st.button("Next example"):
36                st.session_state.next_id = (ids.index(example_id) + 1) % len(ids)
37            if st.button("Previous example"):
38                st.session_state.next_id = (ids.index(example_id) - 1) % len(ids)
39        else:
40            aggrid_interactive_table(df.round(3))

This function renders the page.