src.app

The App module is the main entry point for the application.

Run streamlit run app.py to start the app.

  1"""The App module is the main entry point for the application.
  2
  3    Run `streamlit run app.py` to start the app.
  4
  5"""
  6
  7import pandas as pd
  8import streamlit as st
  9from streamlit_option_menu import option_menu
 10
 11from src.load import load_context
 12from src.subpages import (
 13    DebugPage,
 14    FindDuplicatesPage,
 15    HomePage,
 16    LossesPage,
 17    LossySamplesPage,
 18    MetricsPage,
 19    MisclassifiedPage,
 20    Page,
 21    ProbingPage,
 22    RandomSamplesPage,
 23    RawDataPage,
 24)
 25from src.subpages.attention import AttentionPage
 26from src.subpages.hidden_states import HiddenStatesPage
 27from src.subpages.inspect import InspectPage
 28from src.utils import classmap
 29
 30sts = st.sidebar
 31st.set_page_config(
 32    layout="wide",
 33    page_title="Error Analysis",
 34    page_icon="🏷️",
 35)
 36
 37
 38def _show_menu(pages: list[Page]) -> int:
 39    with st.sidebar:
 40        page_names = [p.name for p in pages]
 41        page_icons = [p.icon for p in pages]
 42        selected_menu_item = st.session_state.active_page = option_menu(
 43            menu_title="ExplaiNER",
 44            options=page_names,
 45            icons=page_icons,
 46            menu_icon="layout-wtf",
 47            default_index=0,
 48        )
 49        return page_names.index(selected_menu_item)
 50    assert False
 51
 52
 53def _initialize_session_state(pages: list[Page]):
 54    if "active_page" not in st.session_state:
 55        for page in pages:
 56            st.session_state.update(**page.get_widget_defaults())
 57    st.session_state.update(st.session_state)
 58
 59
 60def _write_color_legend(context):
 61    def style(x):
 62        return [f"background-color: {rgb}; opacity: 1;" for rgb in colors]
 63
 64    labels = list(set([lbl.split("-")[1] if "-" in lbl else lbl for lbl in context.labels]))
 65    colors = [st.session_state.get(f"color_{lbl}", "#000000") for lbl in labels]
 66
 67    color_legend_df = pd.DataFrame(
 68        [classmap[l] for l in labels], columns=["label"], index=labels
 69    ).T
 70    st.sidebar.write(
 71        color_legend_df.T.style.apply(style, axis=0).set_properties(
 72            **{"color": "white", "text-align": "center"}
 73        )
 74    )
 75
 76
 77def main():
 78    """The main entry point for the application."""
 79    pages: list[Page] = [
 80        HomePage(),
 81        AttentionPage(),
 82        HiddenStatesPage(),
 83        ProbingPage(),
 84        MetricsPage(),
 85        LossySamplesPage(),
 86        LossesPage(),
 87        MisclassifiedPage(),
 88        RandomSamplesPage(),
 89        FindDuplicatesPage(),
 90        InspectPage(),
 91        RawDataPage(),
 92        DebugPage(),
 93    ]
 94
 95    _initialize_session_state(pages)
 96
 97    selected_page_idx = _show_menu(pages)
 98    selected_page = pages[selected_page_idx]
 99
100    if isinstance(selected_page, HomePage):
101        selected_page.render()
102        return
103
104    if "model_name" not in st.session_state:
105        # this can happen if someone loads another page directly (without going through home)
106        st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'")
107        return
108
109    context = load_context(**st.session_state)
110    _write_color_legend(context)
111    selected_page.render(context)
112
113
114if __name__ == "__main__":
115    main()
def main()
 78def main():
 79    """The main entry point for the application."""
 80    pages: list[Page] = [
 81        HomePage(),
 82        AttentionPage(),
 83        HiddenStatesPage(),
 84        ProbingPage(),
 85        MetricsPage(),
 86        LossySamplesPage(),
 87        LossesPage(),
 88        MisclassifiedPage(),
 89        RandomSamplesPage(),
 90        FindDuplicatesPage(),
 91        InspectPage(),
 92        RawDataPage(),
 93        DebugPage(),
 94    ]
 95
 96    _initialize_session_state(pages)
 97
 98    selected_page_idx = _show_menu(pages)
 99    selected_page = pages[selected_page_idx]
100
101    if isinstance(selected_page, HomePage):
102        selected_page.render()
103        return
104
105    if "model_name" not in st.session_state:
106        # this can happen if someone loads another page directly (without going through home)
107        st.error("Setup not complete. Please click on 'Home / Setup in left menu bar'")
108        return
109
110    context = load_context(**st.session_state)
111    _write_color_legend(context)
112    selected_page.render(context)

The main entry point for the application.