IlPakoZ commited on
Commit
78d786d
Β·
verified Β·
1 Parent(s): 05db40e

Upload 2 files

Browse files
Files changed (2) hide show
  1. analysis.py +211 -0
  2. app.py +159 -39
analysis.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ import os
4
+ from matplotlib.colors import LinearSegmentedColormap
5
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
6
+ from PIL import Image
7
+ import io
8
+
9
+ def plot_crossattention_weights(target_mask, drug_mask, target_tokenized, drug_tokenized,
10
+ crossattention_weights, target_tokenizer, drug_tokenizer):
11
+ """
12
+ Plots the cross-attention weights for a given drug-target pair, only considering unmasked tokens.
13
+
14
+ Parameters:
15
+ target_mask (np.ndarray): Boolean mask for target tokens.
16
+ drug_mask (np.ndarray): Boolean mask for drug tokens.
17
+ target_tokenized (dict): Tokenized target sequence.
18
+ drug_tokenized (dict): Tokenized drug sequence.
19
+ crossattention_weights (np.ndarray): The cross-attention weights.
20
+ target_tokenizer: Target tokenizer instance.
21
+ drug_tokenizer: Drug tokenizer instance.
22
+
23
+ Returns:
24
+ PIL.Image: The generated attention heatmap image.
25
+ """
26
+ # Convert masks to numpy arrays if they're tensors
27
+ if hasattr(target_mask, 'cpu'):
28
+ target_mask = target_mask.cpu().numpy()
29
+ if hasattr(drug_mask, 'cpu'):
30
+ drug_mask = drug_mask.cpu().numpy()
31
+
32
+ # Convert boolean masks
33
+ target_mask = target_mask.astype(bool)
34
+ drug_mask = drug_mask.astype(bool)
35
+
36
+ # Get tokens for unmasked positions
37
+ tokens_input = target_tokenized["input_ids"][0][target_mask]
38
+ if hasattr(tokens_input, 'cpu'):
39
+ tokens_input = tokens_input.cpu().numpy()
40
+ target_token_str = target_tokenizer.convert_ids_to_tokens(tokens_input)
41
+
42
+ tokens_input = drug_tokenized["input_ids"][0][drug_mask]
43
+ if hasattr(tokens_input, 'cpu'):
44
+ tokens_input = tokens_input.cpu().numpy()
45
+ drug_token_str = drug_tokenizer.convert_ids_to_tokens(tokens_input)
46
+
47
+ # Extract subset of attention weights
48
+ if hasattr(crossattention_weights, 'cpu'):
49
+ crossattention_weights = crossattention_weights.cpu().numpy()
50
+
51
+ subset = crossattention_weights[target_mask][:, drug_mask]
52
+ height, width = subset.shape
53
+
54
+ fig, ax = plt.subplots(
55
+ figsize=(width * 0.2 + 2, height * 0.2 + 3),
56
+ dpi=300
57
+ )
58
+ im = ax.imshow(subset, cmap='hot', interpolation='nearest')
59
+
60
+ plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.05, shrink=0.8)
61
+
62
+ plt.title("Cross-Attention Weights")
63
+ plt.xlabel("Drug Tokens")
64
+ plt.ylabel("Target Tokens")
65
+
66
+ # Create vertical labels for drug tokens
67
+ vertical_labels = ['\n'.join(label) for label in drug_token_str]
68
+ plt.xticks(ticks=np.arange(width), labels=vertical_labels)
69
+ plt.yticks(ticks=np.arange(height), labels=target_token_str)
70
+
71
+ # Add text annotations
72
+ max_val = subset.max()
73
+ for i in range(height):
74
+ for j in range(width):
75
+ val = subset[i, j]
76
+ if val > max_val / 2:
77
+ # Extract just the digits after the decimal (no leading '0.')
78
+ text = f"{val % 1:.2f}"[2:]
79
+ plt.text(j, i, text,
80
+ ha='center', va='center',
81
+ color="black",
82
+ fontsize=6)
83
+
84
+ # Convert to PIL Image
85
+ buf = io.BytesIO()
86
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
87
+ buf.seek(0)
88
+ img = Image.open(buf)
89
+ plt.close()
90
+
91
+ return img
92
+
93
+
94
+ def plot_presum(tokenized_input, affinities, scaler, w, b, target_tokenizer,
95
+ raw_affinities=False):
96
+ """
97
+ Generates an annotated 1D heatmap of token-level contribution scores.
98
+
99
+ Args:
100
+ tokenized_input (dict): Output of a tokenizer with keys:
101
+ - 'input_ids' (torch.Tensor): token ID sequences, shape (1, seq_len)
102
+ - 'attention_mask' (torch.Tensor): mask indicating padding tokens
103
+ affinities (torch.Tensor): Final layer summation affinity contributions from the model, shape (1, seq_len)
104
+ scaler (object): Fitted scaler with `mean_` and `std_` attributes for inverse-transform.
105
+ w (float): Weight applied to the summed affinities before bias.
106
+ b (float): Bias added to the summed affinities.
107
+ target_tokenizer: Target tokenizer instance.
108
+ raw_affinities (bool): If True, plot raw (signed) contributions on a blueβ€”whiteβ€”red scale.
109
+ If False, enforce non-negative contributions and use a whiteβ€”red scale.
110
+ Default: False
111
+
112
+ Returns:
113
+ PIL.Image: The generated contribution visualization image.
114
+
115
+ Raises:
116
+ ValueError: If `sum(transformed_affinities) < 0` when `raw_affinities=False`.
117
+ """
118
+ colors = [
119
+ (1.0, 0.95, 0.95),
120
+ (1.0, 0.5, 0.5),
121
+ (0.8, 0.0, 0.0)
122
+ ]
123
+
124
+ custom_reds = LinearSegmentedColormap.from_list("CustomReds", colors)
125
+
126
+ # Convert tensors to numpy if needed
127
+ if hasattr(affinities, 'cpu'):
128
+ affinities = affinities.cpu().numpy()
129
+ if hasattr(w, 'cpu'):
130
+ w = w.cpu().numpy()
131
+ if hasattr(b, 'cpu'):
132
+ b = b.cpu().numpy()
133
+
134
+ # Apply transformations
135
+ affinities = w * (affinities[0]) + b / len(affinities[0])
136
+ affinities = (affinities * scaler.std_) + scaler.mean_ / len(affinities)
137
+
138
+ if sum(affinities) < 0 and not raw_affinities:
139
+ raise ValueError("Cannot use non-raw affinities with negative binding affinity prediction")
140
+
141
+ # Get token strings
142
+ tokens_input = tokenized_input["input_ids"][0]
143
+ if hasattr(tokens_input, 'cpu'):
144
+ tokens_input = tokens_input.cpu().numpy()
145
+ token_str = target_tokenizer.convert_ids_to_tokens(tokens_input)
146
+
147
+ # Handle padding
148
+ pad_mask = tokenized_input["attention_mask"][0] == 0
149
+ if hasattr(pad_mask, 'cpu'):
150
+ pad_mask = pad_mask.cpu().numpy()
151
+
152
+ padding_affinities_sum = affinities[pad_mask].sum()
153
+ non_padding_affinities = affinities[~pad_mask]
154
+ processed_affinities = non_padding_affinities + padding_affinities_sum/len(non_padding_affinities)
155
+
156
+ # Make affinities non-negative if requested
157
+ if not raw_affinities:
158
+ all_negative_non_paddings = processed_affinities[processed_affinities < 0]
159
+
160
+ while(len(all_negative_non_paddings) > 0):
161
+ all_positive_non_paddings = processed_affinities[processed_affinities > 0]
162
+
163
+ processed_affinities[processed_affinities < 0] = 0
164
+ processed_affinities[processed_affinities > 0] = all_positive_non_paddings + all_negative_non_paddings.sum()/len(all_positive_non_paddings)
165
+ all_negative_non_paddings = processed_affinities[processed_affinities < 0]
166
+
167
+ # Create visualization
168
+ max_per_row = 20
169
+ n = len(processed_affinities)
170
+ n_rows = int(np.ceil(n / max_per_row))
171
+ grid = np.full((n_rows, max_per_row), np.nan)
172
+ grid.flat[:n] = processed_affinities
173
+
174
+ fig, ax = plt.subplots(
175
+ figsize = (max_per_row * 1, n_rows * 1 + 2),
176
+ dpi = 300
177
+ )
178
+
179
+ ax.set_xticks([])
180
+ ax.set_yticks([])
181
+
182
+ im = ax.imshow(
183
+ grid,
184
+ aspect='equal',
185
+ cmap='bwr' if raw_affinities else custom_reds,
186
+ vmin=np.nanmin(grid) if not raw_affinities else -max(abs(np.nanmin(grid)), abs(np.nanmax(grid))),
187
+ vmax=np.nanmax(grid) if not raw_affinities else max(abs(np.nanmin(grid)), abs(np.nanmax(grid))),
188
+ )
189
+
190
+ def wrap_text(text, width=8):
191
+ return '\n'.join(text[i:i+width] for i in range(0, len(text), width))
192
+
193
+ for idx, val in enumerate(processed_affinities):
194
+ r, c = divmod(idx, max_per_row)
195
+ wrapped_token = wrap_text(token_str[idx], width=8)
196
+ ax.text(c, r, f"{val:.2f}\n{wrapped_token}",
197
+ ha='center', va='center', fontsize=8)
198
+
199
+ divider = make_axes_locatable(ax)
200
+ cax = divider.append_axes('bottom', size=0.2, pad=0.3)
201
+ cbar = fig.colorbar(im, cax=cax, orientation='horizontal')
202
+ cbar.set_label("Contribution")
203
+
204
+ # Convert to PIL Image
205
+ buf = io.BytesIO()
206
+ plt.savefig(buf, format='png', bbox_inches='tight', dpi=300)
207
+ buf.seek(0)
208
+ img = Image.open(buf)
209
+ plt.close()
210
+
211
+ return img
app.py CHANGED
@@ -10,6 +10,9 @@ import os
10
  from pathlib import Path
