src.subpages.metrics

The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).

  1"""
  2The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts).
  3"""
  4import re
  5
  6import matplotlib.pyplot as plt
  7import numpy as np
  8import pandas as pd
  9import plotly.express as px
 10import streamlit as st
 11from seqeval.metrics import classification_report
 12from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
 13
 14from src.subpages.page import Context, Page
 15
 16
 17def _get_evaluation(df):
 18    y_true = df.apply(lambda row: [lbl for lbl in row.labels if lbl != "IGN"], axis=1)
 19    y_pred = df.apply(
 20        lambda row: [pred for (pred, lbl) in zip(row.preds, row.labels) if lbl != "IGN"],
 21        axis=1,
 22    )
 23    report: str = classification_report(y_true, y_pred, scheme="IOB2", digits=3)  # type: ignore
 24    return report.replace(
 25        "precision    recall  f1-score   support",
 26        "=" * 12 + "  precision    recall  f1-score   support",
 27    )
 28
 29
 30def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True):
 31    cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels)
 32    if zero_diagonal:
 33        np.fill_diagonal(cm, 0)
 34
 35    # st.write(plt.rcParams["font.size"])
 36    # plt.rcParams.update({'font.size': 10.0})
 37    fig, ax = plt.subplots(figsize=(10, 10))
 38    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
 39    fmt = "d" if normalize is None else ".3f"
 40    disp.plot(
 41        cmap="Blues",
 42        include_values=True,
 43        xticks_rotation="vertical",
 44        values_format=fmt,
 45        ax=ax,
 46        colorbar=False,
 47    )
 48    return fig
 49
 50
 51class MetricsPage(Page):
 52    name = "Metrics"
 53    icon = "graph-up-arrow"
 54
 55    def get_widget_defaults(self):
 56        return {
 57            "normalize": True,
 58            "zero_diagonal": False,
 59        }
 60
 61    def render(self, context: Context):
 62        st.title(self.name)
 63        with st.expander("💡", expanded=True):
 64            st.write(
 65                "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
 66            )
 67            st.write(
 68                "With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone."
 69            )
 70
 71        eval_results = _get_evaluation(context.df)
 72        if len(eval_results.splitlines()) < 8:
 73            col1, _, col2 = st.columns([8, 1, 1])
 74        else:
 75            col1 = col2 = st
 76
 77        col1.subheader("🎯 Evaluation Results")
 78        col1.code(eval_results)
 79
 80        results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
 81        data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
 82        df = pd.DataFrame(data, columns="class support f1".split())
 83        fig = px.scatter(
 84            df,
 85            x="support",
 86            y="f1",
 87            range_y=(0, 1.05),
 88            color="class",
 89        )
 90        # fig.update_layout(title_text="asdf", title_yanchor="bottom")
 91        col1.plotly_chart(fig)
 92
 93        col2.subheader("🔠 Confusion Matrix")
 94        normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
 95        zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
 96        col2.pyplot(
 97            plot_confusion_matrix(
 98                y_true=context.df_tokens_cleaned["labels"],
 99                y_preds=context.df_tokens_cleaned["preds"],
100                labels=context.labels,
101                normalize=normalize,
102                zero_diagonal=zero_diagonal,
103            ),
104        )
def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True)
31def plot_confusion_matrix(y_true, y_preds, labels, normalize=None, zero_diagonal=True):
32    cm = confusion_matrix(y_true, y_preds, normalize=normalize, labels=labels)
33    if zero_diagonal:
34        np.fill_diagonal(cm, 0)
35
36    # st.write(plt.rcParams["font.size"])
37    # plt.rcParams.update({'font.size': 10.0})
38    fig, ax = plt.subplots(figsize=(10, 10))
39    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
40    fmt = "d" if normalize is None else ".3f"
41    disp.plot(
42        cmap="Blues",
43        include_values=True,
44        xticks_rotation="vertical",
45        values_format=fmt,
46        ax=ax,
47        colorbar=False,
48    )
49    return fig
class MetricsPage(src.subpages.page.Page):
 52class MetricsPage(Page):
 53    name = "Metrics"
 54    icon = "graph-up-arrow"
 55
 56    def get_widget_defaults(self):
 57        return {
 58            "normalize": True,
 59            "zero_diagonal": False,
 60        }
 61
 62    def render(self, context: Context):
 63        st.title(self.name)
 64        with st.expander("💡", expanded=True):
 65            st.write(
 66                "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
 67            )
 68            st.write(
 69                "With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone."
 70            )
 71
 72        eval_results = _get_evaluation(context.df)
 73        if len(eval_results.splitlines()) < 8:
 74            col1, _, col2 = st.columns([8, 1, 1])
 75        else:
 76            col1 = col2 = st
 77
 78        col1.subheader("🎯 Evaluation Results")
 79        col1.code(eval_results)
 80
 81        results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
 82        data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
 83        df = pd.DataFrame(data, columns="class support f1".split())
 84        fig = px.scatter(
 85            df,
 86            x="support",
 87            y="f1",
 88            range_y=(0, 1.05),
 89            color="class",
 90        )
 91        # fig.update_layout(title_text="asdf", title_yanchor="bottom")
 92        col1.plotly_chart(fig)
 93
 94        col2.subheader("🔠 Confusion Matrix")
 95        normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
 96        zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
 97        col2.pyplot(
 98            plot_confusion_matrix(
 99                y_true=context.df_tokens_cleaned["labels"],
100                y_preds=context.df_tokens_cleaned["preds"],
101                labels=context.labels,
102                normalize=normalize,
103                zero_diagonal=zero_diagonal,
104            ),
105        )

