src.subpages.hidden_states

For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples marked by a small black border.

  1"""
  2For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples marked by a small black border.
  3"""
  4import numpy as np
  5import plotly.express as px
  6import plotly.graph_objects as go
  7import streamlit as st
  8
  9from src.subpages.page import Context, Page
 10
 11
 12@st.cache
 13def reduce_dim_svd(X, n_iter: int, random_state=42):
 14    """Dimensionality reduction using truncated SVD (aka LSA).
 15
 16    This transformer performs linear dimensionality reduction by means of truncated singular value decomposition (SVD). Contrary to PCA, this estimator does not center the data before computing the singular value decomposition. This means it can work with sparse matrices efficiently.
 17
 18        Args:
 19            X: Training data
 20            n_iter (int): Desired dimensionality of output data. Must be strictly less than the number of features.
 21            random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.
 22
 23        Returns:
 24            ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
 25    """
 26    from sklearn.decomposition import TruncatedSVD
 27
 28    svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
 29    return svd.fit_transform(X)
 30
 31
 32@st.cache
 33def reduce_dim_pca(X, random_state=42):
 34    """Principal component analysis (PCA).
 35
 36    Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.
 37
 38        Args:
 39            X: Training data
 40            random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.
 41
 42        Returns:
 43            ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
 44    """
 45    from sklearn.decomposition import PCA
 46
 47    return PCA(n_components=2, random_state=random_state).fit_transform(X)
 48
 49
 50@st.cache
 51def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
 52    """Uniform Manifold Approximation and Projection
 53
 54    Finds a low dimensional embedding of the data that approximates an underlying manifold.
 55
 56        Args:
 57            X: Training data
 58            n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
 59            min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
 60            metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".
 61
 62        Returns:
 63            ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
 64    """
 65    from umap import UMAP
 66
 67    return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)
 68
 69
 70class HiddenStatesPage(Page):
 71    name = "Hidden States"
 72    icon = "grid-3x3"
 73
 74    def get_widget_defaults(self):
 75        return {
 76            "n_tokens": 1_000,
 77            "svd_n_iter": 5,
 78            "svd_random_state": 42,
 79            "umap_n_neighbors": 15,
 80            "umap_metric": "euclidean",
 81            "umap_min_dist": 0.1,
 82        }
 83
 84    def render(self, context: Context):
 85        st.title("Embeddings")
 86
 87        with st.expander("💡", expanded=True):
 88            st.write(
 89                "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples signified by a small black border."
 90            )
 91
 92        col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
 93        df = context.df_tokens_merged.copy()
 94        dim_algo = "SVD"
 95        n_tokens = 100
 96
 97        with col1:
 98            st.subheader("Settings")
 99            n_tokens = st.slider(
100                "#tokens",
101                key="n_tokens",
102                min_value=100,
103                max_value=len(df["tokens"].unique()),
104                step=100,
105            )
106
107            dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
108            if dim_algo == "SVD":
109                svd_n_iter = st.slider(
110                    "#iterations",
111                    key="svd_n_iter",
112                    min_value=1,
113                    max_value=10,
114                    step=1,
115                )
116            elif dim_algo == "UMAP":
117                umap_n_neighbors = st.slider(
118                    "#neighbors",
119                    key="umap_n_neighbors",
120                    min_value=2,
121                    max_value=100,
122                    step=1,
123                )
124                umap_min_dist = st.number_input(
125                    "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
126                )
127                umap_metric = st.selectbox(
128                    "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
129                )
130            else:
131                pass
132
133        with col2:
134            sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
135
136            X = np.array(df["hidden_states"].tolist())
137            transformed_hidden_states = None
138            if dim_algo == "SVD":
139                transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter)  # type: ignore
140            elif dim_algo == "PCA":
141                transformed_hidden_states = reduce_dim_pca(X)
142            elif dim_algo == "UMAP":
143                transformed_hidden_states = reduce_dim_umap(
144                    X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric  # type: ignore
145                )
146
147            assert isinstance(transformed_hidden_states, np.ndarray)
148            df["x"] = transformed_hidden_states[:, 0]
149            df["y"] = transformed_hidden_states[:, 1]
150            df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
151            df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
152            df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
153            df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
154            df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
155            df["mislabeled"] = df["labels"] != df["preds"]
156
157            subset = df[:n_tokens]
158            mislabeled_examples_trace = go.Scatter(
159                x=subset[subset["mislabeled"]]["x"],
160                y=subset[subset["mislabeled"]]["y"],
161                mode="markers",
162                marker=dict(
163                    size=6,
164                    color="rgba(0,0,0,0)",
165                    line=dict(width=1),
166                ),
167                hoverinfo="skip",
168            )
169
170            st.subheader("Projection Results")
171
172            fig = px.scatter(
173                subset,
174                x="x",
175                y="y",
176                color="labels",
177                hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"],
178                hover_name="tokens",
179                title="Colored by label",
180            )
181            fig.add_trace(mislabeled_examples_trace)
182            st.plotly_chart(fig)
183
184            fig = px.scatter(
185                subset,
186                x="x",
187                y="y",
188                color="preds",
189                hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"],
190                hover_name="tokens",
191                title="Colored by prediction",
192            )
193            fig.add_trace(mislabeled_examples_trace)
194            st.plotly_chart(fig)
@st.cache
def reduce_dim_svd(X, n_iter: int, random_state=42)
13@st.cache
14def reduce_dim_svd(X, n_iter: int, random_state=42):
15    """Dimensionality reduction using truncated SVD (aka LSA).
16
17    This transformer performs linear dimensionality reduction by means of truncated singular value decomposition (SVD). Contrary to PCA, this estimator does not center the data before computing the singular value decomposition. This means it can work with sparse matrices efficiently.
18
19        Args:
20            X: Training data
21            n_iter (int): Desired dimensionality of output data. Must be strictly less than the number of features.
22            random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.
23
24        Returns:
25            ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
26    """
27    from sklearn.decomposition import TruncatedSVD
28
29    svd = TruncatedSVD(n_components=2, n_iter=n_iter, random_state=random_state)
30    return svd.fit_transform(X)

