import numpy as np import matplotlib.pyplot as plt import os import logging from matplotlib.colors import LinearSegmentedColormap from mpl_toolkits.axes_grid1 import make_axes_locatable from PIL import Image import io # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_tokenized, crossattention_weights, target_tokenizer, drug_tokenizer): """ Plots the cross-attention weights for a given drug-target pair, only considering unmasked tokens. Parameters: target_mask (np.ndarray): Boolean mask for target tokens. drug_mask (np.ndarray): Boolean mask for drug tokens. target_tokenized (dict): Tokenized target sequence. drug_tokenized (dict): Tokenized drug sequence. crossattention_weights (np.ndarray): The cross-attention weights. target_tokenizer: Target tokenizer instance. drug_tokenizer: Drug tokenizer instance. Returns: PIL.Image: The generated attention heatmap image. """ logger.info("Starting plot_crossattention_weights") # Convert masks to numpy arrays if they're tensors if hasattr(target_mask, 'cpu'): target_mask = target_mask.cpu() if hasattr(drug_mask, 'cpu'): drug_mask = drug_mask.cpu() logger.info(f"Target mask shape: {target_mask.shape}, Drug mask shape: {drug_mask.shape}") # Get tokens for unmasked positions tokens_input = target_tokenized["input_ids"][0][target_mask] target_token_str = target_tokenizer.convert_ids_to_tokens(tokens_input) tokens_input = drug_tokenized["input_ids"][0][drug_mask] drug_token_str = drug_tokenizer.convert_ids_to_tokens(tokens_input) logger.info(f"Drug tokens: {drug_token_str}") # Extract subset of attention weights if hasattr(crossattention_weights, 'cpu'): crossattention_weights = crossattention_weights.cpu() subset = crossattention_weights[target_mask][:, drug_mask] subset_np = subset.detach().numpy() # Convert to numpy for matplotlib logger.info(f"Subset shape: {subset_np.shape}") height, width = subset_np.shape fig, ax = plt.subplots( figsize=(width * 0.2 + 2, height * 0.2 + 3), dpi=300 ) im = ax.imshow(subset_np, cmap='hot', interpolation='nearest') plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.05, shrink=0.8) plt.title("Cross-Attention Weights") plt.xlabel("Drug Tokens") plt.ylabel("Target Tokens") # Create vertical labels for drug tokens vertical_labels = ['\n'.join(label) for label in drug_token_str] plt.xticks(ticks=np.arange(width), labels=vertical_labels) plt.yticks(ticks=np.arange(height), labels=target_token_str) # Add text annotations max_val = subset_np.max() logger.info(f"Max crossattention weight: {max_val}") for i in range(height): for j in range(width): val = subset_np[i, j] if val > max_val / 2: # Extract just the digits after the decimal (no leading '0.') text = f"{val % 1:.2f}"[2:] plt.text(j, i, text, ha='center', va='center', color="black", fontsize=6) # Convert to PIL Image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) buf.seek(0) img = Image.open(buf) plt.close() logger.info("Finished plot_crossattention_weights successfully") return img def plot_presum(tokenized_input, affinities, scaler, w, b, target_tokenizer, raw_affinities=False): """ Generates an annotated 1D heatmap of token-level contribution scores. Args: tokenized_input (dict): Output of a tokenizer with keys: - 'input_ids' (torch.Tensor): token ID sequences, shape (1, seq_len) - 'attention_mask' (torch.Tensor): mask indicating padding tokens affinities (torch.Tensor): Final layer summation affinity contributions from the model, shape (1, seq_len) scaler (object): Fitted scaler with `mean_` and `std_` attributes for inverse-transform. w (float): Weight applied to the summed affinities before bias. b (float): Bias added to the summed affinities. target_tokenizer: Target tokenizer instance. raw_affinities (bool): If True, plot raw (signed) contributions on a blue—white—red scale. If False, enforce non-negative contributions and use a white—red scale. Default: False Returns: PIL.Image: The generated contribution visualization image. Raises: ValueError: If `sum(transformed_affinities) < 0` when `raw_affinities=False`. """ colors = [ (1.0, 0.95, 0.95), (1.0, 0.5, 0.5), (0.8, 0.0, 0.0) ] custom_reds = LinearSegmentedColormap.from_list("CustomReds", colors) # Convert tensors to numpy if needed if hasattr(affinities, 'cpu'): affinities = affinities.cpu().numpy() if hasattr(w, 'cpu'): w = w.cpu().numpy() if hasattr(b, 'cpu'): b = b.cpu().numpy() # Apply transformations affinities = w * (affinities[0]) + b / len(affinities[0]) affinities = (affinities * scaler.std_) + scaler.mean_ / len(affinities) if sum(affinities) < 0 and not raw_affinities: raise ValueError("Cannot use non-raw affinities with negative binding affinity prediction") # Get token strings tokens_input = tokenized_input["input_ids"][0] if hasattr(tokens_input, 'cpu'): tokens_input = tokens_input.cpu().numpy() token_str = target_tokenizer.convert_ids_to_tokens(tokens_input) # Handle padding pad_mask = tokenized_input["attention_mask"][0] == 0 if hasattr(pad_mask, 'cpu'): pad_mask = pad_mask.cpu().numpy() padding_affinities_sum = affinities[pad_mask].sum() non_padding_affinities = affinities[~pad_mask] processed_affinities = non_padding_affinities + padding_affinities_sum/len(non_padding_affinities) # Make affinities non-negative if requested if not raw_affinities: all_negative_non_paddings = processed_affinities[processed_affinities < 0] while(len(all_negative_non_paddings) > 0): all_positive_non_paddings = processed_affinities[processed_affinities > 0] processed_affinities[processed_affinities < 0] = 0 processed_affinities[processed_affinities > 0] = all_positive_non_paddings + all_negative_non_paddings.sum()/len(all_positive_non_paddings) all_negative_non_paddings = processed_affinities[processed_affinities < 0] # Create visualization max_per_row = 20 n = len(processed_affinities) n_rows = int(np.ceil(n / max_per_row)) grid = np.full((n_rows, max_per_row), np.nan) grid.flat[:n] = processed_affinities fig, ax = plt.subplots( figsize = (max_per_row * 1, n_rows * 1 + 2), dpi = 300 ) ax.set_xticks([]) ax.set_yticks([]) im = ax.imshow( grid, aspect='equal', cmap='bwr' if raw_affinities else custom_reds, vmin=np.nanmin(grid) if not raw_affinities else -max(abs(np.nanmin(grid)), abs(np.nanmax(grid))), vmax=np.nanmax(grid) if not raw_affinities else max(abs(np.nanmin(grid)), abs(np.nanmax(grid))), ) def wrap_text(text, width=8): return '\n'.join(text[i:i+width] for i in range(0, len(text), width)) for idx, val in enumerate(processed_affinities): r, c = divmod(idx, max_per_row) wrapped_token = wrap_text(token_str[idx], width=8) ax.text(c, r, f"{val:.2f}\n{wrapped_token}", ha='center', va='center', fontsize=8) divider = make_axes_locatable(ax) cax = divider.append_axes('bottom', size=0.2, pad=0.3) cbar = fig.colorbar(im, cax=cax, orientation='horizontal') cbar.set_label("Contribution") # Convert to PIL Image buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) buf.seek(0) img = Image.open(buf) plt.close() return img