src.subpages.misclassified
This page contains all misclassified examples and allows filtering by specific error types.
1"""This page contains all misclassified examples and allows filtering by specific error types.""" 2from collections import defaultdict 3 4import pandas as pd 5import streamlit as st 6from sklearn.metrics import confusion_matrix 7 8from src.subpages.page import Context, Page 9from src.utils import htmlify_labeled_example 10 11 12class MisclassifiedPage(Page): 13 name = "Misclassified" 14 icon = "x-octagon" 15 16 def render(self, context: Context): 17 st.title(self.name) 18 with st.expander("💡", expanded=True): 19 st.write( 20 "This page contains all misclassified examples and allows filtering by specific error types." 21 ) 22 23 misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique() 24 misclassified_samples = context.df_tokens_merged.loc[misclassified_indices] 25 cm = confusion_matrix( 26 misclassified_samples.labels, 27 misclassified_samples.preds, 28 labels=context.labels, 29 ) 30 31 # st.pyplot( 32 # plot_confusion_matrix( 33 # y_preds=misclassified_samples["preds"], 34 # y_true=misclassified_samples["labels"], 35 # labels=labels, 36 # normalize=None, 37 # zero_diagonal=True, 38 # ), 39 # ) 40 df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str) 41 import numpy as np 42 43 np.fill_diagonal(df.values, "") 44 st.dataframe(df.applymap(lambda x: x if x != "0" else "")) 45 # import matplotlib.pyplot as plt 46 # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html()) 47 # selection = aggrid_interactive_table(df) 48 49 # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True) 50 51 confusions = defaultdict(int) 52 for i, row in enumerate(cm): 53 for j, _ in enumerate(row): 54 if i == j or cm[i][j] == 0: 55 continue 56 confusions[(context.labels[i], context.labels[j])] += cm[i][j] 57 58 def format_func(item): 59 return ( 60 f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All" 61 ) 62 63 conf = st.radio( 64 "Filter by Class Confusion", 65 options=list(zip(confusions.keys(), confusions.values())), 66 format_func=format_func, 67 ) 68 69 # st.write( 70 # f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`" 71 # ) 72 73 filtered_indices = misclassified_samples.query( 74 f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'" 75 ).index 76 for i, idx in enumerate(filtered_indices): 77 sample = context.df_tokens_merged.loc[idx] 78 st.write( 79 htmlify_labeled_example(sample), 80 unsafe_allow_html=True, 81 ) 82 st.write("---")
13class MisclassifiedPage(Page): 14 name = "Misclassified" 15 icon = "x-octagon" 16 17 def render(self, context: Context): 18 st.title(self.name) 19 with st.expander("💡", expanded=True): 20 st.write( 21 "This page contains all misclassified examples and allows filtering by specific error types." 22 ) 23 24 misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique() 25 misclassified_samples = context.df_tokens_merged.loc[misclassified_indices] 26 cm = confusion_matrix( 27 misclassified_samples.labels, 28 misclassified_samples.preds, 29 labels=context.labels, 30 ) 31 32 # st.pyplot( 33 # plot_confusion_matrix( 34 # y_preds=misclassified_samples["preds"], 35 # y_true=misclassified_samples["labels"], 36 # labels=labels, 37 # normalize=None, 38 # zero_diagonal=True, 39 # ), 40 # ) 41 df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str) 42 import numpy as np 43 44 np.fill_diagonal(df.values, "") 45 st.dataframe(df.applymap(lambda x: x if x != "0" else "")) 46 # import matplotlib.pyplot as plt 47 # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html()) 48 # selection = aggrid_interactive_table(df) 49 50 # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True) 51 52 confusions = defaultdict(int) 53 for i, row in enumerate(cm): 54 for j, _ in enumerate(row): 55 if i == j or cm[i][j] == 0: 56 continue 57 confusions[(context.labels[i], context.labels[j])] += cm[i][j] 58 59 def format_func(item): 60 return ( 61 f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All" 62 ) 63 64 conf = st.radio( 65 "Filter by Class Confusion", 66 options=list(zip(confusions.keys(), confusions.values())), 67 format_func=format_func, 68 ) 69 70 # st.write( 71 # f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`" 72 # ) 73 74 filtered_indices = misclassified_samples.query( 75 f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'" 76 ).index 77 for i, idx in enumerate(filtered_indices): 78 sample = context.df_tokens_merged.loc[idx] 79 st.write( 80 htmlify_labeled_example(sample), 81 unsafe_allow_html=True, 82 ) 83 st.write("---")
Base class for all pages.
17 def render(self, context: Context): 18 st.title(self.name) 19 with st.expander("💡", expanded=True): 20 st.write( 21 "This page contains all misclassified examples and allows filtering by specific error types." 22 ) 23 24 misclassified_indices = context.df_tokens_merged.query("labels != preds").index.unique() 25 misclassified_samples = context.df_tokens_merged.loc[misclassified_indices] 26 cm = confusion_matrix( 27 misclassified_samples.labels, 28 misclassified_samples.preds, 29 labels=context.labels, 30 ) 31 32 # st.pyplot( 33 # plot_confusion_matrix( 34 # y_preds=misclassified_samples["preds"], 35 # y_true=misclassified_samples["labels"], 36 # labels=labels, 37 # normalize=None, 38 # zero_diagonal=True, 39 # ), 40 # ) 41 df = pd.DataFrame(cm, index=context.labels, columns=context.labels).astype(str) 42 import numpy as np 43 44 np.fill_diagonal(df.values, "") 45 st.dataframe(df.applymap(lambda x: x if x != "0" else "")) 46 # import matplotlib.pyplot as plt 47 # st.pyplot(df.style.background_gradient(cmap='RdYlGn_r').to_html()) 48 # selection = aggrid_interactive_table(df) 49 50 # st.write(df.to_html(escape=False, index=False), unsafe_allow_html=True) 51 52 confusions = defaultdict(int) 53 for i, row in enumerate(cm): 54 for j, _ in enumerate(row): 55 if i == j or cm[i][j] == 0: 56 continue 57 confusions[(context.labels[i], context.labels[j])] += cm[i][j] 58 59 def format_func(item): 60 return ( 61 f"true: {item[0][0]} <> pred: {item[0][1]} ||| count: {item[1]}" if item else "All" 62 ) 63 64 conf = st.radio( 65 "Filter by Class Confusion", 66 options=list(zip(confusions.keys(), confusions.values())), 67 format_func=format_func, 68 ) 69 70 # st.write( 71 # f"**Filtering Examples:** True class: `{conf[0][0]}`, Predicted class: `{conf[0][1]}`" 72 # ) 73 74 filtered_indices = misclassified_samples.query( 75 f"labels == '{conf[0][0]}' and preds == '{conf[0][1]}'" 76 ).index 77 for i, idx in enumerate(filtered_indices): 78 sample = context.df_tokens_merged.loc[idx] 79 st.write( 80 htmlify_labeled_example(sample), 81 unsafe_allow_html=True, 82 ) 83 st.write("---")
This function renders the page.