src.subpages.metrics
The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
1""" 2The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts). 3""" 4import re 5 6import matplotlib.pyplot as plt 7import numpy as np 8import pandas as pd 9import plotly.express as px 10import streamlit as st 11from seqeval.metrics import classification_report 12from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix 13 14from src.subpages.page import Context, Page 15 16 17def _get_evaluation(df): 18 y_true = df.apply(lambda row: [lbl for lbl in row.labels if lbl != "IGN"], axis=1) 19 y_pred = df.apply( 20 lambda row: [pred for (pred, lbl) in zip(row.preds, row.labels) if lbl != "IGN"], 21 axis=1, 22 ) 23 report: str = classification_report(y_true, y_pred, scheme="IOB2", digits=3) # type: ignore 24 return report.replace( 25 "precision recall f1-score support", 26 "=" * 12 + " precision recall f1-score support", 27 ) 28 29 30def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True): 31 cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels) 32 if zero_diagonal: 33 np.fill_diagonal(cm, 0) 34 35 # st.write(plt.rcParams["font.size"]) 36 # plt.rcParams.update({'font.size': 10.0}) 37 fig, ax = plt.subplots(figsize=(10, 10)) 38 disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels) 39 fmt = "d" if normalize is None else ".3f" 40 disp.plot( 41 cmap="Blues", 42 include_values=True, 43 xticks_rotation="vertical", 44 values_format=fmt, 45 ax=ax, 46 colorbar=False, 47 ) 48 return fig 49 50 51class MetricsPage(Page): 52 name = "Metrics" 53 icon = "graph-up-arrow" 54 55 def get_widget_defaults(self): 56 return { 57 "normalize": True, 58 "zero_diagonal": False, 59 } 60 61 def render(self, context: Context): 62 st.title(self.name) 63 with st.expander("💡", expanded=True): 64 st.write( 65 "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)." 66 ) 67 st.write( 68 "With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone." 69 ) 70 71 eval_results = _get_evaluation(context.df) 72 if len(eval_results.splitlines()) < 8: 73 col1, _, col2 = st.columns([8, 1, 1]) 74 else: 75 col1 = col2 = st 76 77 col1.subheader("🎯 Evaluation Results") 78 col1.code(eval_results) 79 80 results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]] 81 data = [(r[0], int(r[-1]), float(r[-2])) for r in results] 82 df = pd.DataFrame(data, columns="class support f1".split()) 83 fig = px.scatter( 84 df, 85 x="support", 86 y="f1", 87 range_y=(0, 1.05), 88 color="class", 89 ) 90 # fig.update_layout(title_text="asdf", title_yanchor="bottom") 91 col1.plotly_chart(fig) 92 93 col2.subheader("🔠Confusion Matrix") 94 normalize = None if not col2.checkbox("Normalize", key="normalize") else "true" 95 zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal") 96 col2.pyplot( 97 plot_confusion_matrix( 98 y_true=context.df_tokens_cleaned["labels"], 99 y_preds=context.df_tokens_cleaned["preds"], 100 labels=context.labels, 101 normalize=normalize, 102 zero_diagonal=zero_diagonal, 103 ), 104 )
def
plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True)
31def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True): 32 cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels) 33 if zero_diagonal: 34 np.fill_diagonal(cm, 0) 35 36 # st.write(plt.rcParams["font.size"]) 37 # plt.rcParams.update({'font.size': 10.0}) 38 fig, ax = plt.subplots(figsize=(10, 10)) 39 disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels) 40 fmt = "d" if normalize is None else ".3f" 41 disp.plot( 42 cmap="Blues", 43 include_values=True, 44 xticks_rotation="vertical", 45 values_format=fmt, 46 ax=ax, 47 colorbar=False, 48 ) 49 return fig
52class MetricsPage(Page): 53 name = "Metrics" 54 icon = "graph-up-arrow" 55 56 def get_widget_defaults(self): 57 return { 58 "normalize": True, 59 "zero_diagonal": False, 60 } 61 62 def render(self, context: Context): 63 st.title(self.name) 64 with st.expander("💡", expanded=True): 65 st.write( 66 "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)." 67 ) 68 st.write( 69 "With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone." 70 ) 71 72 eval_results = _get_evaluation(context.df) 73 if len(eval_results.splitlines()) < 8: 74 col1, _, col2 = st.columns([8, 1, 1]) 75 else: 76 col1 = col2 = st 77 78 col1.subheader("🎯 Evaluation Results") 79 col1.code(eval_results) 80 81 results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]] 82 data = [(r[0], int(r[-1]), float(r[-2])) for r in results] 83 df = pd.DataFrame(data, columns="class support f1".split()) 84 fig = px.scatter( 85 df, 86 x="support", 87 y="f1", 88 range_y=(0, 1.05), 89 color="class", 90 ) 91 # fig.update_layout(title_text="asdf", title_yanchor="bottom") 92 col1.plotly_chart(fig) 93 94 col2.subheader("🔠Confusion Matrix") 95 normalize = None if not col2.checkbox("Normalize", key="normalize") else "true" 96 zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal") 97 col2.pyplot( 98 plot_confusion_matrix( 99 y_true=context.df_tokens_cleaned["labels"], 100 y_preds=context.df_tokens_cleaned["preds"], 101 labels=context.labels, 102 normalize=normalize, 103 zero_diagonal=zero_diagonal, 104 ), 105 )
Base class for all pages.
def
get_widget_defaults(self)
This function holds the default settings for all the page's widgets.
Returns
dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.
62 def render(self, context: Context): 63 st.title(self.name) 64 with st.expander("💡", expanded=True): 65 st.write( 66 "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)." 67 ) 68 st.write( 69 "With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone." 70 ) 71 72 eval_results = _get_evaluation(context.df) 73 if len(eval_results.splitlines()) < 8: 74 col1, _, col2 = st.columns([8, 1, 1]) 75 else: 76 col1 = col2 = st 77 78 col1.subheader("🎯 Evaluation Results") 79 col1.code(eval_results) 80 81 results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]] 82 data = [(r[0], int(r[-1]), float(r[-2])) for r in results] 83 df = pd.DataFrame(data, columns="class support f1".split()) 84 fig = px.scatter( 85 df, 86 x="support", 87 y="f1", 88 range_y=(0, 1.05), 89 color="class", 90 ) 91 # fig.update_layout(title_text="asdf", title_yanchor="bottom") 92 col1.plotly_chart(fig) 93 94 col2.subheader("🔠Confusion Matrix") 95 normalize = None if not col2.checkbox("Normalize", key="normalize") else "true" 96 zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal") 97 col2.pyplot( 98 plot_confusion_matrix( 99 y_true=context.df_tokens_cleaned["labels"], 100 y_preds=context.df_tokens_cleaned["preds"], 101 labels=context.labels, 102 normalize=normalize, 103 zero_diagonal=zero_diagonal, 104 ), 105 )
This function renders the page.