Readme update
Browse files- config.json +8 -0
- configuration_vgcn.py +2 -0
- modeling_vcgn.py +40 -17
config.json
CHANGED
|
@@ -1,5 +1,12 @@
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
| 2 |
"attention_probs_dropout_prob": 0.1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"bert_model": "readerbench/RoBERT-base",
|
| 4 |
"classifier_dropout": null,
|
| 5 |
"do_lower_case": 1,
|
|
@@ -34,6 +41,7 @@
|
|
| 34 |
"pad_token_id": 0,
|
| 35 |
"position_embedding_type": "absolute",
|
| 36 |
"tf_threshold": 0.0,
|
|
|
|
| 37 |
"transformers_version": "4.31.0",
|
| 38 |
"type_vocab_size": 2,
|
| 39 |
"use_cache": true,
|
|
|
|
| 1 |
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"VCGNModelForTextClassification"
|
| 4 |
+
],
|
| 5 |
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_vgcn.VGCNConfig",
|
| 8 |
+
"AutoModelForSequenceClassification": "modeling_vcgn.VCGNModelForTextClassification"
|
| 9 |
+
},
|
| 10 |
"bert_model": "readerbench/RoBERT-base",
|
| 11 |
"classifier_dropout": null,
|
| 12 |
"do_lower_case": 1,
|
|
|
|
| 41 |
"pad_token_id": 0,
|
| 42 |
"position_embedding_type": "absolute",
|
| 43 |
"tf_threshold": 0.0,
|
| 44 |
+
"torch_dtype": "float32",
|
| 45 |
"transformers_version": "4.31.0",
|
| 46 |
"type_vocab_size": 2,
|
| 47 |
"use_cache": true,
|
configuration_vgcn.py
CHANGED
|
@@ -6,6 +6,7 @@ class VGCNConfig(BertConfig):
|
|
| 6 |
|
| 7 |
def __init__(
|
| 8 |
self,
|
|
|
|
| 9 |
gcn_adj_matrix: str ='',
|
| 10 |
max_seq_len: int = 256,
|
| 11 |
npmi_threshold: float = 0.2,
|
|
@@ -29,5 +30,6 @@ class VGCNConfig(BertConfig):
|
|
| 29 |
self.tf_threshold = tf_threshold
|
| 30 |
self.vocab_type = vocab_type
|
| 31 |
self.gcn_embedding_dim = gcn_embedding_dim
|
|
|
|
| 32 |
|
| 33 |
super().__init__(**kwargs)
|
|
|
|
| 6 |
|
| 7 |
def __init__(
|
| 8 |
self,
|
| 9 |
+
bert_model='readerbench/RoBERT-base',
|
| 10 |
gcn_adj_matrix: str ='',
|
| 11 |
max_seq_len: int = 256,
|
| 12 |
npmi_threshold: float = 0.2,
|
|
|
|
| 30 |
self.tf_threshold = tf_threshold
|
| 31 |
self.vocab_type = vocab_type
|
| 32 |
self.gcn_embedding_dim = gcn_embedding_dim
|
| 33 |
+
self.bert_model = bert_model
|
| 34 |
|
| 35 |
super().__init__(**kwargs)
|
modeling_vcgn.py
CHANGED
|
@@ -64,27 +64,51 @@ def get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj,gcn_config:VGCNConfig):
|
|
| 64 |
class VCGNModelForTextClassification(PreTrainedModel):
|
| 65 |
config_class = VGCNConfig
|
| 66 |
|
| 67 |
-
def __init__(self, config):
|
| 68 |
super().__init__(config)
|
| 69 |
-
|
| 70 |
-
self.pre_trained_model_name = ''
|
| 71 |
-
self.remove_stop_words = False
|
| 72 |
-
self.tokenizer = None
|
| 73 |
-
self.norm_gcn_vocab_adj_list = None
|
| 74 |
-
self.gcn_vocab_size = config.vocab_size
|
| 75 |
|
|
|
|
| 76 |
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
self.model = VGCN_Bert(
|
| 80 |
config,
|
| 81 |
-
gcn_adj_matrix=
|
| 82 |
gcn_adj_dim=config.vocab_size,
|
| 83 |
-
gcn_adj_num=len(
|
| 84 |
gcn_embedding_dim=config.gcn_embedding_dim,
|
| 85 |
|
| 86 |
)
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
def load_adj_matrix(self, adj_matrix):
|
| 89 |
filename = None
|
| 90 |
if Path(adj_matrix).is_file():
|
|
@@ -98,11 +122,8 @@ class VCGNModelForTextClassification(PreTrainedModel):
|
|
| 98 |
|
| 99 |
gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb'))
|
| 100 |
|
| 101 |
-
|
| 102 |
-
self.
|
| 103 |
-
self.remove_stop_words = adj_config['remove_stop_words']
|
| 104 |
-
self.tokenizer = BertTokenizer.from_pretrained(self.pre_trained_model_name)
|
| 105 |
-
self.norm_gcn_vocab_adj_list = get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config)
|
| 106 |
|
| 107 |
def _prep_batch(self, batch: torch.Tensor):
|
| 108 |
|
|
@@ -207,12 +228,14 @@ class VocabGraphConvolution(nn.Module):
|
|
| 207 |
"""
|
| 208 |
def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2):
|
| 209 |
super(VocabGraphConvolution, self).__init__()
|
| 210 |
-
if
|
| 211 |
self.adj_matrix=adj_matrix
|
| 212 |
-
|
| 213 |
self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
|
| 214 |
for p in self.adj_matrix:
|
| 215 |
p.requires_grad=False
|
|
|
|
|
|
|
| 216 |
|
| 217 |
self.voc_dim=voc_dim
|
| 218 |
self.num_adj=num_adj
|
|
|
|
| 64 |
class VCGNModelForTextClassification(PreTrainedModel):
|
| 65 |
config_class = VGCNConfig
|
| 66 |
|
| 67 |
+
def __init__(self, config, load_adjacency_matrix=True,):
|
| 68 |
super().__init__(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
+
self.tokenizer = BertTokenizer.from_pretrained(config.bert_model)
|
| 71 |
|
| 72 |
+
if load_adjacency_matrix:
|
| 73 |
+
norm_gcn_vocab_adj_list = self.load_adj_matrix(config.gcn_adj_matrix)
|
| 74 |
+
else:
|
| 75 |
+
norm_gcn_vocab_adj_list = []
|
| 76 |
+
for _ in range(2 if config.vocab_type=='all' else 1):
|
| 77 |
+
norm_gcn_vocab_adj_list.append(torch.sparse.FloatTensor(torch.LongTensor([[0],[0]]), torch.Tensor([0]), (config.vocab_size, config.vocab_size)))
|
| 78 |
|
| 79 |
self.model = VGCN_Bert(
|
| 80 |
config,
|
| 81 |
+
gcn_adj_matrix=norm_gcn_vocab_adj_list,
|
| 82 |
gcn_adj_dim=config.vocab_size,
|
| 83 |
+
gcn_adj_num=len(norm_gcn_vocab_adj_list),
|
| 84 |
gcn_embedding_dim=config.gcn_embedding_dim,
|
| 85 |
|
| 86 |
)
|
| 87 |
|
| 88 |
+
@classmethod
|
| 89 |
+
def from_pretrained(cls, *model_args, reload_adjacency_matrix=False, **kwargs):
|
| 90 |
+
model = super().from_pretrained( *model_args, **kwargs, load_adjacency_matrix=False)
|
| 91 |
+
|
| 92 |
+
if reload_adjacency_matrix:
|
| 93 |
+
norm_gcn_vocab_adj_list = model.load_adj_matrix(model.config.gcn_adj_matrix)
|
| 94 |
+
model.model.embeddings.vocab_gcn.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in norm_gcn_vocab_adj_list])
|
| 95 |
+
for p in model.model.embeddings.vocab_gcn.adj_matrix:
|
| 96 |
+
p.requires_grad=False
|
| 97 |
+
|
| 98 |
+
return model
|
| 99 |
+
|
| 100 |
+
def set_adjacency_matrix(self, adj_matrix:Union[List, np.ndarray, sp.csr_matrix, torch.Tensor] ):
|
| 101 |
+
|
| 102 |
+
if isinstance(adj_matrix, np.ndarray):
|
| 103 |
+
adj_matrix = [torch.from_numpy(adj_matrix)]
|
| 104 |
+
else:
|
| 105 |
+
raise ValueError(f"adjacency matrix must be a list of torch.Tensor or torch.nn.Parameter, got {type(adj_matrix)}")
|
| 106 |
+
|
| 107 |
+
self.model.embeddings.vocab_gcn.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
|
| 108 |
+
for p in self.model.embeddings.vocab_gcn.adj_matrix:
|
| 109 |
+
p.requires_grad=False
|
| 110 |
+
|
| 111 |
+
|
| 112 |
def load_adj_matrix(self, adj_matrix):
|
| 113 |
filename = None
|
| 114 |
if Path(adj_matrix).is_file():
|
|
|
|
| 122 |
|
| 123 |
gcn_vocab_adj_tf, gcn_vocab_adj, adj_config = pkl.load(open(filename, 'rb'))
|
| 124 |
|
| 125 |
+
self.tokenizer = BertTokenizer.from_pretrained(adj_config['bert_model'])
|
| 126 |
+
return get_torch_gcn(gcn_vocab_adj_tf, gcn_vocab_adj, self.config)
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
def _prep_batch(self, batch: torch.Tensor):
|
| 129 |
|
|
|
|
| 228 |
"""
|
| 229 |
def __init__(self,adj_matrix,voc_dim, num_adj, hid_dim, out_dim, dropout_rate=0.2):
|
| 230 |
super(VocabGraphConvolution, self).__init__()
|
| 231 |
+
if isinstance(adj_matrix, nn.Parameter) or isinstance(adj_matrix, nn.ParameterList):
|
| 232 |
self.adj_matrix=adj_matrix
|
| 233 |
+
elif isinstance(adj_matrix, list):
|
| 234 |
self.adj_matrix=torch.nn.ParameterList([torch.nn.Parameter(x) for x in adj_matrix])
|
| 235 |
for p in self.adj_matrix:
|
| 236 |
p.requires_grad=False
|
| 237 |
+
else:
|
| 238 |
+
raise ValueError(f"adjacency matrix must be a list of torch.Tensor or torch.nn.Parameter, got {type(adj_matrix)}")
|
| 239 |
|
| 240 |
self.voc_dim=voc_dim
|
| 241 |
self.num_adj=num_adj
|