IlPakoZ commited on
Commit
62a2083
·
verified ·
1 Parent(s): 62cfa20

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -165
app.py CHANGED
@@ -158,182 +158,182 @@ class DrugTargetInteractionApp:
158
  logger.error(f"Prediction error: {str(e)}")
159
  return f"Error during prediction: {str(e)}"
160
 
161
- def visualize_interaction(self, target_sequence, drug_smiles):
162
- """
163
- Generate visualization images for drug-target interaction
164
-
165
- Args:
166
- target_sequence (str): RNA sequence
167
- drug_smiles (str): Drug SMILES notation
168
 
169
- Returns:
170
- tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message)
171
- """
172
- if self.model is None:
173
- return None, None, None, "Error: Model not loaded. Please load a model first."
174
-
175
- try:
176
- # Tokenize inputs
177
- target_inputs = self.target_tokenizer(
178
- target_sequence,
179
- padding="max_length",
180
- truncation=True,
181
- max_length=512,
182
- return_tensors="pt"
183
- ).to(self.device)
184
-
185
- drug_inputs = self.drug_tokenizer(
186
- drug_smiles,
187
- padding="max_length",
188
- truncation=True,
189
- max_length=512,
190
- return_tensors="pt"
191
- ).to(self.device)
192
-
193
- # Enable interpretation mode
194
- self.model.INTERPR_ENABLE_MODE()
195
-
196
- # Make prediction and extract visualization data
197
- with torch.no_grad():
198
- prediction = self.model(target_inputs, drug_inputs)
199
-
200
- # Unscale if scaler exists
201
- if self.model.scaler is not None:
202
- prediction = self.model.unscale(prediction)
203
 
204
- prediction_value = prediction.cpu().numpy()[0][0]
205
-
206
- # Extract data needed for visualizations
207
- presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
208
- cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
209
-
210
- # Get model parameters for scaling
211
- w = self.model.model.w.squeeze(1)
212
- b = self.model.model.b
213
- scaler = self.model.model.scaler
214
-
215
- logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}")
216
- logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}")
217
 
218
- # Generate visualizations
219
  try:
220
- # 1. Cross-attention heatmap
221
- cross_attention_img = None
222
- logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}")
223
- if cross_attention_weights is not None:
224
- logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}")
225
-
226
- try:
227
- cross_attn_matrix = cross_attention_weights[0, 0]
228
-
229
- if cross_attn_matrix is not None:
230
- logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}")
231
- logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}")
232
- logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}")
233
-
 
 
 
 
 
234
 
235
- cross_attention_img = plot_crossattention_weights(
236
- target_inputs["attention_mask"][0],
237
- drug_inputs["attention_mask"][0],
238
- target_inputs,
239
- drug_inputs,
240
- cross_attn_matrix,
241
- self.target_tokenizer,
242
- self.drug_tokenizer
243
- )
244
- else:
245
- logger.warning("Could not extract valid cross-attention matrix")
246
-
247
- except (IndexError, TypeError, AttributeError) as e:
248
- logger.warning(f"Error extracting cross-attention matrix: {str(e)}")
249
- cross_attn_matrix = None
250
- else:
251
- logger.warning("Cross-attention weights are None")
252
 
253
- except Exception as e:
254
- logger.error(f"Cross-attention visualization error: {str(e)}")
255
- cross_attention_img = None
256
-
257
- try:
258
- # 2. Normalized contribution visualization (always generate)
259
- normalized_img = None
260
- if presum_values is not None:
261
- normalized_img = plot_presum(
262
- target_inputs,
263
- presum_values.detach(), # Detach the tensor
264
- scaler,
265
- w.detach(), # Detach the tensor
266
- b.detach(), # Detach the tensor
267
- self.target_tokenizer,
268
- raw_affinities=False
269
- )
270
- else:
271
- logger.warning("Presum values are None")
272
 
