Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
a0471c4
1
Parent(s):
1f8519e
model type
Browse files
app.py
CHANGED
|
@@ -47,10 +47,10 @@ def load_css(file_name):
|
|
| 47 |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
| 48 |
|
| 49 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
| 50 |
-
def load_model():
|
| 51 |
-
tokenizer = AlbertTokenizer.from_pretrained(
|
| 52 |
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
| 53 |
-
model = AlbertForMaskedLM.from_pretrained(
|
| 54 |
return tokenizer,model
|
| 55 |
|
| 56 |
def clear_data():
|
|
@@ -167,14 +167,23 @@ def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_i
|
|
| 167 |
if __name__=='__main__':
|
| 168 |
wide_setup()
|
| 169 |
load_css('style.css')
|
| 170 |
-
tokenizer,model = load_model()
|
| 171 |
-
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
| 172 |
-
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
| 173 |
-
|
| 174 |
-
main_area = st.empty()
|
| 175 |
|
| 176 |
if 'page_status' not in st.session_state:
|
| 177 |
-
st.session_state['page_status'] = '
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
if st.session_state['page_status']=='type_in':
|
| 180 |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
|
|
|
|
| 47 |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
|
| 48 |
|
| 49 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
| 50 |
+
def load_model(model_name):
|
| 51 |
+
tokenizer = AlbertTokenizer.from_pretrained(model_name)
|
| 52 |
#model = CustomFlaxAlbertForMaskedLM.from_pretrained('albert-xxlarge-v2',from_pt=True)
|
| 53 |
+
model = AlbertForMaskedLM.from_pretrained(model_name)
|
| 54 |
return tokenizer,model
|
| 55 |
|
| 56 |
def clear_data():
|
|
|
|
| 167 |
if __name__=='__main__':
|
| 168 |
wide_setup()
|
| 169 |
load_css('style.css')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
if 'page_status' not in st.session_state:
|
| 172 |
+
st.session_state['page_status'] = 'model_selection'
|
| 173 |
+
|
| 174 |
+
if st.session_state['page_status']=='model_selection':
|
| 175 |
+
model_name = st.selectbox('Please select the model from below.',
|
| 176 |
+
('bert-base-uncased','bert-large-cased',
|
| 177 |
+
'roberta-base','roberta-large',
|
| 178 |
+
'albert-base-v2','albert-large-v2','albert-xlarge-v2','albert-xxlarge-v2'),index=3)
|
| 179 |
+
st.sesstion_state['model_name'] = model_name
|
| 180 |
+
if st.button('Confirm',key='model_name'):
|
| 181 |
+
st.session_state['page_status'] = 'type_in'
|
| 182 |
+
st.experimental_rerun()
|
| 183 |
+
|
| 184 |
+
tokenizer,model = load_model(st.session_state['model_name'])
|
| 185 |
+
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
| 186 |
+
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
| 187 |
|
| 188 |
if st.session_state['page_status']=='type_in':
|
| 189 |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
|