src.subpages.random_samples
Show random samples. Simple method, but it often turns up interesting things.
1"""Show random samples. Simple method, but it often turns up interesting things.""" 2import pandas as pd 3import streamlit as st 4 5from src.subpages.page import Context, Page 6from src.utils import htmlify_labeled_example 7 8 9class RandomSamplesPage(Page): 10 name = "Random Samples" 11 icon = "shuffle" 12 13 def get_widget_defaults(self): 14 return { 15 "random_sample_size_min": 128, 16 } 17 18 def render(self, context: Context): 19 st.title("🎲 Random Samples") 20 with st.expander("💡", expanded=True): 21 st.write( 22 "Show random samples. Simple method, but it often turns up interesting things." 23 ) 24 25 random_sample_size = st.number_input( 26 "Random sample size:", 27 value=min(st.session_state.random_sample_size_min, context.split_sample_size), 28 step=16, 29 key="random_sample_size", 30 ) 31 32 if st.button("🎲 Resample"): 33 st.experimental_rerun() 34 35 random_indices = context.df.sample(int(random_sample_size)).index 36 samples = context.df_tokens_merged.loc[random_indices] 37 38 for i, idx in enumerate(random_indices): 39 sample = samples.loc[idx] 40 41 if isinstance(sample, pd.Series): 42 continue 43 44 col1, _, col2 = st.columns([0.08, 0.025, 0.8]) 45 46 counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: wh^; padding: 0 5px'>[{i+1} | {idx}]</span>" 47 loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>" 48 col1.write(f"{counter}{loss}", unsafe_allow_html=True) 49 col1.write("") 50 col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
10class RandomSamplesPage(Page): 11 name = "Random Samples" 12 icon = "shuffle" 13 14 def get_widget_defaults(self): 15 return { 16 "random_sample_size_min": 128, 17 } 18 19 def render(self, context: Context): 20 st.title("🎲 Random Samples") 21 with st.expander("💡", expanded=True): 22 st.write( 23 "Show random samples. Simple method, but it often turns up interesting things." 24 ) 25 26 random_sample_size = st.number_input( 27 "Random sample size:", 28 value=min(st.session_state.random_sample_size_min, context.split_sample_size), 29 step=16, 30 key="random_sample_size", 31 ) 32 33 if st.button("🎲 Resample"): 34 st.experimental_rerun() 35 36 random_indices = context.df.sample(int(random_sample_size)).index 37 samples = context.df_tokens_merged.loc[random_indices] 38 39 for i, idx in enumerate(random_indices): 40 sample = samples.loc[idx] 41 42 if isinstance(sample, pd.Series): 43 continue 44 45 col1, _, col2 = st.columns([0.08, 0.025, 0.8]) 46 47 counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: wh^; padding: 0 5px'>[{i+1} | {idx}]</span>" 48 loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>" 49 col1.write(f"{counter}{loss}", unsafe_allow_html=True) 50 col1.write("") 51 col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
Base class for all pages.
def
get_widget_defaults(self)
This function holds the default settings for all the page's widgets.
Returns
dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.
19 def render(self, context: Context): 20 st.title("🎲 Random Samples") 21 with st.expander("💡", expanded=True): 22 st.write( 23 "Show random samples. Simple method, but it often turns up interesting things." 24 ) 25 26 random_sample_size = st.number_input( 27 "Random sample size:", 28 value=min(st.session_state.random_sample_size_min, context.split_sample_size), 29 step=16, 30 key="random_sample_size", 31 ) 32 33 if st.button("🎲 Resample"): 34 st.experimental_rerun() 35 36 random_indices = context.df.sample(int(random_sample_size)).index 37 samples = context.df_tokens_merged.loc[random_indices] 38 39 for i, idx in enumerate(random_indices): 40 sample = samples.loc[idx] 41 42 if isinstance(sample, pd.Series): 43 continue 44 45 col1, _, col2 = st.columns([0.08, 0.025, 0.8]) 46 47 counter = f"<span title='#sample | index' style='display: block; background-color: black; opacity: 1; color: wh^; padding: 0 5px'>[{i+1} | {idx}]</span>" 48 loss = f"<span title='total loss' style='display: block; background-color: yellow; color: gray; padding: 0 5px;'>𝐿 {sample.losses.sum():.3f}</span>" 49 col1.write(f"{counter}{loss}", unsafe_allow_html=True) 50 col1.write("") 51 col2.write(htmlify_labeled_example(sample), unsafe_allow_html=True)
This function renders the page.