Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
77d2a77
1
Parent(s):
3052d18
debug
Browse files
app.py
CHANGED
|
@@ -117,15 +117,13 @@ def show_instruction(sent,fontsize=20):
|
|
| 117 |
suffix = '</span></p>'
|
| 118 |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
|
| 119 |
|
| 120 |
-
def create_interventions(token_id,
|
| 121 |
interventions = {}
|
| 122 |
-
for
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
for rep in ['lay','qry','key','val']:
|
| 126 |
-
interventions[layer_id][rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
|
| 127 |
else:
|
| 128 |
-
interventions[
|
| 129 |
return interventions
|
| 130 |
|
| 131 |
def separate_options(option_locs):
|
|
@@ -195,7 +193,7 @@ if __name__=='__main__':
|
|
| 195 |
mask_locs=st.session_state['mask_locs_2'])
|
| 196 |
|
| 197 |
option_1_locs, option_2_locs = {}, {}
|
| 198 |
-
|
| 199 |
input_ids_dict = {}
|
| 200 |
masked_ids_option_1 = {}
|
| 201 |
masked_ids_option_2 = {}
|
|
@@ -215,14 +213,15 @@ if __name__=='__main__':
|
|
| 215 |
st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
|
| 216 |
|
| 217 |
if st.session_state['page_status'] == 'finish_debug':
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
| 226 |
|
| 227 |
|
| 228 |
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
|
|
|
|
| 117 |
suffix = '</span></p>'
|
| 118 |
return st.markdown(prefix + sent + suffix, unsafe_allow_html = True)
|
| 119 |
|
| 120 |
+
def create_interventions(token_id,interv_types,num_heads):
|
| 121 |
interventions = {}
|
| 122 |
+
for rep in ['lay','qry','key','val']:
|
| 123 |
+
if rep in interv_types:
|
| 124 |
+
interventions[rep] = [(head_id,token_id,[head_id,head_id+num_heads]) for head_id in range(num_heads)]
|
|
|
|
|
|
|
| 125 |
else:
|
| 126 |
+
interventions[rep] = []
|
| 127 |
return interventions
|
| 128 |
|
| 129 |
def separate_options(option_locs):
|
|
|
|
| 193 |
mask_locs=st.session_state['mask_locs_2'])
|
| 194 |
|
| 195 |
option_1_locs, option_2_locs = {}, {}
|
| 196 |
+
pron_locs = {}
|
| 197 |
input_ids_dict = {}
|
| 198 |
masked_ids_option_1 = {}
|
| 199 |
masked_ids_option_2 = {}
|
|
|
|
| 213 |
st.write(' '.join([tokenizer.decode([token]) for toke in token_ids]))
|
| 214 |
|
| 215 |
if st.session_state['page_status'] == 'finish_debug':
|
| 216 |
+
for layer_id in range(num_layers):
|
| 217 |
+
interventions = [create_interventions(16,['lay','qry','key','val'],num_heads) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
| 218 |
+
for masked_ids in [masked_ids_option_1, masked_ids_option_2]:
|
| 219 |
+
input_ids = torch.tensor([
|
| 220 |
+
*[masked_ids['sent_1'] for _ in range(num_heads)],
|
| 221 |
+
*[masked_ids['sent_2'] for _ in range(num_heads)]
|
| 222 |
+
])
|
| 223 |
+
outputs = SkeletonAlbertForMaskedLM(model,input_ids,interventions=interventions)
|
| 224 |
+
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
| 225 |
|
| 226 |
|
| 227 |
preds_0 = [torch.multinomial(torch.exp(probs), num_samples=1).squeeze(dim=-1) for probs in logprobs[0][1:-1]]
|