11
  import logging
12
 
 
 
 
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
@@ -93,9 +96,6 @@ class DrugTargetInteractionApp:
93
  return_tensors="pt"
94
  ).to(self.device)
95
 
96
- logger.info(f"Drug inputs: {drug_inputs}")
97
- logger.info(f"Target inputs: {target_inputs}")
98
-
99
  # Make prediction
100
  with torch.no_grad():
101
  prediction = self.model(target_inputs, drug_inputs)
@@ -111,58 +111,118 @@ class DrugTargetInteractionApp:
111
  except Exception as e:
112
  logger.error(f"Prediction error: {str(e)}")
113
  return f"Error during prediction: {str(e)}"
114
-
115
- def get_attention_visualization(self, target_sequence, drug_smiles, max_length=512):
116
- """Get attention weights for visualization"""
 
 
 
 
 
 
 
 
 
117
  if self.model is None:
118
- return None, "Model not loaded"
119
 
120
  try:
121
- # Enable interpretation mode
122
- self.model.INTERPR_ENABLE_MODE()
123
-
124
  # Tokenize inputs
125
  target_inputs = self.target_tokenizer(
126
  target_sequence,
127
- padding=True,
128
- truncation=True,
129
- max_length=max_length,
130
  return_tensors="pt"
131
  ).to(self.device)