Dimensionality reduction using truncated SVD (aka LSA).

This transformer performs linear dimensionality reduction by means of truncated singular value decomposition (SVD). Contrary to PCA, this estimator does not center the data before computing the singular value decomposition. This means it can work with sparse matrices efficiently.

Args:
    X: Training data
    n_iter (int): Desired dimensionality of output data. Must be strictly less than the number of features.
    random_state (int, optional): Used during randomized svd. Pass an int for reproducible results across multiple function calls. Defaults to 42.

Returns:
    ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
@st.cache
def reduce_dim_pca(X, random_state=42)
33@st.cache
34def reduce_dim_pca(X, random_state=42):
35    """Principal component analysis (PCA).
36
37    Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.
38
39        Args:
40            X: Training data
41            random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.
42
43        Returns:
44            ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
45    """
46    from sklearn.decomposition import PCA
47
48    return PCA(n_components=2, random_state=random_state).fit_transform(X)

Principal component analysis (PCA).

Linear dimensionality reduction using Singular Value Decomposition of the data to project it to a lower dimensional space. The input data is centered but not scaled for each feature before applying the SVD.

Args:
    X: Training data
    random_state (int, optional): Used when the 'arpack' or 'randomized' solvers are used. Pass an int for reproducible results across multiple function calls.

Returns:
    ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
@st.cache
def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric='euclidean')
51@st.cache
52def reduce_dim_umap(X, n_neighbors=5, min_dist=0.1, metric="euclidean"):
53    """Uniform Manifold Approximation and Projection
54
55    Finds a low dimensional embedding of the data that approximates an underlying manifold.
56
57        Args:
58            X: Training data
59            n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
60            min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
61            metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".
62
63        Returns:
64            ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
65    """
66    from umap import UMAP
67
68    return UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric=metric).fit_transform(X)

Uniform Manifold Approximation and Projection

Finds a low dimensional embedding of the data that approximates an underlying manifold.

Args:
    X: Training data
    n_neighbors (int, optional): The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation. Larger values result in more global views of the manifold, while smaller values result in more local data being preserved. In general values should be in the range 2 to 100. Defaults to 5.
    min_dist (float, optional): The effective minimum distance between embedded points. Smaller values will result in a more clustered/clumped embedding where nearby points on the manifold are drawn closer together, while larger values will result on a more even dispersal of points. The value should be set relative to the `spread` value, which determines the scale at which embedded points will be spread out. Defaults to 0.1.
    metric (str, optional): The metric to use to compute distances in high dimensional space (see UMAP docs for options). Defaults to "euclidean".

