| from transformers import PretrainedConfig, BertConfig | |
| from typing import List | |
| class VGCNConfig(BertConfig): | |
| model_type = "vgcn" | |
| def __init__( | |
| self, | |
| bert_model='readerbench/RoBERT-base', | |
| gcn_adj_matrix: str ='', | |
| max_seq_len: int = 256, | |
| npmi_threshold: float = 0.2, | |
| tf_threshold: float = 0.0, | |
| vocab_type: str = "all", | |
| gcn_embedding_dim: int = 32, | |
| **kwargs, | |
| ): | |
| if vocab_type not in ["all", "pmi", "tf"]: | |
| raise ValueError(f"`vocab_type` must be 'all', 'pmi' or 'tf', got {vocab_type}.") | |
| if max_seq_len < 1 or max_seq_len > 512: | |
| raise ValueError(f"`max_seq_len` must be between 1 and 512, got {max_seq_len}.") | |
| if npmi_threshold < 0.0 or npmi_threshold > 1.0: | |
| raise ValueError(f"`npmi_threshold` must be between 0.0 and 1.0, got {npmi_threshold}.") | |
| if tf_threshold < 0.0 or tf_threshold > 1.0: | |
| raise ValueError(f"`tf_threshold` must be between 0.0 and 1.0, got {tf_threshold}.") | |
| self.gcn_adj_matrix = gcn_adj_matrix | |
| self.max_seq_len = max_seq_len | |
| self.npmi_threshold = npmi_threshold | |
| self.tf_threshold = tf_threshold | |
| self.vocab_type = vocab_type | |
| self.gcn_embedding_dim = gcn_embedding_dim | |
| self.bert_model = bert_model | |
| super().__init__(**kwargs) |