Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import PreTrainedModel, PretrainedConfig | |
| import torch | |
| import torch.nn as nn | |
| from transformers import PretrainedConfig, PreTrainedModel | |
| from torch.nn.parameter import Parameter | |
| from torch.nn.init import xavier_uniform_, constant_ | |
| from configuration_dlmberta import InteractionModelATTNConfig | |
| import math | |
| class StdScaler(): | |
| def fit(self, X): | |
| self.mean_ = torch.mean(X).item() | |
| self.std_ = torch.std(X, correction=0).item() | |
| def fit_transform(self, X): | |
| self.mean_ = torch.mean(X).item() | |
| self.std_ = torch.std(X, correction=0).item() | |
| return (X-self.mean_)/self.std_ | |
| def transform(self, X): | |
| return (X-self.mean_)/self.std_ | |
| def inverse_transform(self, X): | |
| return (X*self.std_)+self.mean_ | |
| def save(self, directory): | |
| with open(directory+"/scaler.config", "w") as f: | |
| f.write(str(self.mean_)+"\n") | |
| f.write(str(self.std_)+"\n") | |
| def load(self, directory): | |
| with open(directory+"/scaler.config", "r") as f: | |
| self.mean_ = float(f.readline()) | |
| self.std_ = float(f.readline()) | |
| class InteractionModelATTNForRegression(PreTrainedModel): | |
| config_class = InteractionModelATTNConfig | |
| def __init__(self, config, target_encoder, drug_encoder, scaler=None): | |
| super().__init__(config) | |
| self.model = InteractionModelATTN(target_encoder, | |
| drug_encoder, | |
| scaler, | |
| config.attention_dropout, | |
| config.hidden_dropout, | |
| config.num_heads) | |
| self.scaler = scaler | |
| def INTERPR_ENABLE_MODE(self): | |
| self.model.INTERPR_ENABLE_MODE() | |
| def INTERPR_DISABLE_MODE(self): | |
| self.model.INTERPR_DISABLE_MODE() | |
| def INTERPR_OVERRIDE_ATTN(self, new_weights): | |
| self.model.INTERPR_OVERRIDE_ATTN(new_weights) | |
| def INTERPR_RESET_OVERRIDE_ATTN(self): | |
| self.model.INTERPR_RESET_OVERRIDE_ATTN() | |
| def forward(self, x1, x2): | |
| return self.model(x1, x2) | |
| def unscale(self, x): | |
| return self.model.unscale(x) | |
| class CrossAttention(nn.Module): | |
| def __init__(self, embed_dim, num_heads, attention_dropout=0.0, hidden_dropout=0.0, add_bias_kv=False, **factory_kwargs): | |
| """ | |
| Initializes the CrossAttention layer. | |
| Args: | |
| embed_dim (int): Dimension of the input embeddings. | |
| num_heads (int): Number of attention heads. | |
| dropout (float): Dropout probability for attention weights. | |
| """ | |
| super().__init__() | |
| self.attention_dropout = attention_dropout | |
| self.hidden_dropout = hidden_dropout | |
| self.embed_dim = embed_dim | |
| self.num_heads = num_heads | |
| self.head_dim = embed_dim // num_heads | |
| self.scaling = self.head_dim ** -0.5 | |
| if self.head_dim * num_heads != embed_dim: | |
| raise ValueError("embed_dim must be divisible by num_heads") | |
| # Linear projections for query, key, and value. | |
| self.q_proj = nn.Linear(embed_dim, embed_dim) | |
| self.k_proj = nn.Linear(embed_dim, embed_dim) | |
| self.v_proj = nn.Linear(embed_dim, embed_dim) | |
| self.attn_dropout = nn.Dropout(attention_dropout) | |
| xavier_uniform_(self.q_proj.weight) | |
| xavier_uniform_(self.k_proj.weight) | |
| xavier_uniform_(self.v_proj.weight) | |
| constant_(self.q_proj.bias, 0.) | |
| constant_(self.k_proj.bias, 0.) | |
| constant_(self.v_proj.bias, 0.) | |
| # Output projection. | |
| self.out_proj = nn.Linear(embed_dim, embed_dim) | |
| constant_(self.out_proj.bias, 0) | |
| self.drop_out = nn.Dropout(hidden_dropout) | |
| def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, replace_weights=None): | |
| """ | |
| Forward pass for cross attention. | |
| Args: | |
| query (Tensor): Query embeddings of shape (batch_size, query_len, embed_dim). | |
| key (Tensor): Key embeddings of shape (batch_size, key_len, embed_dim). | |
| value (Tensor): Value embeddings of shape (batch_size, key_len, embed_dim). | |
| attn_mask (Tensor, optional): Attention mask of shape (batch_size, num_heads, query_len, key_len). | |
| Returns: | |
| output (Tensor): The attended output of shape (batch_size, query_len, embed_dim). | |
| attn_weights (Tensor): The attention weights of shape (batch_size, num_heads, query_len, key_len). | |
| """ | |
| batch_size, query_len, _ = query.size() | |
| _, key_len, _ = key.size() | |
| Q = self.q_proj(query) | |
| K = self.k_proj(key) | |
| V = self.v_proj(value) | |
| Q = Q.view(batch_size, self.num_heads, query_len, self.head_dim) | |
| K = K.view(batch_size, self.num_heads, key_len, self.head_dim) | |
| V = V.view(batch_size, self.num_heads, key_len, self.head_dim) | |
| # Compute scaled dot-product attention scores | |
| scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch_size, num_heads, query_len, key_len) | |
| if key_padding_mask is not None: | |
| # Convert boolean mask (False -> -inf, True -> 0) | |
| key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # (B, 1, 1, key_len) for broadcasting | |
| scores = scores.masked_fill(key_padding_mask, float('-inf')) # Set masked positions to -inf | |
| if replace_weights is not None: | |
| scores = replace_weights | |
| # Compute attention weights using softmax | |
| attn_weights = torch.nn.functional.softmax(scores, dim=-1) # (batch_size, num_heads, query_len, key_len) | |
| self.scores = scores | |
| if attn_mask is not None: | |
| attn_mask = attn_mask.unsqueeze(1) # Shape: (batch_size, 1, query_len, key_len) | |
| attn_weights = attn_weights.masked_fill(attn_mask, 0) # Set masked positions to 0 | |
| # Optionally apply dropout to the attention weights if self.dropout is defined | |
| attn_weights = self.attn_dropout(attn_weights) | |
| # Compute the weighted sum of the values | |
| attn_output = torch.matmul(attn_weights, V) # (batch_size, num_heads, query_len, head_dim) | |
| # Recombine heads: transpose and reshape back to (batch_size, query_len, embed_dim) | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.embed_dim) | |
| # Final linear projection and dropout | |
| output = self.out_proj(attn_output) | |
| output = self.drop_out(output) | |
| return output, attn_weights | |
| class InteractionModelATTN(nn.Module): | |
| def __init__(self, target_encoder, drug_encoder, scaler, attention_dropout, hidden_dropout, num_heads=1, kernel_size=1): | |
| super().__init__() | |
| self.replace_weights = None | |
| self.crossattention_weights = None | |
| self.presum_layer = None | |
| self.INTERPR_MODE = False | |
| self.scaler = scaler | |
| self.attention_dropout = attention_dropout | |
| self.hidden_dropout = hidden_dropout | |
| self.target_encoder = target_encoder | |
| self.drug_encoder = drug_encoder | |
| self.kernel_size = kernel_size | |
| self.lin_map_target = nn.Linear(512, 384) | |
| self.dropout_map_target = nn.Dropout(hidden_dropout) | |
| self.lin_map_drug = nn.Linear(384, 384) | |
| self.dropout_map_drug = nn.Dropout(hidden_dropout) | |
| self.crossattention = CrossAttention(384, num_heads, attention_dropout, hidden_dropout) | |
| self.norm = nn.LayerNorm(384) | |
| self.summary1 = nn.Linear(384, 384) | |
| self.summary2 = nn.Linear(384, 1) | |
| self.dropout_summary = nn.Dropout(hidden_dropout) | |
| self.layer_norm = nn.LayerNorm(384) | |
| self.gelu = nn.GELU() | |
| self.w = Parameter(torch.empty(512, 1)) | |
| self.b = Parameter(torch.zeros(1)) | |
| self.pdng = Parameter(torch.tensor(0.0)) # learnable padding value (0-dimensional) | |
| xavier_uniform_(self.w) | |
| def forward(self, x1, x2): | |
| """ | |
| Forward pass for attention interaction model. | |
| Args: | |
| x1 (dict): A dictionary containing input tensors for the target encoder. | |
| Expected keys: | |
| - 'input_ids' (torch.Tensor): Token IDs for the target input. | |
| - 'attention_mask' (torch.Tensor): Attention mask for the target input. | |
| x2 (dict): A dictionary containing input tensors for the drug encoder. | |
| Expected keys: | |
| - 'input_ids' (torch.Tensor): Token IDs for the drug input. | |
| - 'attention_mask' (torch.Tensor): Attention mask for the drug input. | |
| Returns: | |
| torch.Tensor: A tensor representing the predicted binding affinity. | |
| """ | |
| x1["attention_mask"] = x1["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120 | |
| y1 = self.target_encoder(**x1).last_hidden_state # The target | |
| query_mask = x1["attention_mask"].unsqueeze(-1).to(y1.dtype) | |
| y1 = y1 * query_mask | |
| x2["attention_mask"] = x2["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120 | |
| y2 = self.drug_encoder(**x2).last_hidden_state # The drug | |
| key_mask = x2["attention_mask"].unsqueeze(-1).to(y2.dtype) | |
| y2 = y2 * key_mask | |
| y1 = self.lin_map_target(y1) | |
| y1 = self.gelu(y1) | |
| y1 = self.dropout_map_target(y1) | |
| y2 = self.lin_map_drug(y2) | |
| y2 = self.gelu(y2) | |
| y2 = self.dropout_map_drug(y2) | |
| key_padding_mask=(x2["attention_mask"] == 0) # S | |
| replace_weights = None | |
| # If in interpretation mode, allow the replacement of cross-attention weights | |
| if self.INTERPR_MODE: | |
| if self.replace_weights is not None: | |
| replace_weights = self.replace_weights | |
| out, _ = self.crossattention(y1, y2, y2, key_padding_mask=key_padding_mask, attn_mask=None, replace_weights=replace_weights) | |
| # If in interpretation mode, make cross-attention weights and scores accessible from the outside | |
| if self.INTERPR_MODE: | |
| self.crossattention_weights = _ | |
| self.scores = self.crossattention.scores | |
| out = self.summary1(out * query_mask) | |
| out = self.gelu(out) | |
| out = self.dropout_summary(out) | |
| out = self.summary2(out).squeeze(-1) | |
| # If in interpretation mode, make final summation layer contributions accessible from the outside | |
| if self.INTERPR_MODE: | |
| self.presum_layer = out | |
| weighted = out * self.w.squeeze(1) # [batch, seq_len] | |
| padding_positions = ~x1["attention_mask"] # True at padding | |
| # assign learnable pdng to all padding positions | |
| weighted = weighted.masked_fill(padding_positions, self.pdng.item()) | |
| # sum across sequence and add bias | |
| result = weighted.sum(dim=1, keepdim=True) + self.b | |
| return result | |
| def train(self, mode = True): | |
| super().train(mode) | |
| self.target_encoder.train(mode) | |
| self.drug_encoder.train(mode) | |
| self.crossattention.train(mode) | |
| return self | |
| def eval(self): | |
| super().eval() | |
| self.target_encoder.eval() | |
| self.drug_encoder.eval() | |
| self.crossattention.eval() | |
| return self | |
| def INTERPR_ENABLE_MODE(self): | |
| """ | |
| Enables the interpretability mode for the model. | |
| """ | |
| if self.training: | |
| raise RuntimeError("Cannot enable interpretability mode while the model is training.") | |
| self.INTERPR_MODE = True | |
| def INTERPR_DISABLE_MODE(self): | |
| """ | |
| Disables the interpretability mode for the model. | |
| """ | |
| if self.training: | |
| raise RuntimeError("Cannot disable interpretability mode while the model is training.") | |
| self.INTERPR_MODE = False | |
| def INTERPR_OVERRIDE_ATTN(self, new_weights): | |
| self.replace_weights = new_weights | |
| def INTERPR_RESET_OVERRIDE_ATTN(self): | |
| self.replace_weights = None | |
| def unscale(self, x): | |
| """ | |
| Unscales the labels using a scaler. If the scaler is not specified, don't do anything. | |
| Parameters: | |
| target_value: the target values to be unscaled | |
| """ | |
| with torch.no_grad(): | |
| if self.scaler is None: | |
| return x | |
| unscaled = self.scaler.inverse_transform(x) | |
| return unscaled | |