DLRNA-BERTa / analysis.py
IlPakoZ's picture
Upload analysis.py
7d1331e verified
import numpy as np
import matplotlib.pyplot as plt
import os
import logging
from matplotlib.colors import LinearSegmentedColormap
from mpl_toolkits.axes_grid1 import make_axes_locatable
from PIL import Image
import io
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_tokenized,
crossattention_weights, target_tokenizer, drug_tokenizer):
"""
Plots the cross-attention weights for a given drug-target pair, only considering unmasked tokens.
Parameters:
target_mask (np.ndarray): Boolean mask for target tokens.
drug_mask (np.ndarray): Boolean mask for drug tokens.
target_tokenized (dict): Tokenized target sequence.
drug_tokenized (dict): Tokenized drug sequence.
crossattention_weights (np.ndarray): The cross-attention weights.
target_tokenizer: Target tokenizer instance.
drug_tokenizer: Drug tokenizer instance.
Returns:
PIL.Image: The generated attention heatmap image.
"""
logger.info("Starting plot_crossattention_weights")
# Convert masks to numpy arrays if they're tensors
if hasattr(target_mask, 'cpu'):
target_mask = target_mask.cpu()
if hasattr(drug_mask, 'cpu'):
drug_mask = drug_mask.cpu()
logger.info(f"Target mask shape: {target_mask.shape}, Drug mask shape: {drug_mask.shape}")
# Get tokens for unmasked positions
tokens_input = target_tokenized["input_ids"][0][target_mask]
target_token_str = target_tokenizer.convert_ids_to_tokens(tokens_input)
tokens_input = drug_tokenized["input_ids"][0][drug_mask]
drug_token_str = drug_tokenizer.convert_ids_to_tokens(tokens_input)
logger.info(f"Drug tokens: {drug_token_str}")
# Extract subset of attention weights
if hasattr(crossattention_weights, 'cpu'):
crossattention_weights = crossattention_weights.cpu()
subset = crossattention_weights[target_mask][:, drug_mask]
subset_np = subset.detach().numpy() # Convert to numpy for matplotlib
logger.info(f"Subset shape: {subset_np.shape}")
height, width = subset_np.shape
fig, ax = plt.subplots(
figsize=(width * 0.2 + 2, height * 0.2 + 3),
dpi=300
)
im = ax.imshow(subset_np, cmap='hot', interpolation='nearest')
plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.05, shrink=0.8)
plt.title("Cross-Attention Weights")
plt.xlabel("Drug Tokens")
plt.ylabel("Target Tokens")
# Create vertical labels for drug tokens
vertical_labels = ['\n'.join(label) for label in drug_token_str]
plt.xticks(ticks=np.arange(width), labels=vertical_labels)
plt.yticks(ticks=np.arange(height), labels=target_token_str)
# Add text annotations
max_val = subset_np.max()
logger.info(f"Max crossattention weight: {max_val}")
for i in range(height):
for j in range(width):
val = subset_np[i, j]
if val > max_val / 2:
# Extract just the digits after the decimal (no leading '0.')
text = f"{val % 1:.2f}"[2:]
plt.text(j, i, text,
ha='center', va='center',
color="black",
fontsize=6)
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
buf.seek(0)
img = Image.open(buf)
plt.close()
logger.info("Finished plot_crossattention_weights successfully")
return img
def plot_presum(tokenized_input, affinities, scaler, w, b, target_tokenizer,
raw_affinities=False):
"""
Generates an annotated 1D heatmap of token-level contribution scores.
Args:
tokenized_input (dict): Output of a tokenizer with keys:
- 'input_ids' (torch.Tensor): token ID sequences, shape (1, seq_len)
- 'attention_mask' (torch.Tensor): mask indicating padding tokens
affinities (torch.Tensor): Final layer summation affinity contributions from the model, shape (1, seq_len)
scaler (object): Fitted scaler with `mean_` and `std_` attributes for inverse-transform.
w (float): Weight applied to the summed affinities before bias.
b (float): Bias added to the summed affinities.
target_tokenizer: Target tokenizer instance.
raw_affinities (bool): If True, plot raw (signed) contributions on a blue—white—red scale.
If False, enforce non-negative contributions and use a white—red scale.
Default: False
Returns:
PIL.Image: The generated contribution visualization image.
Raises:
ValueError: If `sum(transformed_affinities) < 0` when `raw_affinities=False`.
"""
colors = [
(1.0, 0.95, 0.95),
(1.0, 0.5, 0.5),
(0.8, 0.0, 0.0)
]
custom_reds = LinearSegmentedColormap.from_list("CustomReds", colors)
# Convert tensors to numpy if needed
if hasattr(affinities, 'cpu'):
affinities = affinities.cpu().numpy()
if hasattr(w, 'cpu'):
w = w.cpu().numpy()
if hasattr(b, 'cpu'):
b = b.cpu().numpy()
# Apply transformations
affinities = w * (affinities[0]) + b / len(affinities[0])
affinities = (affinities * scaler.std_) + scaler.mean_ / len(affinities)
if sum(affinities) < 0 and not raw_affinities:
raise ValueError("Cannot use non-raw affinities with negative binding affinity prediction")
# Get token strings
tokens_input = tokenized_input["input_ids"][0]
if hasattr(tokens_input, 'cpu'):
tokens_input = tokens_input.cpu().numpy()
token_str = target_tokenizer.convert_ids_to_tokens(tokens_input)
# Handle padding
pad_mask = tokenized_input["attention_mask"][0] == 0
if hasattr(pad_mask, 'cpu'):
pad_mask = pad_mask.cpu().numpy()
padding_affinities_sum = affinities[pad_mask].sum()
non_padding_affinities = affinities[~pad_mask]
processed_affinities = non_padding_affinities + padding_affinities_sum/len(non_padding_affinities)
# Make affinities non-negative if requested
if not raw_affinities:
all_negative_non_paddings = processed_affinities[processed_affinities < 0]
while(len(all_negative_non_paddings) > 0):
all_positive_non_paddings = processed_affinities[processed_affinities > 0]
processed_affinities[processed_affinities < 0] = 0
processed_affinities[processed_affinities > 0] = all_positive_non_paddings + all_negative_non_paddings.sum()/len(all_positive_non_paddings)
all_negative_non_paddings = processed_affinities[processed_affinities < 0]
# Create visualization
max_per_row = 20
n = len(processed_affinities)
n_rows = int(np.ceil(n / max_per_row))
grid = np.full((n_rows, max_per_row), np.nan)
grid.flat[:n] = processed_affinities
fig, ax = plt.subplots(
figsize = (max_per_row * 1, n_rows * 1 + 2),
dpi = 300
)
ax.set_xticks([])
ax.set_yticks([])
im = ax.imshow(
grid,
aspect='equal',
cmap='bwr' if raw_affinities else custom_reds,
vmin=np.nanmin(grid) if not raw_affinities else -max(abs(np.nanmin(grid)), abs(np.nanmax(grid))),
vmax=np.nanmax(grid) if not raw_affinities else max(abs(np.nanmin(grid)), abs(np.nanmax(grid))),
)
def wrap_text(text, width=8):
return '\n'.join(text[i:i+width] for i in range(0, len(text), width))
for idx, val in enumerate(processed_affinities):
r, c = divmod(idx, max_per_row)
wrapped_token = wrap_text(token_str[idx], width=8)
ax.text(c, r, f"{val:.2f}\n{wrapped_token}",
ha='center', va='center', fontsize=8)
divider = make_axes_locatable(ax)
cax = divider.append_axes('bottom', size=0.2, pad=0.3)
cbar = fig.colorbar(im, cax=cax, orientation='horizontal')
cbar.set_label("Contribution")
# Convert to PIL Image
buf = io.BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
buf.seek(0)
img = Image.open(buf)
plt.close()
return img