Returns:
    ndarray: Reduced version of X, ndarray of shape (n_samples, 2).
class HiddenStatesPage(src.subpages.page.Page):
 71class HiddenStatesPage(Page):
 72    name = "Hidden States"
 73    icon = "grid-3x3"
 74
 75    def get_widget_defaults(self):
 76        return {
 77            "n_tokens": 1_000,
 78            "svd_n_iter": 5,
 79            "svd_random_state": 42,
 80            "umap_n_neighbors": 15,
 81            "umap_metric": "euclidean",
 82            "umap_min_dist": 0.1,
 83        }
 84
 85    def render(self, context: Context):
 86        st.title("Embeddings")
 87
 88        with st.expander("💡", expanded=True):
 89            st.write(
 90                "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples signified by a small black border."
 91            )
 92
 93        col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
 94        df = context.df_tokens_merged.copy()
 95        dim_algo = "SVD"
 96        n_tokens = 100
 97
 98        with col1:
 99            st.subheader("Settings")
100            n_tokens = st.slider(
101                "#tokens",
102                key="n_tokens",
103                min_value=100,
104                max_value=len(df["tokens"].unique()),
105                step=100,
106            )
107
108            dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
109            if dim_algo == "SVD":
110                svd_n_iter = st.slider(
111                    "#iterations",
112                    key="svd_n_iter",
113                    min_value=1,
114                    max_value=10,
115                    step=1,
116                )
117            elif dim_algo == "UMAP":
118                umap_n_neighbors = st.slider(
119                    "#neighbors",
120                    key="umap_n_neighbors",
121                    min_value=2,
122                    max_value=100,
123                    step=1,
124                )
125                umap_min_dist = st.number_input(
126                    "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
127                )
128                umap_metric = st.selectbox(
129                    "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
130                )
131            else:
132                pass
133
134        with col2:
135            sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
136
137            X = np.array(df["hidden_states"].tolist())
138            transformed_hidden_states = None
139            if dim_algo == "SVD":
140                transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter)  # type: ignore
141            elif dim_algo == "PCA":
142                transformed_hidden_states = reduce_dim_pca(X)
143            elif dim_algo == "UMAP":
144                transformed_hidden_states = reduce_dim_umap(
145                    X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric  # type: ignore
146                )
147
148            assert isinstance(transformed_hidden_states, np.ndarray)
149            df["x"] = transformed_hidden_states[:, 0]
150            df["y"] = transformed_hidden_states[:, 1]
151            df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
152            df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
153            df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
154            df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
155            df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
156            df["mislabeled"] = df["labels"] != df["preds"]
157
158            subset = df[:n_tokens]
159            mislabeled_examples_trace = go.Scatter(
160                x=subset[subset["mislabeled"]]["x"],
161                y=subset[subset["mislabeled"]]["y"],
162                mode="markers",
163                marker=dict(
164                    size=6,
165                    color="rgba(0,0,0,0)",
166                    line=dict(width=1),
167                ),
168                hoverinfo="skip",
169            )
170
171            st.subheader("Projection Results")
172
173            fig = px.scatter(
174                subset,
175                x="x",
176                y="y",
177                color="labels",
178                hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"],
179                hover_name="tokens",
180                title="Colored by label",
181            )
182            fig.add_trace(mislabeled_examples_trace)
183            st.plotly_chart(fig)
184
185            fig = px.scatter(
186                subset,
187                x="x",
188                y="y",
189                color="preds",
190                hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"],
191                hover_name="tokens",
192                title="Colored by prediction",
193            )
194            fig.add_trace(mislabeled_examples_trace)
195            st.plotly_chart(fig)

Base class for all pages.

HiddenStatesPage()
name: str = 'Hidden States'
icon: str = 'grid-3x3'
def get_widget_defaults(self)
75    def get_widget_defaults(self):
76        return {
77            "n_tokens": 1_000,
78            "svd_n_iter": 5,
79            "svd_random_state": 42,
80            "umap_n_neighbors": 15,
81            "umap_metric": "euclidean",
82            "umap_min_dist": 0.1,
83        }

This function holds the default settings for all the page's widgets.

Returns

dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.