132
 
133
  drug_inputs = self.drug_tokenizer(
134
  drug_smiles,
135
- padding=True,
136
- truncation=True,
137
- max_length=max_length,
138
  return_tensors="pt"
139
  ).to(self.device)
140
-
141
- # Make prediction to get attention weights
142
  with torch.no_grad():
143
- _ = self.model(target_inputs, drug_inputs)
144
 
145
- # Get attention weights
146
- attention_weights = self.model.model.crossattention_weights
147
- if attention_weights is not None:
148
- # Convert to numpy for visualization
149
- attention_weights = attention_weights.cpu().numpy()
150
-
151
- # Get tokens for visualization
152
- target_tokens = self.target_tokenizer.convert_ids_to_tokens(
153
- target_inputs["input_ids"][0], skip_special_tokens=True
154
- )
155
- drug_tokens = self.drug_tokenizer.convert_ids_to_tokens(
156
- drug_inputs["input_ids"][0], skip_special_tokens=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  )
158
-
159
- return attention_weights, target_tokens, drug_tokens, "Attention visualization ready"
160
  else:
161
- return None, None, None, "No attention weights available"
162
-
 
 
 
 
 
 
 
 
 
163
  except Exception as e:
164
- logger.error(f"Attention visualization error: {str(e)}")
165
- return None, None, None, f"Error: {str(e)}"
 
166
 
167
  # Initialize the app
168
  app = DrugTargetInteractionApp()
@@ -174,6 +234,13 @@ def predict_wrapper(target_seq, drug_smiles):
174
 
175
  return app.predict_interaction(target_seq, drug_smiles)
176
 
 
 
 
 
 
 
 
177
  def load_model_wrapper(model_path):
178
  """Wrapper function to load model"""
179
  if app.load_model(model_path):
@@ -210,7 +277,10 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
210
  lines=2
211
  )
212
 
213
- predict_btn = gr.Button("πŸš€ Predict Interaction", variant="primary", size="lg")
 
 
 
214
 
215
  with gr.Column(scale=1):
216
  prediction_output = gr.Textbox(
@@ -219,6 +289,43 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
219
  lines=3
220
  )
221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  # Example inputs
223
  gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>πŸ“š Example Inputs:</h3>")
224
 
@@ -239,11 +346,18 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
239
  cache_examples=False
240
  )
