src.subpages.home
1import json 2import random 3from typing import Optional 4 5import streamlit as st 6 7from src.data import get_data 8from src.subpages.page import Context, Page 9from src.utils import PROJ, classmap, color_map_color 10 11_SENTENCE_ENCODER_MODEL = ( 12 "sentence-transformers/all-MiniLM-L6-v2", 13 "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", 14)[0] 15_MODEL_NAME = ( 16 "elastic/distilbert-base-uncased-finetuned-conll03-english", 17 "gagan3012/bert-tiny-finetuned-ner", 18 "socialmediaie/bertweet-base_wnut17_ner", 19 "sberbank-ai/bert-base-NER-reptile-5-datasets", 20 "aseifert/comma-xlm-roberta-base", 21 "dslim/bert-base-NER", 22 "aseifert/distilbert-base-german-cased-comma-derstandard", 23)[0] 24_DATASET_NAME = ( 25 "conll2003", 26 "wnut_17", 27 "aseifert/comma", 28)[0] 29_CONFIG_NAME = ( 30 "conll2003", 31 "wnut_17", 32 "seifertverlag", 33)[0] 34 35 36class HomePage(Page): 37 name = "Home / Setup" 38 icon = "house" 39 40 def get_widget_defaults(self): 41 return { 42 "encoder_model_name": _SENTENCE_ENCODER_MODEL, 43 "model_name": _MODEL_NAME, 44 "ds_name": _DATASET_NAME, 45 "ds_split_name": "validation", 46 "ds_config_name": _CONFIG_NAME, 47 "split_sample_size": 512, 48 } 49 50 def render(self, context: Optional[Context] = None): 51 st.title("ExplaiNER") 52 53 with st.expander("💡", expanded=True): 54 st.write( 55 "**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further **improving both model AND dataset**." 56 ) 57 st.write( 58 "**Note:** This Space requires a fair amount of computation, so please be patient with the loading animations. 🙏 I am caching as much as possible, so after the first wait most things should be precomputed." 59 ) 60 st.write( 61 "_Caveat: Even though everything is customizable here, I haven't tested this app much with different models/datasets._" 62 ) 63 64 col1, _, col2a, col2b = st.columns([0.8, 0.05, 0.15, 0.15]) 65 66 with col1: 67 random_form_key = f"settings-{random.randint(0, 100000)}" 68 # FIXME: for some reason I'm getting the following error if I don't randomize the key: 69 """ 70 2022-05-05 20:37:16.507 Traceback (most recent call last): 71 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/scriptrunner/script_runner.py", line 443, in _run_script 72 exec(code, module.__dict__) 73 File "/Users/zoro/code/error-analysis/main.py", line 162, in <module> 74 main() 75 File "/Users/zoro/code/error-analysis/main.py", line 102, in main 76 show_setup() 77 File "/Users/zoro/code/error-analysis/section/setup.py", line 68, in show_setup 78 st.form_submit_button("Load Model & Data") 79 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 240, in form_submit_button 80 return self._form_submit_button( 81 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 260, in _form_submit_button 82 return self.dg._button( 83 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/button.py", line 304, in _button 84 check_session_state_rules(default_value=None, key=key, writes_allowed=False) 85 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/utils.py", line 74, in check_session_state_rules 86 raise StreamlitAPIException( 87 streamlit.errors.StreamlitAPIException: Values for st.button, st.download_button, st.file_uploader, and st.form cannot be set using st.session_state. 88 """ 89 with st.form(key=random_form_key): 90 st.subheader("Model & Data Selection") 91 st.text_input( 92 label="NER Model:", 93 key="model_name", 94 help="Path or name of the model to use", 95 ) 96 st.text_input( 97 label="Encoder Model:", 98 key="encoder_model_name", 99 help="Path or name of the encoder to use for duplicate detection", 100 ) 101 ds_name = st.text_input( 102 label="Dataset:", 103 key="ds_name", 104 help="Path or name of the dataset to use", 105 ) 106 ds_config_name = st.text_input( 107 label="Config (optional):", 108 key="ds_config_name", 109 ) 110 ds_split_name = st.selectbox( 111 label="Split:", 112 options=["train", "validation", "test"], 113 key="ds_split_name", 114 ) 115 split_sample_size = st.number_input( 116 "Sample size:", 117 step=16, 118 key="split_sample_size", 119 help="Sample size for the split, speeds up processing inside streamlit", 120 ) 121 # breakpoint() 122 # st.form_submit_button("Submit") 123 st.form_submit_button("Load Model & Data") 124 125 split = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size) 126 labels = list( 127 set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"]) 128 ) 129 130 with col2a: 131 st.subheader("Classes") 132 st.write("**Color**") 133 colors = {label: color_map_color(i / len(labels)) for i, label in enumerate(labels)} 134 for label in labels: 135 if f"color_{label}" not in st.session_state: 136 st.session_state[f"color_{label}"] = colors[label] 137 st.color_picker(label, key=f"color_{label}") 138 with col2b: 139 st.subheader("—") 140 st.write("**Icon**") 141 emojis = list(json.load(open(PROJ / "subpages/emoji-en-US.json")).keys()) 142 for label in labels: 143 if f"icon_{label}" not in st.session_state: 144 st.session_state[f"icon_{label}"] = classmap[label] 145 st.selectbox(label, key=f"icon_{label}", options=emojis) 146 classmap[label] = st.session_state[f"icon_{label}"] 147 148 # if st.button("Reset to defaults"): 149 # st.session_state.update(**get_home_page_defaults()) 150 # # time.sleep 2 secs 151 # import time 152 # time.sleep(1) 153 154 # # st.legacy_caching.clear_cache() 155 # st.experimental_rerun()
37class HomePage(Page): 38 name = "Home / Setup" 39 icon = "house" 40 41 def get_widget_defaults(self): 42 return { 43 "encoder_model_name": _SENTENCE_ENCODER_MODEL, 44 "model_name": _MODEL_NAME, 45 "ds_name": _DATASET_NAME, 46 "ds_split_name": "validation", 47 "ds_config_name": _CONFIG_NAME, 48 "split_sample_size": 512, 49 } 50 51 def render(self, context: Optional[Context] = None): 52 st.title("ExplaiNER") 53 54 with st.expander("💡", expanded=True): 55 st.write( 56 "**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further **improving both model AND dataset**." 57 ) 58 st.write( 59 "**Note:** This Space requires a fair amount of computation, so please be patient with the loading animations. 🙏 I am caching as much as possible, so after the first wait most things should be precomputed." 60 ) 61 st.write( 62 "_Caveat: Even though everything is customizable here, I haven't tested this app much with different models/datasets._" 63 ) 64 65 col1, _, col2a, col2b = st.columns([0.8, 0.05, 0.15, 0.15]) 66 67 with col1: 68 random_form_key = f"settings-{random.randint(0, 100000)}" 69 # FIXME: for some reason I'm getting the following error if I don't randomize the key: 70 """ 71 2022-05-05 20:37:16.507 Traceback (most recent call last): 72 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/scriptrunner/script_runner.py", line 443, in _run_script 73 exec(code, module.__dict__) 74 File "/Users/zoro/code/error-analysis/main.py", line 162, in <module> 75 main() 76 File "/Users/zoro/code/error-analysis/main.py", line 102, in main 77 show_setup() 78 File "/Users/zoro/code/error-analysis/section/setup.py", line 68, in show_setup 79 st.form_submit_button("Load Model & Data") 80 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 240, in form_submit_button 81 return self._form_submit_button( 82 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 260, in _form_submit_button 83 return self.dg._button( 84 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/button.py", line 304, in _button 85 check_session_state_rules(default_value=None, key=key, writes_allowed=False) 86 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/utils.py", line 74, in check_session_state_rules 87 raise StreamlitAPIException( 88 streamlit.errors.StreamlitAPIException: Values for st.button, st.download_button, st.file_uploader, and st.form cannot be set using st.session_state. 89 """ 90 with st.form(key=random_form_key): 91 st.subheader("Model & Data Selection") 92 st.text_input( 93 label="NER Model:", 94 key="model_name", 95 help="Path or name of the model to use", 96 ) 97 st.text_input( 98 label="Encoder Model:", 99 key="encoder_model_name", 100 help="Path or name of the encoder to use for duplicate detection", 101 ) 102 ds_name = st.text_input( 103 label="Dataset:", 104 key="ds_name", 105 help="Path or name of the dataset to use", 106 ) 107 ds_config_name = st.text_input( 108 label="Config (optional):", 109 key="ds_config_name", 110 ) 111 ds_split_name = st.selectbox( 112 label="Split:", 113 options=["train", "validation", "test"], 114 key="ds_split_name", 115 ) 116 split_sample_size = st.number_input( 117 "Sample size:", 118 step=16, 119 key="split_sample_size", 120 help="Sample size for the split, speeds up processing inside streamlit", 121 ) 122 # breakpoint() 123 # st.form_submit_button("Submit") 124 st.form_submit_button("Load Model & Data") 125 126 split = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size) 127 labels = list( 128 set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"]) 129 ) 130 131 with col2a: 132 st.subheader("Classes") 133 st.write("**Color**") 134 colors = {label: color_map_color(i / len(labels)) for i, label in enumerate(labels)} 135 for label in labels: 136 if f"color_{label}" not in st.session_state: 137 st.session_state[f"color_{label}"] = colors[label] 138 st.color_picker(label, key=f"color_{label}") 139 with col2b: 140 st.subheader("—") 141 st.write("**Icon**") 142 emojis = list(json.load(open(PROJ / "subpages/emoji-en-US.json")).keys()) 143 for label in labels: 144 if f"icon_{label}" not in st.session_state: 145 st.session_state[f"icon_{label}"] = classmap[label] 146 st.selectbox(label, key=f"icon_{label}", options=emojis) 147 classmap[label] = st.session_state[f"icon_{label}"] 148 149 # if st.button("Reset to defaults"): 150 # st.session_state.update(**get_home_page_defaults()) 151 # # time.sleep 2 secs 152 # import time 153 # time.sleep(1) 154 155 # # st.legacy_caching.clear_cache() 156 # st.experimental_rerun()
Base class for all pages.
def
get_widget_defaults(self)
41 def get_widget_defaults(self): 42 return { 43 "encoder_model_name": _SENTENCE_ENCODER_MODEL, 44 "model_name": _MODEL_NAME, 45 "ds_name": _DATASET_NAME, 46 "ds_split_name": "validation", 47 "ds_config_name": _CONFIG_NAME, 48 "split_sample_size": 512, 49 }
This function holds the default settings for all the page's widgets.
Returns
dict: A dictionary of widget defaults, where the keys are the widget names and the values are the default.
51 def render(self, context: Optional[Context] = None): 52 st.title("ExplaiNER") 53 54 with st.expander("💡", expanded=True): 55 st.write( 56 "**Error Analysis is an important but often overlooked part of the data science project lifecycle**, for which there is still very little tooling available. Practitioners tend to write throwaway code or, worse, skip this crucial step of understanding their models' errors altogether. This project tries to provide an **extensive toolkit to probe any NER model/dataset combination**, find labeling errors and understand the models' and datasets' limitations, leading the user on her way to further **improving both model AND dataset**." 57 ) 58 st.write( 59 "**Note:** This Space requires a fair amount of computation, so please be patient with the loading animations. 🙏 I am caching as much as possible, so after the first wait most things should be precomputed." 60 ) 61 st.write( 62 "_Caveat: Even though everything is customizable here, I haven't tested this app much with different models/datasets._" 63 ) 64 65 col1, _, col2a, col2b = st.columns([0.8, 0.05, 0.15, 0.15]) 66 67 with col1: 68 random_form_key = f"settings-{random.randint(0, 100000)}" 69 # FIXME: for some reason I'm getting the following error if I don't randomize the key: 70 """ 71 2022-05-05 20:37:16.507 Traceback (most recent call last): 72 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/scriptrunner/script_runner.py", line 443, in _run_script 73 exec(code, module.__dict__) 74 File "/Users/zoro/code/error-analysis/main.py", line 162, in <module> 75 main() 76 File "/Users/zoro/code/error-analysis/main.py", line 102, in main 77 show_setup() 78 File "/Users/zoro/code/error-analysis/section/setup.py", line 68, in show_setup 79 st.form_submit_button("Load Model & Data") 80 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 240, in form_submit_button 81 return self._form_submit_button( 82 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/form.py", line 260, in _form_submit_button 83 return self.dg._button( 84 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/button.py", line 304, in _button 85 check_session_state_rules(default_value=None, key=key, writes_allowed=False) 86 File "/Users/zoro/mambaforge/lib/python3.9/site-packages/streamlit/elements/utils.py", line 74, in check_session_state_rules 87 raise StreamlitAPIException( 88 streamlit.errors.StreamlitAPIException: Values for st.button, st.download_button, st.file_uploader, and st.form cannot be set using st.session_state. 89 """ 90 with st.form(key=random_form_key): 91 st.subheader("Model & Data Selection") 92 st.text_input( 93 label="NER Model:", 94 key="model_name", 95 help="Path or name of the model to use", 96 ) 97 st.text_input( 98 label="Encoder Model:", 99 key="encoder_model_name", 100 help="Path or name of the encoder to use for duplicate detection", 101 ) 102 ds_name = st.text_input( 103 label="Dataset:", 104 key="ds_name", 105 help="Path or name of the dataset to use", 106 ) 107 ds_config_name = st.text_input( 108 label="Config (optional):", 109 key="ds_config_name", 110 ) 111 ds_split_name = st.selectbox( 112 label="Split:", 113 options=["train", "validation", "test"], 114 key="ds_split_name", 115 ) 116 split_sample_size = st.number_input( 117 "Sample size:", 118 step=16, 119 key="split_sample_size", 120 help="Sample size for the split, speeds up processing inside streamlit", 121 ) 122 # breakpoint() 123 # st.form_submit_button("Submit") 124 st.form_submit_button("Load Model & Data") 125 126 split = get_data(ds_name, ds_config_name, ds_split_name, split_sample_size) 127 labels = list( 128 set([n.split("-")[1] for n in split.features["ner_tags"].feature.names if n != "O"]) 129 ) 130 131 with col2a: 132 st.subheader("Classes") 133 st.write("**Color**") 134 colors = {label: color_map_color(i / len(labels)) for i, label in enumerate(labels)} 135 for label in labels: 136 if f"color_{label}" not in st.session_state: 137 st.session_state[f"color_{label}"] = colors[label] 138 st.color_picker(label, key=f"color_{label}") 139 with col2b: 140 st.subheader("—") 141 st.write("**Icon**") 142 emojis = list(json.load(open(PROJ / "subpages/emoji-en-US.json")).keys()) 143 for label in labels: 144 if f"icon_{label}" not in st.session_state: 145 st.session_state[f"icon_{label}"] = classmap[label] 146 st.selectbox(label, key=f"icon_{label}", options=emojis) 147 classmap[label] = st.session_state[f"icon_{label}"] 148 149 # if st.button("Reset to defaults"): 150 # st.session_state.update(**get_home_page_defaults()) 151 # # time.sleep 2 secs 152 # import time 153 # time.sleep(1) 154 155 # # st.legacy_caching.clear_cache() 156 # st.experimental_rerun()
This function renders the page.