def render(self, context: src.subpages.page.Context)
 85    def render(self, context: Context):
 86        st.title("Embeddings")
 87
 88        with st.expander("💡", expanded=True):
 89            st.write(
 90                "For every token in the dataset, we take its hidden state and project it onto a two-dimensional plane. Data points are colored by label/prediction, with mislabeled examples signified by a small black border."
 91            )
 92
 93        col1, _, col2 = st.columns([9 / 32, 1 / 32, 22 / 32])
 94        df = context.df_tokens_merged.copy()
 95        dim_algo = "SVD"
 96        n_tokens = 100
 97
 98        with col1:
 99            st.subheader("Settings")
100            n_tokens = st.slider(
101                "#tokens",
102                key="n_tokens",
103                min_value=100,
104                max_value=len(df["tokens"].unique()),
105                step=100,
106            )
107
108            dim_algo = st.selectbox("Dimensionality reduction algorithm", ["SVD", "PCA", "UMAP"])
109            if dim_algo == "SVD":
110                svd_n_iter = st.slider(
111                    "#iterations",
112                    key="svd_n_iter",
113                    min_value=1,
114                    max_value=10,
115                    step=1,
116                )
117            elif dim_algo == "UMAP":
118                umap_n_neighbors = st.slider(
119                    "#neighbors",
120                    key="umap_n_neighbors",
121                    min_value=2,
122                    max_value=100,
123                    step=1,
124                )
125                umap_min_dist = st.number_input(
126                    "Min distance", key="umap_min_dist", value=0.1, min_value=0.0, max_value=1.0
127                )
128                umap_metric = st.selectbox(
129                    "Metric", ["euclidean", "manhattan", "chebyshev", "minkowski"]
130                )
131            else:
132                pass
133
134        with col2:
135            sents = df.groupby("ids").apply(lambda x: " ".join(x["tokens"].tolist()))
136
137            X = np.array(df["hidden_states"].tolist())
138            transformed_hidden_states = None
139            if dim_algo == "SVD":
140                transformed_hidden_states = reduce_dim_svd(X, n_iter=svd_n_iter)  # type: ignore
141            elif dim_algo == "PCA":
142                transformed_hidden_states = reduce_dim_pca(X)
143            elif dim_algo == "UMAP":
144                transformed_hidden_states = reduce_dim_umap(
145                    X, n_neighbors=umap_n_neighbors, min_dist=umap_min_dist, metric=umap_metric  # type: ignore
146                )
147
148            assert isinstance(transformed_hidden_states, np.ndarray)
149            df["x"] = transformed_hidden_states[:, 0]
150            df["y"] = transformed_hidden_states[:, 1]
151            df["sent0"] = df["ids"].map(lambda x: " ".join(sents[x][0:50].split()))
152            df["sent1"] = df["ids"].map(lambda x: " ".join(sents[x][50:100].split()))
153            df["sent2"] = df["ids"].map(lambda x: " ".join(sents[x][100:150].split()))
154            df["sent3"] = df["ids"].map(lambda x: " ".join(sents[x][150:200].split()))
155            df["sent4"] = df["ids"].map(lambda x: " ".join(sents[x][200:250].split()))
156            df["mislabeled"] = df["labels"] != df["preds"]
157
158            subset = df[:n_tokens]
159            mislabeled_examples_trace = go.Scatter(
160                x=subset[subset["mislabeled"]]["x"],
161                y=subset[subset["mislabeled"]]["y"],
162                mode="markers",
163                marker=dict(
164                    size=6,
165                    color="rgba(0,0,0,0)",
166                    line=dict(width=1),
167                ),
168                hoverinfo="skip",
169            )
170
171            st.subheader("Projection Results")
172
173            fig = px.scatter(
174                subset,
175                x="x",
176                y="y",
177                color="labels",
178                hover_data=["ids", "preds", "sent0", "sent1", "sent2", "sent3", "sent4"],
179                hover_name="tokens",
180                title="Colored by label",
181            )
182            fig.add_trace(mislabeled_examples_trace)
183            st.plotly_chart(fig)
184
185            fig = px.scatter(
186                subset,
187                x="x",
188                y="y",
189                color="preds",
190                hover_data=["ids", "labels", "sent0", "sent1", "sent2", "sent3", "sent4"],
191                hover_name="tokens",
192                title="Colored by prediction",
193            )
194            fig.add_trace(mislabeled_examples_trace)
195            st.plotly_chart(fig)

This function renders the page.