273
- except Exception as e:
274
- logger.error(f"Normalized contribution visualization error: {str(e)}")
275
- normalized_img = None
276
-
277
- try:
278
- # 3. Raw contribution visualization (only if pKd > 0)
279
- raw_img = None
280
- if prediction_value > 0 and presum_values is not None:
281
- raw_img = plot_presum(
282
- target_inputs,
283
- presum_values.detach(), # Detach the tensor
284
- scaler,
285
- w.detach(), # Detach the tensor
286
- b.detach(), # Detach the tensor
287
- self.target_tokenizer,
288
- raw_affinities=True
289
- )
290
- else:
291
- if prediction_value <= 0:
292
- logger.info("Skipping raw affinities visualization as pKd <= 0")
293
- if presum_values is None:
294
- logger.warning("Cannot generate raw visualization: presum values are None")
295
 
296
- except Exception as e:
297
- logger.error(f"Raw contribution visualization error: {str(e)}")
298
- raw_img = None
299
-
300
- # Disable interpretation mode after use
301
- self.model.INTERPR_DISABLE_MODE()
302
-
303
- # Create placeholder images if generation failed
304
- if cross_attention_img is None:
305
- cross_attention_img = create_placeholder_image(
306
- text="Cross-Attention Heatmap\nFailed to generate"
307
- )
308
- if normalized_img is None:
309
- normalized_img = create_placeholder_image(
310
- text="Normalized Contribution\nFailed to generate"
311
- )
312
- if raw_img is None and prediction_value > 0:
313
- raw_img = create_placeholder_image(
314
- text="Raw Contribution\nFailed to generate"
315
- )
316
- elif raw_img is None:
317
- raw_img = create_placeholder_image(
318
- text="Raw Contribution\nSkipped (pKd 0)"
319
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
321
- status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
322
- if prediction_value <= 0:
323
- status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
324
- if cross_attention_weights is None:
325
- status_msg += " (Cross-attention visualization failed: weights not available)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- return cross_attention_img, raw_img, normalized_img, status_msg
328
-
329
- except Exception as e:
330
- logger.error(f"Visualization error: {str(e)}")
331
- # Make sure to disable interpretation mode even if there's an error
332
- try:
333
  self.model.INTERPR_DISABLE_MODE()
334
- except:
335
- pass
336
- return None, None, None, f"Error during visualization: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
 
339
  # Initialize the app
 
158
  logger.error(f"Prediction error: {str(e)}")
159
  return f"Error during prediction: {str(e)}"
160
 
161
+ def visualize_interaction(self, target_sequence, drug_smiles):
162
+ """
163
+ Generate visualization images for drug-target interaction
 
 
 
 
164
 
165
+ Args:
166
+ target_sequence (str): RNA sequence
167
+ drug_smiles (str): Drug SMILES notation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ Returns:
170
+ tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message)
171
+ """
172
+ if self.model is None:
173
+ return None, None, None, "Error: Model not loaded. Please load a model first."
 
 
 
 
 
 
 
 
174
 
 
175
  try:
176
+ # Tokenize inputs
177
+ target_inputs = self.target_tokenizer(
178
+ target_sequence,
179
+ padding="max_length",
180
+ truncation=True,
181
+ max_length=512,
182
+ return_tensors="pt"
183
+ ).to(self.device)
184
+
185
+ drug_inputs = self.drug_tokenizer(
186
+ drug_smiles,
187
+ padding="max_length",
188
+ truncation=True,
189
+ max_length=512,
190
+ return_tensors="pt"
191
+ ).to(self.device)
192
+
193
+ # Enable interpretation mode
194
+ self.model.INTERPR_ENABLE_MODE()
195
 
196
+ # Make prediction and extract visualization data
197
+ with torch.no_grad():
198
+ prediction = self.model(target_inputs, drug_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ # Unscale if scaler exists
201
+ if self.model.scaler is not None:
202
+ prediction = self.model.unscale(prediction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ prediction_value = prediction.cpu().numpy()[0][0]
205
+
206
+ # Extract data needed for visualizations
207
+ presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
208
+ cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
209
+
210
+ # Get model parameters for scaling
211
+ w = self.model.model.w.squeeze(1)
212
+ b = self.model.model.b
213
+ scaler = self.model.model.scaler
214
+
215
+ logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}")
216
+ logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}")
217
+
218
+ # Generate visualizations
219
+ try:
220
+ # 1. Cross-attention heatmap
221
+ cross_attention_img = None
222
+ logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}")
223
+ if cross_attention_weights is not None:
224
+ logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}")
 
225
 