241
 
 
242
  predict_btn.click(
243
  fn=predict_wrapper,
244
  inputs=[target_input, drug_input],
245
  outputs=prediction_output
246
  )
 
 
 
 
 
 
247
 
248
  with gr.Tab("βš™οΈ Model Settings"):
249
  gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>")
@@ -286,14 +400,20 @@ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()
286
  - Cross-attention for drug-target interaction modeling
287
  - Dropout for regularization
288
  - Layer normalization for stable training
289
- - Interpretability mode for attention visualization
290
 
291
  ### Usage Tips:
292
  1. Load your trained model using the Model Settings tab
293
  2. Enter a RNA sequence and drug SMILES
294
  3. Click "Predict Interaction" to get binding affinity prediction
 
295
 
296
  For best results, ensure your input sequences are properly formatted and within reasonable length limits.
 
 
 
 
 
297
  """)
298
 
299
  # Launch the app
 
10
  from pathlib import Path
11
  import logging
12
 
13
+ # Import visualization functions
14
+ from analysis import plot_crossattention_weights, plot_presum
15
+
16
  # Configure logging
17
  logging.basicConfig(level=logging.INFO)
18
  logger = logging.getLogger(__name__)
 
96
  return_tensors="pt"
97
  ).to(self.device)
98
 
 
 
 
99
  # Make prediction
100
  with torch.no_grad():
101
  prediction = self.model(target_inputs, drug_inputs)
 
111
  except Exception as e:
112
  logger.error(f"Prediction error: {str(e)}")
113
  return f"Error during prediction: {str(e)}"
114
+
115
+ def visualize_interaction(self, target_sequence, drug_smiles):
116
+ """
117
+ Generate visualization images for drug-target interaction
118
+
119
+ Args:
120
+ target_sequence (str): RNA sequence
121
+ drug_smiles (str): Drug SMILES notation
122
+
123
+ Returns:
124
+ tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message)
125
+ """
126
  if self.model is None:
127
+ return None, None, None, "Error: Model not loaded. Please load a model first."
128
 
129
  try:
 
 
 
130
  # Tokenize inputs
131
  target_inputs = self.target_tokenizer(
132
  target_sequence,
133
+ padding="max_length",
134
+ truncation=True,
135
+ max_length=512,
136
  return_tensors="pt"
137
  ).to(self.device)
138
 
139
  drug_inputs = self.drug_tokenizer(
140
  drug_smiles,
141
+ padding="max_length",
142
+ truncation=True,
143
+ max_length=512,
144
  return_tensors="pt"
145
  ).to(self.device)
146
+
147
+ # Make prediction and extract visualization data
148
  with torch.no_grad():
149
+ prediction = self.model(target_inputs, drug_inputs)
150
 
151
+ # Unscale if scaler exists
152
+ if self.model.scaler is not None:
153
+ prediction = self.model.unscale(prediction)
154
+
155
+ prediction_value = prediction.cpu().numpy()[0][0]
156
+
157
+ # Extract data needed for visualizations
158
+ presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
159
+ cross_attention_weights = self.model.model.crossattention_weights[0][0] # Shape: (seq_len, seq_len)
160
+
161
+ # Get model parameters for scaling
162
+ w = self.model.model.w.squeeze(1)
163
+ b = self.model.model.b
164
+ scaler = self.model.model.scaler
165
+
166
+ # Generate visualizations
167
+ try:
168
+ # 1. Cross-attention heatmap
169
+ cross_attention_img = plot_crossattention_weights(
170
+ target_inputs["attention_mask"][0],
171
+ drug_inputs["attention_mask"][0],
172
+ target_inputs,
173
+ drug_inputs,
174
+ cross_attention_weights,
175
+ self.target_tokenizer,
176
+ self.drug_tokenizer
177
+ )
178
+ except Exception as e:
179
+ logger.error(f"Cross-attention visualization error: {str(e)}")
180
+ cross_attention_img = None
181
+
182
+ try:
183
+ # 2. Normalized contribution visualization (always generate)
184
+ normalized_img = plot_presum(
185
+ target_inputs,
186
+ presum_values,
187
+ scaler,
188
+ w,
189
+ b,
190
+ self.target_tokenizer,
191
+ raw_affinities=False
192
+ )
193
+ except Exception as e:
194
+ logger.error(f"Normalized contribution visualization error: {str(e)}")
195
+ normalized_img = None
196
+
197
+ try:
198
+ # 3. Raw contribution visualization (only if pKd > 0)
199
+ raw_img = None
200
+ if prediction_value > 0:
201
+ raw_img = plot_presum(
202
+ target_inputs,
203
+ presum_values,
204
+ scaler,
205
+ w,
206
+ b,
207
+ self.target_tokenizer,
208
+ raw_affinities=True
209
  )
 
 
210
  else:
211
+ logger.info("Skipping raw affinities visualization as pKd <= 0")
212
+ except Exception as e:
213
+ logger.error(f"Raw contribution visualization error: {str(e)}")
214
+ raw_img = None
215
+
216
+ status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
217
+ if prediction_value <= 0:
218
+ status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
219
+
220
+ return cross_attention_img, raw_img, normalized_img, status_msg
221
+
222
  except Exception as e:
223
+ logger.error(f"Visualization error: {str(e)}")
224
+ return None, None, None, f"Error during visualization: {str(e)}"
225
+
226
 
227
  # Initialize the app
228
  app = DrugTargetInteractionApp()
 
234
 
235
  return app.predict_interaction(target_seq, drug_smiles)
236
 
237
+ def visualize_wrapper(target_seq, drug_smiles):
238
+ """Wrapper function for visualization"""
239
+ if not target_seq.strip() or not drug_smiles.strip():
240
+ return None, None, None, "Please provide both target sequence and drug SMILES."
241
+
242
+ return app.visualize_interaction(target_seq, drug_smiles)
243
+
244
  def load_model_wrapper(model_path):
245
  """Wrapper function to load model"""
246
  if app.load_model(model_path):
 
277
  lines=2
278
  )
279
 
280
+ # Buttons side by side
281
+ with gr.Row():
282
+ predict_btn = gr.Button("πŸš€ Predict Interaction", variant="primary", size="lg")
283
+ visualize_btn = gr.Button("πŸ“Š Visualize Interaction", variant="secondary", size="lg")
284
 
285
  with gr.Column(scale=1):
286
  prediction_output = gr.Textbox(
 
289
  lines=3
290
  )
291
 
292
+ # Visualization outputs section
293
+ gr.HTML("<h3 style='margin-top: 30px; color: #2E86AB;'>πŸ“ˆ Interaction Visualizations</h3>")
294
+
295
+ with gr.Row():
296
+ with gr.Column():
297
+ viz_image1 = gr.Image(
298
+ label="Cross-Attention Heatmap",
299
+ type="pil",
300
+ interactive=False,
301
+ container=True,
302
+ height=300
303
+ )
304
+
305
+ with gr.Column():
306
+ viz_image2 = gr.Image(
307
+ label="Raw pKd Contribution Visualization",
308
+ type="pil",
309
+ interactive=False,
310
+ container=True,
311
+ height=300
312
+ )
313
+
314
+ with gr.Column():
315
+ viz_image3 = gr.Image(
316
+ label="Normalized pKd Contribution Visualization",
317
+ type="pil",
318
+ interactive=False,
319
+ container=True,
320
+ height=300
321
+ )
322
+
323
+ viz_status = gr.Textbox(
324
+ label="Visualization Status",
325
+ interactive=False,
326
+ lines=2
327
+ )
328
+
329
  # Example inputs
330
  gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>πŸ“š Example Inputs:</h3>")
331
 
 
346
  cache_examples=False
347
  )
348
 
349
+ # Button click events
350
  predict_btn.click(
351
  fn=predict_wrapper,
352
  inputs=[target_input, drug_input],
353
  outputs=prediction_output
354
  )
355
+
356
+ visualize_btn.click(
357
+ fn=visualize_wrapper,
358
+ inputs=[target_input, drug_input],
359
+ outputs=[viz_image1, viz_image2, viz_image3, viz_status]
360
+ )
361
 
362
  with gr.Tab("βš™οΈ Model Settings"):
363
  gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>")
 
400
  - Cross-attention for drug-target interaction modeling
401
  - Dropout for regularization
402
  - Layer normalization for stable training
403
+ - Interpretability mode for contribution and attention visualization
404
 
405
  ### Usage Tips:
406
  1. Load your trained model using the Model Settings tab
407
  2. Enter a RNA sequence and drug SMILES
408
  3. Click "Predict Interaction" to get binding affinity prediction
409
+ 4. Click "Visualize Interaction" to see detailed interaction analysis
410
 
411
  For best results, ensure your input sequences are properly formatted and within reasonable length limits.
412
+
413
+ ### Visualization Features:
414
+ - **Cross-Attention Heatmap**: Shows cross-attention between drug and target tokens
415
+ - **Raw pKd Contribution**: Shows raw signed contributions (only when pKd > 0)
416
+ - **Normalized pKd Contribution**: Shows normalized non-negative contributions
417
  """)
418
 
419
  # Launch the app