src.subpages.raw_data

See the data as seen by your model.

 1"""See the data as seen by your model."""
 2import pandas as pd
 3import streamlit as st
 4
 5from src.subpages.page import Context, Page
 6from src.utils import aggrid_interactive_table
 7
 8
 9@st.cache
10def convert_df(df):
11    return df.to_csv().encode("utf-8")
12
13
14class RawDataPage(Page):
15    name = "Raw data"
16    icon = "qr-code"
17
18    def render(self, context: Context):
19        st.title(self.name)
20        with st.expander("💡", expanded=True):
21            st.write("See the data as seen by your model.")
22
23        st.subheader("Dataset")
24        st.code(
25            f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
26        )
27
28        st.write("**Data after processing and inference**")
29
30        processed_df = (
31            context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
32        )
33        cols = (
34            "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
35        )
36        if "token_type_ids" not in processed_df.columns:
37            cols.remove("token_type_ids")
38        processed_df = processed_df[cols]
39        aggrid_interactive_table(processed_df)
40        processed_df_csv = convert_df(processed_df)
41        st.download_button(
42            "Download csv",
43            processed_df_csv,
44            "processed_data.csv",
45            "text/csv",
46        )
47
48        st.write("**Raw data (exploded by tokens)**")
49        raw_data_df = context.split.to_pandas().apply(pd.Series.explode)  # type: ignore
50        aggrid_interactive_table(raw_data_df)
51        raw_data_df_csv = convert_df(raw_data_df)
52        st.download_button(
53            "Download csv",
54            raw_data_df_csv,
55            "raw_data.csv",
56            "text/csv",
57        )
@st.cache
def convert_df(df)
10@st.cache
11def convert_df(df):
12    return df.to_csv().encode("utf-8")
class RawDataPage(src.subpages.page.Page):
15class RawDataPage(Page):
16    name = "Raw data"
17    icon = "qr-code"
18
19    def render(self, context: Context):
20        st.title(self.name)
21        with st.expander("💡", expanded=True):
22            st.write("See the data as seen by your model.")
23
24        st.subheader("Dataset")
25        st.code(
26            f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
27        )
28
29        st.write("**Data after processing and inference**")
30
31        processed_df = (
32            context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
33        )
34        cols = (
35            "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
36        )
37        if "token_type_ids" not in processed_df.columns:
38            cols.remove("token_type_ids")
39        processed_df = processed_df[cols]
40        aggrid_interactive_table(processed_df)
41        processed_df_csv = convert_df(processed_df)
42        st.download_button(
43            "Download csv",
44            processed_df_csv,
45            "processed_data.csv",
46            "text/csv",
47        )
48
49        st.write("**Raw data (exploded by tokens)**")
50        raw_data_df = context.split.to_pandas().apply(pd.Series.explode)  # type: ignore
51        aggrid_interactive_table(raw_data_df)
52        raw_data_df_csv = convert_df(raw_data_df)
53        st.download_button(
54            "Download csv",
55            raw_data_df_csv,
56            "raw_data.csv",
57            "text/csv",
58        )

Base class for all pages.

RawDataPage()
name: str = 'Raw data'
icon: str = 'qr-code'
def render(self, context: src.subpages.page.Context)
19    def render(self, context: Context):
20        st.title(self.name)
21        with st.expander("💡", expanded=True):
22            st.write("See the data as seen by your model.")
23
24        st.subheader("Dataset")
25        st.code(
26            f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}"
27        )
28
29        st.write("**Data after processing and inference**")
30
31        processed_df = (
32            context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3)
33        )
34        cols = (
35            "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split()
36        )
37        if "token_type_ids" not in processed_df.columns:
38            cols.remove("token_type_ids")
39        processed_df = processed_df[cols]
40        aggrid_interactive_table(processed_df)
41        processed_df_csv = convert_df(processed_df)
42        st.download_button(
43            "Download csv",
44            processed_df_csv,
45            "processed_data.csv",
46            "text/csv",
47        )
48
49        st.write("**Raw data (exploded by tokens)**")
50        raw_data_df = context.split.to_pandas().apply(pd.Series.explode)  # type: ignore
51        aggrid_interactive_table(raw_data_df)
52        raw_data_df_csv = convert_df(raw_data_df)
53        st.download_button(
54            "Download csv",
55            raw_data_df_csv,
56            "raw_data.csv",
57            "text/csv",
58        )

This function renders the page.