src.subpages.misclassified

This page contains all misclassified examples and allows filtering by specific error types.

 1"""This page contains all misclassified examples and allows filtering by specific error types."""
 2from collections import defaultdict
 3
 4import pandas as pd
 5import streamlit as st
 6from sklearn.metrics import confusion_matrix
 7
 8from src.subpages.page import Context, Page
 9from src.utils import htmlify_labeled_example
10
11
12class MisclassifiedPage(Page):
13    name = "Misclassified"
14    icon = "x-octagon"
15
16    def render(self, context: Context):
17        st.title(self.name)
18        with st.expander("💡", expanded=True):
19            st.write(
20                "This page contains all misclassified examples and allows filtering by specific error types."
21            )
22
23        misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
24        misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
25        cm = confusion_matrix(
26            misclassified_samples.labels,
27            misclassified_samples.preds,
28            labels=context.labels,
29        )
30
31        # st.pyplot(
32        #     plot_confusion_matrix(
33        #         y_preds=misclassified_samples["preds"],
34        #         y_true=misclassified_samples["labels"],
35        #         labels=labels,
36        #         normalize=None,
37        #         zero_diagonal=True,
38        #     ),
39        # )
40        df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
41        import numpy as np
42
43        np.fill_diagonal(df.values, "")
44        st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
45        # import matplotlib.pyplot as plt
46        # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
47        # selection = aggrid_interactive_table(df)
48
49        # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
50
51        confusions = defaultdict(int)
52        for i, row in enumerate(cm):
53            for j, _ in enumerate(row):
54                if i == j or cm[i][j] == 0:
55                    continue
56                confusions[(context.labels[i], context.labels[j])] += cm[i][j]
57
58        def format_func(item):
59            return (
60                f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
61            )
62
63        conf = st.radio(
64            "Filter by Class Confusion",
65            options=list(zip(confusions.keys(), confusions.values())),
66            format_func=format_func,
67        )
68
69        # st.write(
70        #     f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
71        # )
72
73        filtered_indices = misclassified_samples.query(
74            f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
75        ).index
76        for i, idx in enumerate(filtered_indices):
77            sample = context.df_tokens_merged.loc[idx]
78            st.write(
79                htmlify_labeled_example(sample),
80                unsafe_allow_html=True,
81            )
82            st.write("---")
class MisclassifiedPage(src.subpages.page.Page):
13class MisclassifiedPage(Page):
14    name = "Misclassified"
15    icon = "x-octagon"
16
17    def render(self, context: Context):
18        st.title(self.name)
19        with st.expander("💡", expanded=True):
20            st.write(
21                "This page contains all misclassified examples and allows filtering by specific error types."
22            )
23
24        misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
25        misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
26        cm = confusion_matrix(
27            misclassified_samples.labels,
28            misclassified_samples.preds,
29            labels=context.labels,
30        )
31
32        # st.pyplot(
33        #     plot_confusion_matrix(
34        #         y_preds=misclassified_samples["preds"],
35        #         y_true=misclassified_samples["labels"],
36        #         labels=labels,
37        #         normalize=None,
38        #         zero_diagonal=True,
39        #     ),
40        # )
41        df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
42        import numpy as np
43
44        np.fill_diagonal(df.values, "")
45        st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
46        # import matplotlib.pyplot as plt
47        # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
48        # selection = aggrid_interactive_table(df)
49
50        # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
51
52        confusions = defaultdict(int)
53        for i, row in enumerate(cm):
54            for j, _ in enumerate(row):
55                if i == j or cm[i][j] == 0:
56                    continue
57                confusions[(context.labels[i], context.labels[j])] += cm[i][j]
58
59        def format_func(item):
60            return (
61                f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
62            )
63
64        conf = st.radio(
65            "Filter by Class Confusion",
66            options=list(zip(confusions.keys(), confusions.values())),
67            format_func=format_func,
68        )
69
70        # st.write(
71        #     f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
72        # )
73
74        filtered_indices = misclassified_samples.query(
75            f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
76        ).index
77        for i, idx in enumerate(filtered_indices):
78            sample = context.df_tokens_merged.loc[idx]
79            st.write(
80                htmlify_labeled_example(sample),
81                unsafe_allow_html=True,
82            )
83            st.write("---")

Base class for all pages.

MisclassifiedPage()
name: str = 'Misclassified'
icon: str = 'x-octagon'
def render(self, context: src.subpages.page.Context)
17    def render(self, context: Context):
18        st.title(self.name)
19        with st.expander("💡", expanded=True):
20            st.write(
21                "This page contains all misclassified examples and allows filtering by specific error types."
22            )
23
24        misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique()
25        misclassified_samples = context.df_tokens_merged.loc[misclassified_indices]
26        cm = confusion_matrix(
27            misclassified_samples.labels,
28            misclassified_samples.preds,
29            labels=context.labels,
30        )
31
32        # st.pyplot(
33        #     plot_confusion_matrix(
34        #         y_preds=misclassified_samples["preds"],
35        #         y_true=misclassified_samples["labels"],
36        #         labels=labels,
37        #         normalize=None,
38        #         zero_diagonal=True,
39        #     ),
40        # )
41        df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str)
42        import numpy as np
43
44        np.fill_diagonal(df.values, "")
45        st.dataframe(df.applymap(lambda x: x if x != "0" else ""))
46        # import matplotlib.pyplot as plt
47        # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html())
48        # selection = aggrid_interactive_table(df)
49
50        # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True)
51
52        confusions = defaultdict(int)
53        for i, row in enumerate(cm):
54            for j, _ in enumerate(row):
55                if i == j or cm[i][j] == 0:
56                    continue
57                confusions[(context.labels[i], context.labels[j])] += cm[i][j]
58
59        def format_func(item):
60            return (
61                f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All"
62            )
63
64        conf = st.radio(
65            "Filter by Class Confusion",
66            options=list(zip(confusions.keys(), confusions.values())),
67            format_func=format_func,
68        )
69
70        # st.write(
71        #     f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`"
72        # )
73
74        filtered_indices = misclassified_samples.query(
75            f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'"
76        ).index
77        for i, idx in enumerate(filtered_indices):
78            sample = context.df_tokens_merged.loc[idx]
79            st.write(
80                htmlify_labeled_example(sample),
81                unsafe_allow_html=True,
82            )
83            st.write("---")

This function renders the page.