IlPakoZ commited on
Commit
a899ac7
·
verified ·
1 Parent(s): c0031a7

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +643 -0
app.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel
5
+ from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler
6
+ from configuration_dlmberta import InteractionModelATTNConfig
7
+ from chemberta import ChembertaTokenizer
8
+ import json
9
+ import os
10
+ from pathlib import Path
11
+ import logging
12
+
13
+ # Import visualization functions
14
+ from analysis import plot_crossattention_weights, plot_presum
15
+ from PIL import Image, ImageDraw, ImageFont
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)):
22
+ """
23
+ Create a transparent placeholder image with text
24
+
25
+ Args:
26
+ width (int): Image width
27
+ height (int): Image height
28
+ text (str): Text to display
29
+ bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent
30
+
31
+ Returns:
32
+ PIL.Image: Transparent placeholder image
33
+ """
34
+ # Create image with transparent background
35
+ img = Image.new('RGBA', (width, height), bg_color)
36
+ draw = ImageDraw.Draw(img)
37
+
38
+ # Try to use a default font, fallback to default if not available
39
+ try:
40
+ font = ImageFont.truetype("arial.ttf", 16)
41
+ except:
42
+ try:
43
+ font = ImageFont.load_default()
44
+ except:
45
+ font = None
46
+
47
+ # Get text size and position for centering
48
+ if font:
49
+ bbox = draw.textbbox((0, 0), text, font=font)
50
+ text_width = bbox[2] - bbox[0]
51
+ text_height = bbox[3] - bbox[1]
52
+ else:
53
+ # Rough estimation if no font available
54
+ text_width = len(text) * 8
55
+ text_height = 16
56
+
57
+ x = (width - text_width) // 2
58
+ y = (height - text_height) // 2
59
+
60
+ # Draw text in gray
61
+ draw.text((x, y), text, fill=(128, 128, 128, 255), font=font)
62
+
63
+ return img
64
+
65
+ class DrugTargetInteractionApp:
66
+ def __init__(self):
67
+ self.model = None
68
+ self.target_tokenizer = None
69
+ self.drug_tokenizer = None
70
+ self.scaler = None
71
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+
73
+ def load_model(self, model_path="./"):
74
+ """Load the pre-trained model and tokenizers"""
75
+ try:
76
+ # Load configuration
77
+ config = InteractionModelATTNConfig.from_pretrained(model_path)
78
+
79
+ # Load drug encoder (ChemBERTa)
80
+ drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
81
+ drug_encoder_config.pooler = None
82
+ drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)
83
+
84
+ # Load target encoder
85
+ target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
86
+
87
+ # Load scaler if exists
88
+ scaler_path = os.path.join(model_path, "scaler.config")
89
+ scaler = None
90
+ if os.path.exists(scaler_path):
91
+ scaler = StdScaler()
92
+ scaler.load(model_path)
93
+
94
+ self.model = InteractionModelATTNForRegression.from_pretrained(
95
+ model_path,
96
+ config=config,
97
+ target_encoder=target_encoder,
98
+ drug_encoder=drug_encoder,
99
+ scaler=scaler
100
+ )
101
+
102
+ self.model.to(self.device)
103
+ self.model.eval()
104
+
105
+ # Load tokenizers
106
+ self.target_tokenizer = AutoTokenizer.from_pretrained(
107
+ os.path.join(model_path, "target_tokenizer")
108
+ )
109
+
110
+ # Load drug tokenizer (ChemBERTa)
111
+ vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json")
112
+ self.drug_tokenizer = ChembertaTokenizer(vocab_file)
113
+
114
+ logger.info("Model and tokenizers loaded successfully!")
115
+ return True
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error loading model: {str(e)}")
119
+ return False
120
+
121
+ def predict_interaction(self, target_sequence, drug_smiles, max_length=512):
122
+ """Predict drug-target interaction"""
123
+ if self.model is None:
124
+ return "Error: Model not loaded. Please load a model first."
125
+
126
+ try:
127
+ # Tokenize inputs
128
+ target_inputs = self.target_tokenizer(
129
+ target_sequence,
130
+ padding="max_length",
131
+ truncation=True,
132
+ max_length=512,
133
+ return_tensors="pt"
134
+ ).to(self.device)
135
+
136
+ drug_inputs = self.drug_tokenizer(
137
+ drug_smiles,
138
+ padding="max_length",
139
+ truncation=True,
140
+ max_length=512,
141
+ return_tensors="pt"
142
+ ).to(self.device)
143
+
144
+ # Make prediction
145
+ self.model.INTERPR_DISABLE_MODE()
146
+ with torch.no_grad():
147
+ prediction = self.model(target_inputs, drug_inputs)
148
+
149
+ # Unscale if scaler exists
150
+ if self.model.scaler is not None:
151
+ prediction = self.model.unscale(prediction)
152
+
153
+ prediction_value = prediction.cpu().numpy()[0][0]
154
+
155
+ return f"Predicted Binding Affinity: {prediction_value:.4f}"
156
+
157
+ except Exception as e:
158
+ logger.error(f"Prediction error: {str(e)}")
159
+ return f"Error during prediction: {str(e)}"
160
+
161
+ def visualize_interaction(self, target_sequence, drug_smiles):
162
+ """
163
+ Generate visualization images for drug-target interaction
164
+
165
+ Args:
166
+ target_sequence (str): RNA sequence
167
+ drug_smiles (str): Drug SMILES notation
168
+
169
+ Returns:
170
+ tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message)
171
+ """
172
+ if self.model is None:
173
+ return None, None, None, "Error: Model not loaded. Please load a model first."
174
+
175
+ try:
176
+ # Tokenize inputs
177
+ target_inputs = self.target_tokenizer(
178
+ target_sequence,
179
+ padding="max_length",
180
+ truncation=True,
181
+ max_length=512,
182
+ return_tensors="pt"
183
+ ).to(self.device)
184
+
185
+ drug_inputs = self.drug_tokenizer(
186
+ drug_smiles,
187
+ padding="max_length",
188
+ truncation=True,
189
+ max_length=512,
190
+ return_tensors="pt"
191
+ ).to(self.device)
192
+
193
+ # Enable interpretation mode
194
+ self.model.INTERPR_ENABLE_MODE()
195
+
196
+ # Make prediction and extract visualization data
197
+ with torch.no_grad():
198
+ prediction = self.model(target_inputs, drug_inputs)
199
+
200
+ # Unscale if scaler exists
201
+ if self.model.scaler is not None:
202
+ prediction = self.model.unscale(prediction)
203
+
204
+ prediction_value = prediction.cpu().numpy()[0][0]
205
+
206
+ # Extract data needed for visualizations
207
+ presum_values = self.model.model.presum_layer # Shape: (1, seq_len)
208
+ cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len)
209
+
210
+ # Get model parameters for scaling
211
+ w = self.model.model.w.squeeze(1)
212
+ b = self.model.model.b
213
+ scaler = self.model.model.scaler
214
+
215
+ logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}")
216
+ logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}")
217
+
218
+ # Generate visualizations
219
+ try:
220
+ # 1. Cross-attention heatmap
221
+ cross_attention_img = None
222
+ logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}")
223
+ if cross_attention_weights is not None:
224
+ logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}")
225
+
226
+ try:
227
+ cross_attn_matrix = cross_attention_weights[0, 0]
228
+
229
+ if cross_attn_matrix is not None:
230
+ logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}")
231
+ logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}")
232
+ logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}")
233
+
234
+ cross_attention_img = plot_crossattention_weights(
235
+ target_inputs["attention_mask"][0],
236
+ drug_inputs["attention_mask"][0],
237
+ target_inputs,
238
+ drug_inputs,
239
+ cross_attn_matrix,
240
+ self.target_tokenizer,
241
+ self.drug_tokenizer
242
+ )
243
+ else:
244
+ logger.warning("Could not extract valid cross-attention matrix")
245
+
246
+ except (IndexError, TypeError, AttributeError) as e:
247
+ logger.warning(f"Error extracting cross-attention matrix: {str(e)}")
248
+ cross_attn_matrix = None
249
+ else:
250
+ logger.warning("Cross-attention weights are None")
251
+
252
+ except Exception as e:
253
+ logger.error(f"Cross-attention visualization error: {str(e)}")
254
+ cross_attention_img = None
255
+
256
+ try:
257
+ # 2. Normalized contribution visualization (only if pKd > 0)
258
+ normalized_img = None
259
+ if presum_values is not None:
260
+ normalized_img = plot_presum(
261
+ target_inputs,
262
+ presum_values.detach(), # Detach the tensor
263
+ scaler,
264
+ w.detach(), # Detach the tensor
265
+ b.detach(), # Detach the tensor
266
+ self.target_tokenizer,
267
+ raw_affinities=False
268
+ )
269
+ else:
270
+ if prediction_value <= 0:
271
+ logger.info("Skipping normalized affinities visualization as pKd <= 0")
272
+ if presum_values is None:
273
+ logger.warning("Cannot generate raw visualization: presum values are None")
274
+
275
+
276
+ except Exception as e:
277
+ logger.error(f"Normalized contribution visualization error: {str(e)}")
278
+ normalized_img = None
279
+
280
+ try:
281
+ # 3. Raw contribution visualization (always generate)
282
+ raw_img = None
283
+ if prediction_value > 0 and presum_values is not None:
284
+ raw_img = plot_presum(
285
+ target_inputs,
286
+ presum_values.detach(), # Detach the tensor
287
+ scaler,
288
+ w.detach(), # Detach the tensor
289
+ b.detach(), # Detach the tensor
290
+ self.target_tokenizer,
291
+ raw_affinities=True
292
+ )
293
+ else:
294
+ logger.warning("Presum values are None")
295
+
296
+ except Exception as e:
297
+ logger.error(f"Raw contribution visualization error: {str(e)}")
298
+ raw_img = None
299
+
300
+ # Disable interpretation mode after use
301
+ self.model.INTERPR_DISABLE_MODE()
302
+
303
+ # Create placeholder images if generation failed
304
+ if cross_attention_img is None:
305
+ cross_attention_img = create_placeholder_image(
306
+ text="Cross-Attention Heatmap\nFailed to generate"
307
+ )
308
+ if normalized_img is None:
309
+ normalized_img = create_placeholder_image(
310
+ text="Normalized Contribution\nFailed to generate"
311
+ )
312
+ if raw_img is None and prediction_value > 0:
313
+ raw_img = create_placeholder_image(
314
+ text="Raw Contribution\nFailed to generate"
315
+ )
316
+ elif raw_img is None:
317
+ raw_img = create_placeholder_image(
318
+ text="Raw Contribution\nSkipped (pKd ≤ 0)"
319
+ )
320
+
321
+ status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
322
+ if prediction_value <= 0:
323
+ status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
324
+ if cross_attention_weights is None:
325
+ status_msg += " (Cross-attention visualization failed: weights not available)"
326
+
327
+ return cross_attention_img, raw_img, normalized_img, status_msg
328
+
329
+ except Exception as e:
330
+ logger.error(f"Visualization error: {str(e)}")
331
+ # Make sure to disable interpretation mode even if there's an error
332
+ try:
333
+ self.model.INTERPR_DISABLE_MODE()
334
+ except:
335
+ pass
336
+ return None, None, None, f"Error during visualization: {str(e)}"
337
+
338
+
339
+ # Initialize the app
340
+ app = DrugTargetInteractionApp()
341
+
342
+ def predict_wrapper(target_seq, drug_smiles):
343
+ """Wrapper function for Gradio interface"""
344
+ if not target_seq.strip() or not drug_smiles.strip():
345
+ return "Please provide both target sequence and drug SMILES."
346
+
347
+ return app.predict_interaction(target_seq, drug_smiles)
348
+
349
+ def visualize_wrapper(target_seq, drug_smiles):
350
+ """Wrapper function for visualization"""
351
+ if not target_seq.strip() or not drug_smiles.strip():
352
+ return None, None, None, "Please provide both target sequence and drug SMILES."
353
+
354
+ return app.visualize_interaction(target_seq, drug_smiles)
355
+
356
+ def load_model_wrapper(model_path):
357
+ """Wrapper function to load model"""
358
+ if app.load_model(model_path):
359
+ return "Model loaded successfully!"
360
+ else:
361
+ return "Failed to load model. Check the path and files."
362
+
363
+ # Create Gradio interface
364
+ with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo:
365
+ gr.HTML("""
366
+ <div style="text-align: center; margin-bottom: 30px;">
367
+ <h1 style="color: #2E86AB; font-size: 2.5em; margin-bottom: 10px;">
368
+ 🧬 Drug-Target Interaction Predictor
369
+ </h1>
370
+ <p style="font-size: 1.2em; color: #666;">
371
+ Predict binding affinity between drugs and target RNA sequences using deep learning
372
+ </p>
373
+ </div>
374
+ """)
375
+
376
+ # Create state variables to share images between tabs
377
+ viz_state1 = gr.State()
378
+ viz_state2 = gr.State()
379
+ viz_state3 = gr.State()
380
+
381
+ with gr.Tab("🔮 Prediction & Analysis"):
382
+ with gr.Row():
383
+ with gr.Column(scale=1):
384
+ target_input = gr.Textbox(
385
+ label="Target RNA Sequence",
386
+ placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)",
387
+ lines=4,
388
+ max_lines=6
389
+ )
390
+
391
+ drug_input = gr.Textbox(
392
+ label="Drug SMILES",
393
+ placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)",
394
+ lines=2
395
+ )
396
+
397
+ with gr.Row():
398
+ predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg")
399
+ visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg")
400
+
401
+ with gr.Column(scale=1):
402
+ prediction_output = gr.Textbox(
403
+ label="Prediction Result",
404
+ interactive=False,
405
+ lines=4
406
+ )
407
+
408
+ # Example inputs
409
+ gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>📚 Example Inputs:</h3>")
410
+
411
+ examples = gr.Examples(
412
+ examples=[
413
+ [
414
+ "AUGCUAGCUAGUACGUAUAUCUGCACUGC",
415
+ "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
416
+ ],
417
+ [
418
+ "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU",
419
+ "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
420
+ ]
421
+ ],
422
+ inputs=[target_input, drug_input],
423
+ outputs=prediction_output,
424
+ fn=predict_wrapper,
425
+ cache_examples=False
426
+ )
427
+
428
+ # Button click events
429
+ predict_btn.click(
430
+ fn=predict_wrapper,
431
+ inputs=[target_input, drug_input],
432
+ outputs=prediction_output
433
+ )
434
+
435
+ def visualize_and_update(target_seq, drug_smiles):
436
+ """Generate visualizations and update both status and state"""
437
+ img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles)
438
+ # Combine prediction result with visualization status
439
+ combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
440
+ return img1, img2, img3, combined_status
441
+
442
+ visualize_btn.click(
443
+ fn=visualize_and_update,
444
+ inputs=[target_input, drug_input],
445
+ outputs=[viz_state1, viz_state2, viz_state3, prediction_output]
446
+ )
447
+
448
+ with gr.Tab("📊 Visualizations"):
449
+ gr.HTML("""
450
+ <div style="text-align: center; margin-bottom: 20px;">
451
+ <h2 style="color: #2E86AB;">🔬 Interaction Analysis & Visualizations</h2>
452
+ <p style="font-size: 1.1em; color: #666;">
453
+ Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab
454
+ </p>
455
+ </div>
456
+ """)
457
+
458
+ # Visualization outputs - Large and vertically aligned
459
+ viz_image1 = gr.Image(
460
+ label="Cross-Attention Heatmap",
461
+ type="pil",
462
+ interactive=False,
463
+ container=True,
464
+ height=500,
465
+ value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)")
466
+ )
467
+
468
+ viz_image2 = gr.Image(
469
+ label="Raw pKd Contribution Visualization",
470
+ type="pil",
471
+ interactive=False,
472
+ container=True,
473
+ height=500,
474
+ value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)")
475
+ )
476
+
477
+ viz_image3 = gr.Image(
478
+ label="Normalized pKd Contribution Visualization",
479
+ type="pil",
480
+ interactive=False,
481
+ container=True,
482
+ height=500,
483
+ value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)")
484
+ )
485
+
486
+ # Update visualization images when state changes
487
+ viz_state1.change(
488
+ fn=lambda x: x,
489
+ inputs=viz_state1,
490
+ outputs=viz_image1
491
+ )
492
+
493
+ viz_state2.change(
494
+ fn=lambda x: x,
495
+ inputs=viz_state2,
496
+ outputs=viz_image2
497
+ )
498
+
499
+ viz_state3.change(
500
+ fn=lambda x: x,
501
+ inputs=viz_state3,
502
+ outputs=viz_image3
503
+ )
504
+
505
+ with gr.Tab("⚙️ Model Settings"):
506
+ gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>")
507
+
508
+ model_path_input = gr.Textbox(
509
+ label="Model Path",
510
+ value="./",
511
+ placeholder="Path to model directory"
512
+ )
513
+
514
+ load_model_btn = gr.Button("📥 Load Model", variant="secondary")
515
+ model_status = gr.Textbox(
516
+ label="Status",
517
+ interactive=False,
518
+ value="No model loaded"
519
+ )
520
+
521
+ load_model_btn.click(
522
+ fn=load_model_wrapper,
523
+ inputs=model_path_input,
524
+ outputs=model_status
525
+ )
526
+
527
+ with gr.Tab("📊 Dataset"):
528
+ gr.Markdown("""
529
+ ## Training and Test Datasets
530
+
531
+ ### Fine-tuning Dataset (Training)
532
+
533
+ The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including:
534
+ - **759 unique compounds** (SMILES representations)
535
+ - **294 unique RNA sequences**
536
+ - Dissociation constants (pKd values) for binding affinity prediction
537
+
538
+ **RNA Sequence Distribution by Type:**
539
+
540
+ | RNA Sequence Type | Number of Interactions |
541
+ |-------------------|------------------------|
542
+ | Aptamers | 520 |
543
+ | Ribosomal | 295 |
544
+ | Viral RNAs | 281 |
545
+ | miRNAs | 146 |
546
+ | Riboswitches | 100 |
547
+ | Repeats | 97 |
548
+ | **Total** | **1,439** |
549
+
550
+ ### External Evaluation Dataset (Test)
551
+
552
+ Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**:
553
+ - **2,991 positive interactions**
554
+ - **2,538 negative interactions**
555
+
556
+ **Test Dataset Composition:**
557
+ - **1,617 aptamer pairs** (5 unique RNA sequences)
558
+ - **1,828 viral RNA pairs** (6 unique RNA sequences)
559
+ - **1,459 riboswitch pairs** (5 unique RNA sequences)
560
+ - **630 miRNA pairs** (3 unique RNA sequences)
561
+
562
+ ### Dataset Downloads
563
+
564
+ - [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true)
565
+ - [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true)
566
+
567
+ ### Citation
568
+
569
+ Original datasets published by:
570
+ **Krishnan et al.** - Available on the RSAPred website in PDF format.
571
+
572
+ *Reference:*
573
+ ```bibtex
574
+ @article{krishnan2024reliable,
575
+ title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning},
576
+ author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael},
577
+ journal={Briefings in Bioinformatics},
578
+ volume={25},
579
+ number={2},
580
+ pages={bbae002},
581
+ year={2024},
582
+ publisher={Oxford University Press}
583
+ }
584
+ ```
585
+ """)
586
+ with gr.Tab("ℹ️ About"):
587
+ gr.Markdown("""
588
+ ## About this application
589
+
590
+ This application implements DLRNA-BERTa, a Dual Langauge RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes:
591
+
592
+ - **Target encoder**: Processes RNA sequences using RNA-BERTa
593
+ - **Drug encoder**: Processes molecular SMILES notation using ChemBERTa
594
+ - **Cross-attention mechanism**: Captures interactions between drugs and targets
595
+ - **Regression head**: Predicts binding affinity scores (pKd values)
596
+
597
+ ### Input requirements:
598
+ - **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C)
599
+ - **Drug SMILES**: Simplified Molecular Input Line Entry System notation
600
+
601
+ ### Model features:
602
+ - Cross-attention for drug-target interaction modeling
603
+ - Dropout for regularization
604
+ - Layer normalization for stable training
605
+ - Interpretability mode for contribution and attention visualization
606
+
607
+ ### Usage tips:
608
+ 1. Load a trained model using the Model Settings tab (optional)
609
+ 2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab
610
+ 3. Click "Predict Interaction" for binding affinity prediction only
611
+ 4. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab
612
+
613
+ For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
614
+
615
+ ### Visualization features:
616
+ - **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens
617
+ - **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token (only when pKd > 0)
618
+ - **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token
619
+
620
+ ### Performance metrics:
621
+ - Training on diverse drug-target interaction datasets
622
+ - Evaluated using RMSE, Pearson correlation, and Concordance Index
623
+ - Optimized for both predictive accuracy and interpretability
624
+
625
+ ### GitHub repository:
626
+ - The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction
627
+
628
+ ### Contribution:
629
+ - Special thanks to Umut Onur Özcan for help in developing this space:)
630
+ """)
631
+
632
+ # Launch the app
633
+ if __name__ == "__main__":
634
+ # Try to load model on startup
635
+ if os.path.exists("./config.json"):
636
+ app.load_model("./")
637
+
638
+ demo.launch(
639
+ server_name="0.0.0.0",
640
+ server_port=7860,
641
+ share=False,
642
+ show_error=True
643
+ )