Base class for all pages.

MetricsPage()
name: str = 'Metrics'
icon: str = 'graph-up-arrow'
def get_widget_defaults(self)
56    def get_widget_defaults(self):
57        return {
58            "normalize": True,
59            "zero_diagonal": False,
60        }

This function holds the default settings for all the page's widgets.

Returns

dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.

def render(self, context: src.subpages.page.Context)
 62    def render(self, context: Context):
 63        st.title(self.name)
 64        with st.expander("💡", expanded=True):
 65            st.write(
 66                "The metrics page contains precision, recall and f-score metrics as well as a confusion matrix over all the classes. By default, the confusion matrix is normalized. There's an option to zero out the diagonal, leaving only prediction errors (here it makes sense to turn off normalization, so you get raw error counts)."
 67            )
 68            st.write(
 69                "With the confusion matrix, you don't want any of the classes to end up in the bottom right quarter: those are frequent but error-prone."
 70            )
 71
 72        eval_results = _get_evaluation(context.df)
 73        if len(eval_results.splitlines()) < 8:
 74            col1, _, col2 = st.columns([8, 1, 1])
 75        else:
 76            col1 = col2 = st
 77
 78        col1.subheader("🎯 Evaluation Results")
 79        col1.code(eval_results)
 80
 81        results = [re.split(r" +", l.lstrip()) for l in eval_results.splitlines()[2:-4]]
 82        data = [(r[0], int(r[-1]), float(r[-2])) for r in results]
 83        df = pd.DataFrame(data, columns="class support f1".split())
 84        fig = px.scatter(
 85            df,
 86            x="support",
 87            y="f1",
 88            range_y=(0, 1.05),
 89            color="class",
 90        )
 91        # fig.update_layout(title_text="asdf", title_yanchor="bottom")
 92        col1.plotly_chart(fig)
 93
 94        col2.subheader("🔠 Confusion Matrix")
 95        normalize = None if not col2.checkbox("Normalize", key="normalize") else "true"
 96        zero_diagonal = col2.checkbox("Zero Diagonal", key="zero_diagonal")
 97        col2.pyplot(
 98            plot_confusion_matrix(
 99                y_true=context.df_tokens_cleaned["labels"],
100                y_preds=context.df_tokens_cleaned["preds"],
101                labels=context.labels,
102                normalize=normalize,
103                zero_diagonal=zero_diagonal,
104            ),
105        )

This function renders the page.