Spaces:
Runtime error
Runtime error
taka-yamakoshi
commited on
Commit
·
ebfe870
1
Parent(s):
5958ae4
add model options
Browse files- app.py +26 -15
- skeleton_modeling_bert.py +73 -0
- skeleton_modeling_roberta.py +73 -0
app.py
CHANGED
|
@@ -10,10 +10,7 @@ import seaborn as sns
|
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
| 12 |
|
| 13 |
-
from transformers import AlbertTokenizer, AlbertForMaskedLM
|
| 14 |
-
|
| 15 |
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
|
| 16 |
-
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
|
| 17 |
|
| 18 |
def wide_setup():
|
| 19 |
max_width = 1500
|
|
@@ -48,10 +45,23 @@ def load_css(file_name):
|
|
| 48 |
|
| 49 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
| 50 |
def load_model(model_name):
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def clear_data():
|
| 57 |
for key in st.session_state:
|
|
@@ -147,14 +157,14 @@ def mask_out(input_ids,pron_locs,option_locs,mask_id):
|
|
| 147 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
| 148 |
|
| 149 |
|
| 150 |
-
def run_intervention(interventions,batch_size,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
| 151 |
probs = []
|
| 152 |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
|
| 153 |
input_ids = torch.tensor([
|
| 154 |
*[masked_ids['sent_1'] for _ in range(batch_size)],
|
| 155 |
*[masked_ids['sent_2'] for _ in range(batch_size)]
|
| 156 |
])
|
| 157 |
-
outputs =
|
| 158 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
| 159 |
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
|
| 160 |
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
|
|
@@ -181,9 +191,10 @@ if __name__=='__main__':
|
|
| 181 |
st.session_state['page_status'] = 'type_in'
|
| 182 |
st.experimental_rerun()
|
| 183 |
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
| 187 |
|
| 188 |
if st.session_state['page_status']=='type_in':
|
| 189 |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
|
|
@@ -263,7 +274,7 @@ if __name__=='__main__':
|
|
| 263 |
option_2_tokens = option_2_tokens_1
|
| 264 |
|
| 265 |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
| 266 |
-
probs_original = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
| 267 |
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
|
| 268 |
[probs_original[0,1][0],probs_original[1,1][0]]],
|
| 269 |
columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
|
|
@@ -292,9 +303,9 @@ if __name__=='__main__':
|
|
| 292 |
for layer_id in range(num_layers):
|
| 293 |
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
| 294 |
if multihead:
|
| 295 |
-
probs = run_intervention(interventions,1,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
| 296 |
else:
|
| 297 |
-
probs = run_intervention(interventions,num_heads,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
| 298 |
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
|
| 299 |
effect_list.append(effect)
|
| 300 |
effect_array.append(effect_list)
|
|
|
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
| 12 |
|
|
|
|
|
|
|
| 13 |
#from custom_modeling_albert_flax import CustomFlaxAlbertForMaskedLM
|
|
|
|
| 14 |
|
| 15 |
def wide_setup():
|
| 16 |
max_width = 1500
|
|
|
|
| 45 |
|
| 46 |
@st.cache(show_spinner=True,allow_output_mutation=True)
|
| 47 |
def load_model(model_name):
|
| 48 |
+
if model_name.startswith('albert'):
|
| 49 |
+
from transformers import AlbertTokenizer, AlbertForMaskedLM
|
| 50 |
+
from skeleton_modeling_albert import SkeletonAlbertForMaskedLM
|
| 51 |
+
tokenizer = AlbertTokenizer.from_pretrained(model_name)
|
| 52 |
+
model = AlbertForMaskedLM.from_pretrained(model_name)
|
| 53 |
+
skeleton_model = SkeletonAlbertForMaskedLM
|
| 54 |
+
elif model_name.startswith('bert'):
|
| 55 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
| 56 |
+
from skeleton_modeling_bert import SkeletonBertForMaskedLM
|
| 57 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 58 |
+
model = BertForMaskedLM.from_pretrained(model_name)
|
| 59 |
+
elif model_name.startswith('roberta'):
|
| 60 |
+
from transformers import RobertaTokenizer, RobertaForMaskedLM
|
| 61 |
+
from skeleton_modeling_roberta import SkeletonRobertaForMaskedLM
|
| 62 |
+
tokenizer = RobertaTokenizer.from_pretrained(model_name)
|
| 63 |
+
model = RobertaForMaskedLM.from_pretrained(model_name)
|
| 64 |
+
return tokenizer,model,skeleton_model
|
| 65 |
|
| 66 |
def clear_data():
|
| 67 |
for key in st.session_state:
|
|
|
|
| 157 |
return input_ids[:pron_locs[0]+1] + [mask_id for _ in range(len(option_locs))] + input_ids[pron_locs[-1]+2:]
|
| 158 |
|
| 159 |
|
| 160 |
+
def run_intervention(interventions,batch_size,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs):
|
| 161 |
probs = []
|
| 162 |
for masked_ids, option_tokens in zip([masked_ids_option_1, masked_ids_option_2],[option_1_tokens,option_2_tokens]):
|
| 163 |
input_ids = torch.tensor([
|
| 164 |
*[masked_ids['sent_1'] for _ in range(batch_size)],
|
| 165 |
*[masked_ids['sent_2'] for _ in range(batch_size)]
|
| 166 |
])
|
| 167 |
+
outputs = skeleton_model(model,input_ids,interventions=interventions)
|
| 168 |
logprobs = F.log_softmax(outputs['logits'], dim = -1)
|
| 169 |
logprobs_1, logprobs_2 = logprobs[:batch_size], logprobs[batch_size:]
|
| 170 |
evals_1 = [logprobs_1[:,pron_locs['sent_1'][0]+1+i,token].numpy() for i,token in enumerate(option_tokens)]
|
|
|
|
| 191 |
st.session_state['page_status'] = 'type_in'
|
| 192 |
st.experimental_rerun()
|
| 193 |
|
| 194 |
+
if st.session_state['page_status']!='model_selection':
|
| 195 |
+
tokenizer,model,skeleton_model = load_model(st.session_state['model_name'])
|
| 196 |
+
num_layers, num_heads = model.config.num_hidden_layers, model.config.num_attention_heads
|
| 197 |
+
mask_id = tokenizer('[MASK]').input_ids[1:-1][0]
|
| 198 |
|
| 199 |
if st.session_state['page_status']=='type_in':
|
| 200 |
show_instruction('1. Type in the sentences and click "Tokenize"',fontsize=16)
|
|
|
|
| 274 |
option_2_tokens = option_2_tokens_1
|
| 275 |
|
| 276 |
interventions = [{'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
| 277 |
+
probs_original = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
| 278 |
df = pd.DataFrame(data=[[probs_original[0,0][0],probs_original[1,0][0]],
|
| 279 |
[probs_original[0,1][0],probs_original[1,1][0]]],
|
| 280 |
columns=[tokenizer.decode(option_1_tokens),tokenizer.decode(option_2_tokens)],
|
|
|
|
| 303 |
for layer_id in range(num_layers):
|
| 304 |
interventions = [create_interventions(token_id,['lay','qry','key','val'],num_heads,multihead) if i==layer_id else {'lay':[],'qry':[],'key':[],'val':[]} for i in range(num_layers)]
|
| 305 |
if multihead:
|
| 306 |
+
probs = run_intervention(interventions,1,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
| 307 |
else:
|
| 308 |
+
probs = run_intervention(interventions,num_heads,skeleton_model,model,masked_ids_option_1,masked_ids_option_2,option_1_tokens,option_2_tokens,pron_locs)
|
| 309 |
effect = ((probs_original-probs)[0,0] + (probs_original-probs)[1,1] + (probs-probs_original)[0,1] + (probs-probs_original)[1,0])/4
|
| 310 |
effect_list.append(effect)
|
| 311 |
effect_array.append(effect_list)
|
skeleton_modeling_bert.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
@torch.no_grad()
|
| 7 |
+
def SkeletonBertLayer(layer_id,layer,hidden,interventions):
|
| 8 |
+
attention_layer = layer.attention.self
|
| 9 |
+
num_heads = attention_layer.num_attention_heads
|
| 10 |
+
head_dim = attention_layer.attention_head_size
|
| 11 |
+
assert num_heads*head_dim == hidden.shape[2]
|
| 12 |
+
|
| 13 |
+
qry = attention_layer.query(hidden)
|
| 14 |
+
key = attention_layer.key(hidden)
|
| 15 |
+
val = attention_layer.value(hidden)
|
| 16 |
+
|
| 17 |
+
assert qry.shape == hidden.shape
|
| 18 |
+
assert key.shape == hidden.shape
|
| 19 |
+
assert val.shape == hidden.shape
|
| 20 |
+
|
| 21 |
+
# swap representations
|
| 22 |
+
reps = {
|
| 23 |
+
'lay': hidden,
|
| 24 |
+
'qry': qry,
|
| 25 |
+
'key': key,
|
| 26 |
+
'val': val,
|
| 27 |
+
}
|
| 28 |
+
for rep_type in ['lay','qry','key','val']:
|
| 29 |
+
interv_rep = interventions[layer_id][rep_type]
|
| 30 |
+
new_state = reps[rep_type].clone()
|
| 31 |
+
for head_id, pos, swap_ids in interv_rep:
|
| 32 |
+
new_state[swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
|
| 33 |
+
new_state[swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
|
| 34 |
+
reps[rep_type] = new_state.clone()
|
| 35 |
+
|
| 36 |
+
hidden = reps['lay'].clone()
|
| 37 |
+
qry = reps['qry'].clone()
|
| 38 |
+
key = reps['key'].clone()
|
| 39 |
+
val = reps['val'].clone()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
#split into multiple heads
|
| 43 |
+
split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
| 44 |
+
split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
| 45 |
+
split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
| 46 |
+
|
| 47 |
+
#calculate the attention matrix
|
| 48 |
+
attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
|
| 49 |
+
|
| 50 |
+
z_rep_indiv = attn_mat@split_val
|
| 51 |
+
z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
|
| 52 |
+
|
| 53 |
+
hidden_post_attn_res = layer.attention.output.dense(z_rep)+hidden # residual connection
|
| 54 |
+
hidden_post_attn = layer.attention.output.LayerNorm(hidden_post_attn_res) # layer_norm
|
| 55 |
+
|
| 56 |
+
hidden_post_interm = layer.intermediate(hidden_post_attn) # massive feed forward
|
| 57 |
+
hidden_post_interm_res = layer.output.dense(hidden_post_interm)+hidden_post_attn # residual connection
|
| 58 |
+
new_hidden = layer.output.LayerNorm(hidden_post_interm_res) # layer_norm
|
| 59 |
+
return new_hidden
|
| 60 |
+
|
| 61 |
+
def SkeletonBertForMaskedLM(model,input_ids,interventions):
|
| 62 |
+
core_model = model.bert
|
| 63 |
+
lm_head = model.cls
|
| 64 |
+
output_hidden = []
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
hidden = core_model.embeddings(input_ids)
|
| 67 |
+
output_hidden.append(hidden)
|
| 68 |
+
for layer_id in range(model.config.num_hidden_layers):
|
| 69 |
+
layer = core_model.encoder.layer[layer_id]
|
| 70 |
+
hidden = SkeletonBertLayer(layer_id,layer,hidden,interventions)
|
| 71 |
+
output_hidden.append(hidden)
|
| 72 |
+
logits = lm_head(hidden)
|
| 73 |
+
return {'logits':logits,'hidden_states':output_hidden}
|
skeleton_modeling_roberta.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
@torch.no_grad()
|
| 7 |
+
def SkeletonRobertaLayer(layer_id,layer,hidden,interventions):
|
| 8 |
+
attention_layer = layer.attention.self
|
| 9 |
+
num_heads = attention_layer.num_attention_heads
|
| 10 |
+
head_dim = attention_layer.attention_head_size
|
| 11 |
+
assert num_heads*head_dim == hidden.shape[2]
|
| 12 |
+
|
| 13 |
+
qry = attention_layer.query(hidden)
|
| 14 |
+
key = attention_layer.key(hidden)
|
| 15 |
+
val = attention_layer.value(hidden)
|
| 16 |
+
|
| 17 |
+
assert qry.shape == hidden.shape
|
| 18 |
+
assert key.shape == hidden.shape
|
| 19 |
+
assert val.shape == hidden.shape
|
| 20 |
+
|
| 21 |
+
# swap representations
|
| 22 |
+
reps = {
|
| 23 |
+
'lay': hidden,
|
| 24 |
+
'qry': qry,
|
| 25 |
+
'key': key,
|
| 26 |
+
'val': val,
|
| 27 |
+
}
|
| 28 |
+
for rep_type in ['lay','qry','key','val']:
|
| 29 |
+
interv_rep = interventions[layer_id][rep_type]
|
| 30 |
+
new_state = reps[rep_type].clone()
|
| 31 |
+
for head_id, pos, swap_ids in interv_rep:
|
| 32 |
+
new_state[swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
|
| 33 |
+
new_state[swap_ids[1],:,head_dim*head_id:head_dim*(head_id+1)][pos,:] = reps[rep_type][swap_ids[0],:,head_dim*head_id:head_dim*(head_id+1)][pos,:]
|
| 34 |
+
reps[rep_type] = new_state.clone()
|
| 35 |
+
|
| 36 |
+
hidden = reps['lay'].clone()
|
| 37 |
+
qry = reps['qry'].clone()
|
| 38 |
+
key = reps['key'].clone()
|
| 39 |
+
val = reps['val'].clone()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
#split into multiple heads
|
| 43 |
+
split_qry = qry.view(*(qry.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
| 44 |
+
split_key = key.view(*(key.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
| 45 |
+
split_val = val.view(*(val.size()[:-1]+(num_heads,head_dim))).permute(0,2,1,3)
|
| 46 |
+
|
| 47 |
+
#calculate the attention matrix
|
| 48 |
+
attn_mat = F.softmax(split_qry@split_key.permute(0,1,3,2)/math.sqrt(head_dim),dim=-1)
|
| 49 |
+
|
| 50 |
+
z_rep_indiv = attn_mat@split_val
|
| 51 |
+
z_rep = z_rep_indiv.permute(0,2,1,3).reshape(*hidden.size())
|
| 52 |
+
|
| 53 |
+
hidden_post_attn_res = layer.attention.output.dense(z_rep)+hidden # residual connection
|
| 54 |
+
hidden_post_attn = layer.attention.output.LayerNorm(hidden_post_attn_res) # layer_norm
|
| 55 |
+
|
| 56 |
+
hidden_post_interm = layer.intermediate(hidden_post_attn) # massive feed forward
|
| 57 |
+
hidden_post_interm_res = layer.output.dense(hidden_post_interm)+hidden_post_attn # residual connection
|
| 58 |
+
new_hidden = layer.output.LayerNorm(hidden_post_interm_res) # layer_norm
|
| 59 |
+
return new_hidden
|
| 60 |
+
|
| 61 |
+
def SkeletonBertForMaskedLM(model,input_ids,interventions):
|
| 62 |
+
core_model = model.roberta
|
| 63 |
+
lm_head = model.lm_head
|
| 64 |
+
output_hidden = []
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
hidden = core_model.embeddings(input_ids)
|
| 67 |
+
output_hidden.append(hidden)
|
| 68 |
+
for layer_id in range(model.config.num_hidden_layers):
|
| 69 |
+
layer = core_model.encoder.layer[layer_id]
|
| 70 |
+
hidden = SkeletonRobertaLayer(layer_id,layer,hidden,interventions)
|
| 71 |
+
output_hidden.append(hidden)
|
| 72 |
+
logits = lm_head(hidden)
|
| 73 |
+
return {'logits':logits,'hidden_states':output_hidden}
|