Spaces:
Runtime error
Runtime error
| import torch | |
| from torch.nn import functional as F | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| from torch import nn | |
| from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | |
| from torch.nn import Identity | |
| from transformers.activations import get_activation | |
| import numpy as np | |
| #from torch_scatter import scatter_add | |
| from .utils import input_check, pos_encoding | |
| class classification_model(torch.nn.Module): | |
| def __init__(self, pretrained_model, config, num_classifier=1, num_pos_emb_layer=1, bertsum=False, device=None): | |
| super(classification_model, self).__init__() | |
| self.config = config | |
| self.num_labels = config.num_labels | |
| self.pretrained_model = pretrained_model | |
| if hasattr(config, 'd_model'): | |
| self.pretrained_hidden = config.d_model | |
| elif hasattr(config, 'hidden_size'): | |
| self.pretrained_hidden = config.hidden_size | |
| self.sequence_summary = SequenceSummary(config) | |
| self.bertsum = bertsum | |
| self.device = device | |
| self.return_hidden = False | |
| self.return_hidden_pretrained = False | |
| if self.bertsum: | |
| #self.pooling_1 = GATpooling(self.pretrained_hidden) | |
| #self.fnn_1 = nn.Linear(self.pretrained_hidden, self.pretrained_hidden) | |
| self.pooling_2 = GATpooling(self.pretrained_hidden, self.device) | |
| self.fnn_2 = nn.Linear(self.pretrained_hidden, self.pretrained_hidden) | |
| self.pos_emb_layer = nn.Sequential(*[nn.Linear(self.pretrained_hidden, self.pretrained_hidden) for _ in range(num_pos_emb_layer)]) | |
| dim_list = np.linspace(self.pretrained_hidden, config.num_labels, num_classifier+1, dtype=np.int32) | |
| #dim_list = np.linspace(768, config.num_labels, num_classifier+1, dtype=np.int32) | |
| self.classifiers = nn.ModuleList() | |
| for c in range(num_classifier): | |
| self.classifiers.append(nn.Linear(dim_list[c], dim_list[c+1])) | |
| def forward(self, inputs): | |
| hidden_states = None | |
| input_ids = inputs['input_ids'] | |
| token_type_ids = inputs['token_type_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| position = inputs['position'] | |
| transformer_inputs = input_check({'input_ids':input_ids, 'token_type_ids':token_type_ids, 'attention_mask':attention_mask}, self.pretrained_model) | |
| pretrianed_output = self.pretrained_model(**transformer_inputs) | |
| output = pretrianed_output[0] | |
| if self.return_hidden_pretrained and self.return_hidden: | |
| hidden_states = pretrianed_output[1] | |
| if self.bertsum: | |
| output = scatter_add(output, inputs['sentence_batch'], dim=-2) | |
| #output = self.pooling_1(output, inputs['sentence_batch']) | |
| #output = self.fnn_1(output) | |
| output = self.pooling_2(output) | |
| output = output.squeeze() | |
| output = self.fnn_2(output) | |
| else: | |
| output = self.sequence_summary(output) | |
| # paragraph positional encoding vector add | |
| pos_emb = pos_encoding(position, self.pretrained_hidden).to(self.device, dtype=torch.float) | |
| output = torch.add(output,pos_emb) | |
| output = self.pos_emb_layer(output) | |
| if self.return_hidden and not self.return_hidden_pretrained: | |
| hidden_states = output | |
| for layer in self.classifiers: | |
| output = layer(output) | |
| logits = output | |
| if 'labels' in inputs.keys(): | |
| loss = self.classification_loss_f(inputs, logits) | |
| else: | |
| loss = None | |
| return loss, output, hidden_states | |
| def classification_loss_f(self, inputs, logits): | |
| labels=inputs['labels'] | |
| loss=None | |
| if labels is not None: | |
| if self.config.problem_type is None: | |
| if self.num_labels == 1: | |
| self.config.problem_type = "regression" | |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): | |
| self.config.problem_type = "single_label_classification" | |
| else: | |
| self.config.problem_type = "multi_label_classification" | |
| if self.config.problem_type == "regression": | |
| loss_fct = MSELoss() | |
| if self.num_labels == 1: | |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) | |
| else: | |
| loss = loss_fct(logits, labels) | |
| elif self.config.problem_type == "single_label_classification": | |
| loss_fct = CrossEntropyLoss() | |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | |
| elif self.config.problem_type == "multi_label_classification": | |
| loss_fct = BCEWithLogitsLoss() | |
| loss = loss_fct(logits, labels) | |
| return loss | |
| class GATpooling(nn.Module): | |
| def __init__(self, hidden_size, device=None): | |
| super(GATpooling, self).__init__() | |
| self.gate_nn = nn.Linear(hidden_size, 1) | |
| self.device = device | |
| def forward(self, x, batch=None): | |
| if batch==None: | |
| batch = torch.zeros(x.shape[-2], dtype=torch.long).to(self.device) | |
| gate = self.gate_nn(x) | |
| gate = F.softmax(gate, dim=-1) | |
| out = scatter_add(gate*x, batch, dim=-2) | |
| return out | |
| class SequenceSummary(nn.Module): | |
| r""" | |
| Compute a single vector summary of a sequence hidden states. | |
| Args: | |
| config ([`PretrainedConfig`]): | |
| The config used by the model. Relevant arguments in the config class of the model are (refer to the actual | |
| config class of your model for the default values it uses): | |
| - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are: | |
| - `"last"` -- Take the last token hidden state (like XLNet) | |
| - `"first"` -- Take the first token hidden state (like Bert) | |
| - `"mean"` -- Take the mean of all tokens hidden states | |
| - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2) | |
| - `"attn"` -- Not implemented now, use multi-head attention | |
| - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction. | |
| - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes | |
| (otherwise to `config.hidden_size`). | |
| - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output, | |
| another string or `None` will add no activation. | |
| - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation. | |
| - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation. | |
| """ | |
| def __init__(self, config): | |
| super().__init__() | |
| self.summary_type = getattr(config, "summary_type", "mean") | |
| if self.summary_type == "attn": | |
| # We should use a standard multi-head attention module with absolute positional embedding for that. | |
| # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 | |
| # We can probably just use the multi-head attention module of PyTorch >=1.1.0 | |
| raise NotImplementedError | |
| self.summary = Identity() | |
| if hasattr(config, "summary_use_proj") and config.summary_use_proj: | |
| if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: | |
| num_classes = config.num_labels | |
| else: | |
| num_classes = config.hidden_size | |
| self.summary = nn.Linear(config.hidden_size, num_classes) | |
| activation_string = getattr(config, "summary_activation", None) | |
| self.activation: Callable = get_activation(activation_string) if activation_string else Identity() | |
| self.first_dropout = Identity() | |
| if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0: | |
| self.first_dropout = nn.Dropout(config.summary_first_dropout) | |
| self.last_dropout = Identity() | |
| if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: | |
| self.last_dropout = nn.Dropout(config.summary_last_dropout) | |
| def forward( | |
| self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None | |
| ) -> torch.FloatTensor: | |
| """ | |
| Compute a single vector summary of a sequence hidden states. | |
| Args: | |
| hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`): | |
| The hidden states of the last layer. | |
| cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*): | |
| Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token. | |
| Returns: | |
| `torch.FloatTensor`: The summary of the sequence hidden states. | |
| """ | |
| if self.summary_type == "last": | |
| output = hidden_states[:, -1] | |
| elif self.summary_type == "first": | |
| output = hidden_states[:, 0] | |
| elif self.summary_type == "mean": | |
| output = hidden_states.mean(dim=1) | |
| elif self.summary_type == "cls_index": | |
| if cls_index is None: | |
| cls_index = torch.full_like( | |
| hidden_states[..., :1, :], | |
| hidden_states.shape[-2] - 1, | |
| dtype=torch.long, | |
| ) | |
| else: | |
| cls_index = cls_index.unsqueeze(-1).unsqueeze(-1) | |
| cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),)) | |
| # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states | |
| output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size) | |
| elif self.summary_type == "attn": | |
| raise NotImplementedError | |
| output = self.first_dropout(output) | |
| output = self.summary(output) | |
| output = self.activation(output) | |
| output = self.last_dropout(output) | |
| return output | |