Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import os | |
| import logging | |
| from matplotlib.colors import LinearSegmentedColormap | |
| from mpl_toolkits.axes_grid1 import make_axes_locatable | |
| from PIL import Image | |
| import io | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_tokenized, | |
| crossattention_weights, target_tokenizer, drug_tokenizer): | |
| """ | |
| Plots the cross-attention weights for a given drug-target pair, only considering unmasked tokens. | |
| Parameters: | |
| target_mask (np.ndarray): Boolean mask for target tokens. | |
| drug_mask (np.ndarray): Boolean mask for drug tokens. | |
| target_tokenized (dict): Tokenized target sequence. | |
| drug_tokenized (dict): Tokenized drug sequence. | |
| crossattention_weights (np.ndarray): The cross-attention weights. | |
| target_tokenizer: Target tokenizer instance. | |
| drug_tokenizer: Drug tokenizer instance. | |
| Returns: | |
| PIL.Image: The generated attention heatmap image. | |
| """ | |
| logger.info("Starting plot_crossattention_weights") | |
| # Convert masks to numpy arrays if they're tensors | |
| if hasattr(target_mask, 'cpu'): | |
| target_mask = target_mask.cpu() | |
| if hasattr(drug_mask, 'cpu'): | |
| drug_mask = drug_mask.cpu() | |
| logger.info(f"Target mask shape: {target_mask.shape}, Drug mask shape: {drug_mask.shape}") | |
| # Get tokens for unmasked positions | |
| tokens_input = target_tokenized["input_ids"][0][target_mask] | |
| target_token_str = target_tokenizer.convert_ids_to_tokens(tokens_input) | |
| tokens_input = drug_tokenized["input_ids"][0][drug_mask] | |
| drug_token_str = drug_tokenizer.convert_ids_to_tokens(tokens_input) | |
| logger.info(f"Drug tokens: {drug_token_str}") | |
| # Extract subset of attention weights | |
| if hasattr(crossattention_weights, 'cpu'): | |
| crossattention_weights = crossattention_weights.cpu() | |
| subset = crossattention_weights[target_mask][:, drug_mask] | |
| subset_np = subset.detach().numpy() # Convert to numpy for matplotlib | |
| logger.info(f"Subset shape: {subset_np.shape}") | |
| height, width = subset_np.shape | |
| fig, ax = plt.subplots( | |
| figsize=(width * 0.2 + 2, height * 0.2 + 3), | |
| dpi=300 | |
| ) | |
| im = ax.imshow(subset_np, cmap='hot', interpolation='nearest') | |
| plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.05, shrink=0.8) | |
| plt.title("Cross-Attention Weights") | |
| plt.xlabel("Drug Tokens") | |
| plt.ylabel("Target Tokens") | |
| # Create vertical labels for drug tokens | |
| vertical_labels = ['\n'.join(label) for label in drug_token_str] | |
| plt.xticks(ticks=np.arange(width), labels=vertical_labels) | |
| plt.yticks(ticks=np.arange(height), labels=target_token_str) | |
| # Add text annotations | |
| max_val = subset_np.max() | |
| logger.info(f"Max crossattention weight: {max_val}") | |
| for i in range(height): | |
| for j in range(width): | |
| val = subset_np[i, j] | |
| if val > max_val / 2: | |
| # Extract just the digits after the decimal (no leading '0.') | |
| text = f"{val % 1:.2f}"[2:] | |
| plt.text(j, i, text, | |
| ha='center', va='center', | |
| color="black", | |
| fontsize=6) | |
| # Convert to PIL Image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| logger.info("Finished plot_crossattention_weights successfully") | |
| return img | |
| def plot_presum(tokenized_input, affinities, scaler, w, b, target_tokenizer, | |
| raw_affinities=False): | |
| """ | |
| Generates an annotated 1D heatmap of token-level contribution scores. | |
| Args: | |
| tokenized_input (dict): Output of a tokenizer with keys: | |
| - 'input_ids' (torch.Tensor): token ID sequences, shape (1, seq_len) | |
| - 'attention_mask' (torch.Tensor): mask indicating padding tokens | |
| affinities (torch.Tensor): Final layer summation affinity contributions from the model, shape (1, seq_len) | |
| scaler (object): Fitted scaler with `mean_` and `std_` attributes for inverse-transform. | |
| w (float): Weight applied to the summed affinities before bias. | |
| b (float): Bias added to the summed affinities. | |
| target_tokenizer: Target tokenizer instance. | |
| raw_affinities (bool): If True, plot raw (signed) contributions on a blue—white—red scale. | |
| If False, enforce non-negative contributions and use a white—red scale. | |
| Default: False | |
| Returns: | |
| PIL.Image: The generated contribution visualization image. | |
| Raises: | |
| ValueError: If `sum(transformed_affinities) < 0` when `raw_affinities=False`. | |
| """ | |
| colors = [ | |
| (1.0, 0.95, 0.95), | |
| (1.0, 0.5, 0.5), | |
| (0.8, 0.0, 0.0) | |
| ] | |
| custom_reds = LinearSegmentedColormap.from_list("CustomReds", colors) | |
| # Convert tensors to numpy if needed | |
| if hasattr(affinities, 'cpu'): | |
| affinities = affinities.cpu().numpy() | |
| if hasattr(w, 'cpu'): | |
| w = w.cpu().numpy() | |
| if hasattr(b, 'cpu'): | |
| b = b.cpu().numpy() | |
| # Apply transformations | |
| affinities = w * (affinities[0]) + b / len(affinities[0]) | |
| affinities = (affinities * scaler.std_) + scaler.mean_ / len(affinities) | |
| if sum(affinities) < 0 and not raw_affinities: | |
| raise ValueError("Cannot use non-raw affinities with negative binding affinity prediction") | |
| # Get token strings | |
| tokens_input = tokenized_input["input_ids"][0] | |
| if hasattr(tokens_input, 'cpu'): | |
| tokens_input = tokens_input.cpu().numpy() | |
| token_str = target_tokenizer.convert_ids_to_tokens(tokens_input) | |
| # Handle padding | |
| pad_mask = tokenized_input["attention_mask"][0] == 0 | |
| if hasattr(pad_mask, 'cpu'): | |
| pad_mask = pad_mask.cpu().numpy() | |
| padding_affinities_sum = affinities[pad_mask].sum() | |
| non_padding_affinities = affinities[~pad_mask] | |
| processed_affinities = non_padding_affinities + padding_affinities_sum/len(non_padding_affinities) | |
| # Make affinities non-negative if requested | |
| if not raw_affinities: | |
| all_negative_non_paddings = processed_affinities[processed_affinities < 0] | |
| while(len(all_negative_non_paddings) > 0): | |
| all_positive_non_paddings = processed_affinities[processed_affinities > 0] | |
| processed_affinities[processed_affinities < 0] = 0 | |
| processed_affinities[processed_affinities > 0] = all_positive_non_paddings + all_negative_non_paddings.sum()/len(all_positive_non_paddings) | |
| all_negative_non_paddings = processed_affinities[processed_affinities < 0] | |
| # Create visualization | |
| max_per_row = 20 | |
| n = len(processed_affinities) | |
| n_rows = int(np.ceil(n / max_per_row)) | |
| grid = np.full((n_rows, max_per_row), np.nan) | |
| grid.flat[:n] = processed_affinities | |
| fig, ax = plt.subplots( | |
| figsize = (max_per_row * 1, n_rows * 1 + 2), | |
| dpi = 300 | |
| ) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| im = ax.imshow( | |
| grid, | |
| aspect='equal', | |
| cmap='bwr' if raw_affinities else custom_reds, | |
| vmin=np.nanmin(grid) if not raw_affinities else -max(abs(np.nanmin(grid)), abs(np.nanmax(grid))), | |
| vmax=np.nanmax(grid) if not raw_affinities else max(abs(np.nanmin(grid)), abs(np.nanmax(grid))), | |
| ) | |
| def wrap_text(text, width=8): | |
| return '\n'.join(text[i:i+width] for i in range(0, len(text), width)) | |
| for idx, val in enumerate(processed_affinities): | |
| r, c = divmod(idx, max_per_row) | |
| wrapped_token = wrap_text(token_str[idx], width=8) | |
| ax.text(c, r, f"{val:.2f}\n{wrapped_token}", | |
| ha='center', va='center', fontsize=8) | |
| divider = make_axes_locatable(ax) | |
| cax = divider.append_axes('bottom', size=0.2, pad=0.3) | |
| cbar = fig.colorbar(im, cax=cax, orientation='horizontal') | |
| cbar.set_label("Contribution") | |
| # Convert to PIL Image | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', bbox_inches='tight', dpi=300) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close() | |
| return img |