226
+ try:
227
+ cross_attn_matrix = cross_attention_weights[0, 0]
228
+
229
+ if cross_attn_matrix is not None:
230
+ logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}")
231
+ logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}")
232
+ logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}")
233
+
234
+
235
+ cross_attention_img = plot_crossattention_weights(
236
+ target_inputs["attention_mask"][0],
237
+ drug_inputs["attention_mask"][0],
238
+ target_inputs,
239
+ drug_inputs,
240
+ cross_attn_matrix,
241
+ self.target_tokenizer,
242
+ self.drug_tokenizer
243
+ )
244
+ else:
245
+ logger.warning("Could not extract valid cross-attention matrix")
246
+
247
+ except (IndexError, TypeError, AttributeError) as e:
248
+ logger.warning(f"Error extracting cross-attention matrix: {str(e)}")
249
+ cross_attn_matrix = None
250
+ else:
251
+ logger.warning("Cross-attention weights are None")
252
+
253
+ except Exception as e:
254
+ logger.error(f"Cross-attention visualization error: {str(e)}")
255
+ cross_attention_img = None
256
+
257
+ try:
258
+ # 2. Normalized contribution visualization (always generate)
259
+ normalized_img = None
260
+ if presum_values is not None:
261
+ normalized_img = plot_presum(
262
+ target_inputs,
263
+ presum_values.detach(), # Detach the tensor
264
+ scaler,
265
+ w.detach(), # Detach the tensor
266
+ b.detach(), # Detach the tensor
267
+ self.target_tokenizer,
268
+ raw_affinities=False
269
+ )
270
+ else:
271
+ logger.warning("Presum values are None")
272
+
273
+ except Exception as e:
274
+ logger.error(f"Normalized contribution visualization error: {str(e)}")
275
+ normalized_img = None
276
 
277
+ try:
278
+ # 3. Raw contribution visualization (only if pKd > 0)
279
+ raw_img = None
280
+ if prediction_value > 0 and presum_values is not None:
281
+ raw_img = plot_presum(
282
+ target_inputs,
283
+ presum_values.detach(), # Detach the tensor
284
+ scaler,
285
+ w.detach(), # Detach the tensor
286
+ b.detach(), # Detach the tensor
287
+ self.target_tokenizer,
288
+ raw_affinities=True
289
+ )
290
+ else:
291
+ if prediction_value <= 0:
292
+ logger.info("Skipping raw affinities visualization as pKd <= 0")
293
+ if presum_values is None:
294
+ logger.warning("Cannot generate raw visualization: presum values are None")
295
+
296
+ except Exception as e:
297
+ logger.error(f"Raw contribution visualization error: {str(e)}")
298
+ raw_img = None
299
 
300
+ # Disable interpretation mode after use
 
 
 
 
 
301
  self.model.INTERPR_DISABLE_MODE()
302
+
303
+ # Create placeholder images if generation failed
304
+ if cross_attention_img is None:
305
+ cross_attention_img = create_placeholder_image(
306
+ text="Cross-Attention Heatmap\nFailed to generate"
307
+ )
308
+ if normalized_img is None:
309
+ normalized_img = create_placeholder_image(
310
+ text="Normalized Contribution\nFailed to generate"
311
+ )
312
+ if raw_img is None and prediction_value > 0:
313
+ raw_img = create_placeholder_image(
314
+ text="Raw Contribution\nFailed to generate"
315
+ )
316
+ elif raw_img is None:
317
+ raw_img = create_placeholder_image(
318
+ text="Raw Contribution\nSkipped (pKd ≤ 0)"
319
+ )
320
+
321
+ status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
322
+ if prediction_value <= 0:
323
+ status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
324
+ if cross_attention_weights is None:
325
+ status_msg += " (Cross-attention visualization failed: weights not available)"
326
+
327
+ return cross_attention_img, raw_img, normalized_img, status_msg
328
+
329
+ except Exception as e:
330
+ logger.error(f"Visualization error: {str(e)}")
331
+ # Make sure to disable interpretation mode even if there's an error
332
+ try:
333
+ self.model.INTERPR_DISABLE_MODE()
334
+ except:
335
+ pass
336
+ return None, None, None, f"Error during visualization: {str(e)}"
337
 
338
 
339
  # Initialize the app