src.subpages.raw_data
See the data as seen by your model.
1"""See the data as seen by your model.""" 2import pandas as pd 3import streamlit as st 4 5from src.subpages.page import Context, Page 6from src.utils import aggrid_interactive_table 7 8 9@st.cache 10def convert_df(df): 11 return df.to_csv().encode("utf-8") 12 13 14class RawDataPage(Page): 15 name = "Raw data" 16 icon = "qr-code" 17 18 def render(self, context: Context): 19 st.title(self.name) 20 with st.expander("💡", expanded=True): 21 st.write("See the data as seen by your model.") 22 23 st.subheader("Dataset") 24 st.code( 25 f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}" 26 ) 27 28 st.write("**Data after processing and inference**") 29 30 processed_df = ( 31 context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3) 32 ) 33 cols = ( 34 "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() 35 ) 36 if "token_type_ids" not in processed_df.columns: 37 cols.remove("token_type_ids") 38 processed_df = processed_df[cols] 39 aggrid_interactive_table(processed_df) 40 processed_df_csv = convert_df(processed_df) 41 st.download_button( 42 "Download csv", 43 processed_df_csv, 44 "processed_data.csv", 45 "text/csv", 46 ) 47 48 st.write("**Raw data (exploded by tokens)**") 49 raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore 50 aggrid_interactive_table(raw_data_df) 51 raw_data_df_csv = convert_df(raw_data_df) 52 st.download_button( 53 "Download csv", 54 raw_data_df_csv, 55 "raw_data.csv", 56 "text/csv", 57 )
@st.cache
def
convert_df(df)
15class RawDataPage(Page): 16 name = "Raw data" 17 icon = "qr-code" 18 19 def render(self, context: Context): 20 st.title(self.name) 21 with st.expander("💡", expanded=True): 22 st.write("See the data as seen by your model.") 23 24 st.subheader("Dataset") 25 st.code( 26 f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}" 27 ) 28 29 st.write("**Data after processing and inference**") 30 31 processed_df = ( 32 context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3) 33 ) 34 cols = ( 35 "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() 36 ) 37 if "token_type_ids" not in processed_df.columns: 38 cols.remove("token_type_ids") 39 processed_df = processed_df[cols] 40 aggrid_interactive_table(processed_df) 41 processed_df_csv = convert_df(processed_df) 42 st.download_button( 43 "Download csv", 44 processed_df_csv, 45 "processed_data.csv", 46 "text/csv", 47 ) 48 49 st.write("**Raw data (exploded by tokens)**") 50 raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore 51 aggrid_interactive_table(raw_data_df) 52 raw_data_df_csv = convert_df(raw_data_df) 53 st.download_button( 54 "Download csv", 55 raw_data_df_csv, 56 "raw_data.csv", 57 "text/csv", 58 )
Base class for all pages.
19 def render(self, context: Context): 20 st.title(self.name) 21 with st.expander("💡", expanded=True): 22 st.write("See the data as seen by your model.") 23 24 st.subheader("Dataset") 25 st.code( 26 f"Dataset: {context.ds_name}\nConfig: {context.ds_config_name}\nSplit: {context.ds_split_name}" 27 ) 28 29 st.write("**Data after processing and inference**") 30 31 processed_df = ( 32 context.df_tokens.drop("hidden_states", axis=1).drop("attention_mask", axis=1).round(3) 33 ) 34 cols = ( 35 "ids input_ids token_type_ids word_ids losses tokens labels preds total_loss".split() 36 ) 37 if "token_type_ids" not in processed_df.columns: 38 cols.remove("token_type_ids") 39 processed_df = processed_df[cols] 40 aggrid_interactive_table(processed_df) 41 processed_df_csv = convert_df(processed_df) 42 st.download_button( 43 "Download csv", 44 processed_df_csv, 45 "processed_data.csv", 46 "text/csv", 47 ) 48 49 st.write("**Raw data (exploded by tokens)**") 50 raw_data_df = context.split.to_pandas().apply(pd.Series.explode) # type: ignore 51 aggrid_interactive_table(raw_data_df) 52 raw_data_df_csv = convert_df(raw_data_df) 53 st.download_button( 54 "Download csv", 55 raw_data_df_csv, 56 "raw_data.csv", 57 "text/csv", 58 )
This function renders the page.