DLRNA-BERTa / modeling_dlmberta.py
IlPakoZ's picture
Upload 2 files
13d264b verified
import torch
from transformers import PreTrainedModel, PretrainedConfig
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from torch.nn.parameter import Parameter
from torch.nn.init import xavier_uniform_, constant_
from configuration_dlmberta import InteractionModelATTNConfig
import math
class StdScaler():
def fit(self, X):
self.mean_ = torch.mean(X).item()
self.std_ = torch.std(X, correction=0).item()
def fit_transform(self, X):
self.mean_ = torch.mean(X).item()
self.std_ = torch.std(X, correction=0).item()
return (X-self.mean_)/self.std_
def transform(self, X):
return (X-self.mean_)/self.std_
def inverse_transform(self, X):
return (X*self.std_)+self.mean_
def save(self, directory):
with open(directory+"/scaler.config", "w") as f:
f.write(str(self.mean_)+"\n")
f.write(str(self.std_)+"\n")
def load(self, directory):
with open(directory+"/scaler.config", "r") as f:
self.mean_ = float(f.readline())
self.std_ = float(f.readline())
class InteractionModelATTNForRegression(PreTrainedModel):
config_class = InteractionModelATTNConfig
def __init__(self, config, target_encoder, drug_encoder, scaler=None):
super().__init__(config)
self.model = InteractionModelATTN(target_encoder,
drug_encoder,
scaler,
config.attention_dropout,
config.hidden_dropout,
config.num_heads)
self.scaler = scaler
def INTERPR_ENABLE_MODE(self):
self.model.INTERPR_ENABLE_MODE()
def INTERPR_DISABLE_MODE(self):
self.model.INTERPR_DISABLE_MODE()
def INTERPR_OVERRIDE_ATTN(self, new_weights):
self.model.INTERPR_OVERRIDE_ATTN(new_weights)
def INTERPR_RESET_OVERRIDE_ATTN(self):
self.model.INTERPR_RESET_OVERRIDE_ATTN()
def forward(self, x1, x2):
return self.model(x1, x2)
def unscale(self, x):
return self.model.unscale(x)
class CrossAttention(nn.Module):
def __init__(self, embed_dim, num_heads, attention_dropout=0.0, hidden_dropout=0.0, add_bias_kv=False, **factory_kwargs):
"""
Initializes the CrossAttention layer.
Args:
embed_dim (int): Dimension of the input embeddings.
num_heads (int): Number of attention heads.
dropout (float): Dropout probability for attention weights.
"""
super().__init__()
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.scaling = self.head_dim ** -0.5
if self.head_dim * num_heads != embed_dim:
raise ValueError("embed_dim must be divisible by num_heads")
# Linear projections for query, key, and value.
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.attn_dropout = nn.Dropout(attention_dropout)
xavier_uniform_(self.q_proj.weight)
xavier_uniform_(self.k_proj.weight)
xavier_uniform_(self.v_proj.weight)
constant_(self.q_proj.bias, 0.)
constant_(self.k_proj.bias, 0.)
constant_(self.v_proj.bias, 0.)
# Output projection.
self.out_proj = nn.Linear(embed_dim, embed_dim)
constant_(self.out_proj.bias, 0)
self.drop_out = nn.Dropout(hidden_dropout)
def forward(self, query, key, value, key_padding_mask=None, attn_mask=None, replace_weights=None):
"""
Forward pass for cross attention.
Args:
query (Tensor): Query embeddings of shape (batch_size, query_len, embed_dim).
key (Tensor): Key embeddings of shape (batch_size, key_len, embed_dim).
value (Tensor): Value embeddings of shape (batch_size, key_len, embed_dim).
attn_mask (Tensor, optional): Attention mask of shape (batch_size, num_heads, query_len, key_len).
Returns:
output (Tensor): The attended output of shape (batch_size, query_len, embed_dim).
attn_weights (Tensor): The attention weights of shape (batch_size, num_heads, query_len, key_len).
"""
batch_size, query_len, _ = query.size()
_, key_len, _ = key.size()
Q = self.q_proj(query)
K = self.k_proj(key)
V = self.v_proj(value)
Q = Q.view(batch_size, self.num_heads, query_len, self.head_dim)
K = K.view(batch_size, self.num_heads, key_len, self.head_dim)
V = V.view(batch_size, self.num_heads, key_len, self.head_dim)
# Compute scaled dot-product attention scores
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.head_dim) # (batch_size, num_heads, query_len, key_len)
if key_padding_mask is not None:
# Convert boolean mask (False -> -inf, True -> 0)
key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(1) # (B, 1, 1, key_len) for broadcasting
scores = scores.masked_fill(key_padding_mask, float('-inf')) # Set masked positions to -inf
if replace_weights is not None:
scores = replace_weights
# Compute attention weights using softmax
attn_weights = torch.nn.functional.softmax(scores, dim=-1) # (batch_size, num_heads, query_len, key_len)
self.scores = scores
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(1) # Shape: (batch_size, 1, query_len, key_len)
attn_weights = attn_weights.masked_fill(attn_mask, 0) # Set masked positions to 0
# Optionally apply dropout to the attention weights if self.dropout is defined
attn_weights = self.attn_dropout(attn_weights)
# Compute the weighted sum of the values
attn_output = torch.matmul(attn_weights, V) # (batch_size, num_heads, query_len, head_dim)
# Recombine heads: transpose and reshape back to (batch_size, query_len, embed_dim)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, query_len, self.embed_dim)
# Final linear projection and dropout
output = self.out_proj(attn_output)
output = self.drop_out(output)
return output, attn_weights
class InteractionModelATTN(nn.Module):
def __init__(self, target_encoder, drug_encoder, scaler, attention_dropout, hidden_dropout, num_heads=1, kernel_size=1):
super().__init__()
self.replace_weights = None
self.crossattention_weights = None
self.presum_layer = None
self.INTERPR_MODE = False
self.scaler = scaler
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
self.target_encoder = target_encoder
self.drug_encoder = drug_encoder
self.kernel_size = kernel_size
self.lin_map_target = nn.Linear(512, 384)
self.dropout_map_target = nn.Dropout(hidden_dropout)
self.lin_map_drug = nn.Linear(384, 384)
self.dropout_map_drug = nn.Dropout(hidden_dropout)
self.crossattention = CrossAttention(384, num_heads, attention_dropout, hidden_dropout)
self.norm = nn.LayerNorm(384)
self.summary1 = nn.Linear(384, 384)
self.summary2 = nn.Linear(384, 1)
self.dropout_summary = nn.Dropout(hidden_dropout)
self.layer_norm = nn.LayerNorm(384)
self.gelu = nn.GELU()
self.w = Parameter(torch.empty(512, 1))
self.b = Parameter(torch.zeros(1))
self.pdng = Parameter(torch.tensor(0.0)) # learnable padding value (0-dimensional)
xavier_uniform_(self.w)
def forward(self, x1, x2):
"""
Forward pass for attention interaction model.
Args:
x1 (dict): A dictionary containing input tensors for the target encoder.
Expected keys:
- 'input_ids' (torch.Tensor): Token IDs for the target input.
- 'attention_mask' (torch.Tensor): Attention mask for the target input.
x2 (dict): A dictionary containing input tensors for the drug encoder.
Expected keys:
- 'input_ids' (torch.Tensor): Token IDs for the drug input.
- 'attention_mask' (torch.Tensor): Attention mask for the drug input.
Returns:
torch.Tensor: A tensor representing the predicted binding affinity.
"""
x1["attention_mask"] = x1["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120
y1 = self.target_encoder(**x1).last_hidden_state # The target
query_mask = x1["attention_mask"].unsqueeze(-1).to(y1.dtype)
y1 = y1 * query_mask
x2["attention_mask"] = x2["attention_mask"].bool() # Fix dropout model issue: https://github.com/pytorch/pytorch/issues/86120
y2 = self.drug_encoder(**x2).last_hidden_state # The drug
key_mask = x2["attention_mask"].unsqueeze(-1).to(y2.dtype)
y2 = y2 * key_mask
y1 = self.lin_map_target(y1)
y1 = self.gelu(y1)
y1 = self.dropout_map_target(y1)
y2 = self.lin_map_drug(y2)
y2 = self.gelu(y2)
y2 = self.dropout_map_drug(y2)
key_padding_mask=(x2["attention_mask"] == 0) # S
replace_weights = None
# If in interpretation mode, allow the replacement of cross-attention weights
if self.INTERPR_MODE:
if self.replace_weights is not None:
replace_weights = self.replace_weights
out, _ = self.crossattention(y1, y2, y2, key_padding_mask=key_padding_mask, attn_mask=None, replace_weights=replace_weights)
# If in interpretation mode, make cross-attention weights and scores accessible from the outside
if self.INTERPR_MODE:
self.crossattention_weights = _
self.scores = self.crossattention.scores
out = self.summary1(out * query_mask)
out = self.gelu(out)
out = self.dropout_summary(out)
out = self.summary2(out).squeeze(-1)
# If in interpretation mode, make final summation layer contributions accessible from the outside
if self.INTERPR_MODE:
self.presum_layer = out
weighted = out * self.w.squeeze(1) # [batch, seq_len]
padding_positions = ~x1["attention_mask"] # True at padding
# assign learnable pdng to all padding positions
weighted = weighted.masked_fill(padding_positions, self.pdng.item())
# sum across sequence and add bias
result = weighted.sum(dim=1, keepdim=True) + self.b
return result
def train(self, mode = True):
super().train(mode)
self.target_encoder.train(mode)
self.drug_encoder.train(mode)
self.crossattention.train(mode)
return self
def eval(self):
super().eval()
self.target_encoder.eval()
self.drug_encoder.eval()
self.crossattention.eval()
return self
def INTERPR_ENABLE_MODE(self):
"""
Enables the interpretability mode for the model.
"""
if self.training:
raise RuntimeError("Cannot enable interpretability mode while the model is training.")
self.INTERPR_MODE = True
def INTERPR_DISABLE_MODE(self):
"""
Disables the interpretability mode for the model.
"""
if self.training:
raise RuntimeError("Cannot disable interpretability mode while the model is training.")
self.INTERPR_MODE = False
def INTERPR_OVERRIDE_ATTN(self, new_weights):
self.replace_weights = new_weights
def INTERPR_RESET_OVERRIDE_ATTN(self):
self.replace_weights = None
def unscale(self, x):
"""
Unscales the labels using a scaler. If the scaler is not specified, don't do anything.
Parameters:
target_value: the target values to be unscaled
"""
with torch.no_grad():
if self.scaler is None:
return x
unscaled = self.scaler.inverse_transform(x)
return unscaled