File size: 26,233 Bytes
a899ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69e3309
a899ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a81fdbb
 
 
 
 
 
a899ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
import gradio as gr
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel
from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler
from configuration_dlmberta import InteractionModelATTNConfig
from chemberta import ChembertaTokenizer
import json
import os
from pathlib import Path
import logging

# Import visualization functions
from analysis import plot_crossattention_weights, plot_presum
from PIL import Image, ImageDraw, ImageFont

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)):
    """
    Create a transparent placeholder image with text
    
    Args:
        width (int): Image width
        height (int): Image height  
        text (str): Text to display
        bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent
        
    Returns:
        PIL.Image: Transparent placeholder image
    """
    # Create image with transparent background
    img = Image.new('RGBA', (width, height), bg_color)
    draw = ImageDraw.Draw(img)
    
    # Try to use a default font, fallback to default if not available
    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except:
        try:
            font = ImageFont.load_default()
        except:
            font = None
    
    # Get text size and position for centering
    if font:
        bbox = draw.textbbox((0, 0), text, font=font)
        text_width = bbox[2] - bbox[0]
        text_height = bbox[3] - bbox[1]
    else:
        # Rough estimation if no font available
        text_width = len(text) * 8
        text_height = 16
    
    x = (width - text_width) // 2
    y = (height - text_height) // 2
    
    # Draw text in gray
    draw.text((x, y), text, fill=(128, 128, 128, 255), font=font)
    
    return img

class DrugTargetInteractionApp:
    def __init__(self):
        self.model = None
        self.target_tokenizer = None
        self.drug_tokenizer = None
        self.scaler = None
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
    def load_model(self, model_path="./"):
        """Load the pre-trained model and tokenizers"""
        try:
            # Load configuration
            config = InteractionModelATTNConfig.from_pretrained(model_path)
            
            # Load drug encoder (ChemBERTa)
            drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
            drug_encoder_config.pooler = None
            drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)
            
            # Load target encoder
            target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
            
            # Load scaler if exists
            scaler_path = os.path.join(model_path, "scaler.config")
            scaler = None
            if os.path.exists(scaler_path):
                scaler = StdScaler()
                scaler.load(model_path)
            
            self.model = InteractionModelATTNForRegression.from_pretrained(
                        model_path,
                        config=config,
                        target_encoder=target_encoder,
                        drug_encoder=drug_encoder, 
                        scaler=scaler
                    )

            self.model.to(self.device)
            self.model.eval()
            
            # Load tokenizers
            self.target_tokenizer = AutoTokenizer.from_pretrained(
                os.path.join(model_path, "target_tokenizer")
            )
            
            # Load drug tokenizer (ChemBERTa)
            vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json")
            self.drug_tokenizer = ChembertaTokenizer(vocab_file)
            
            logger.info("Model and tokenizers loaded successfully!")
            return True
            
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            return False
    
    def predict_interaction(self, target_sequence, drug_smiles, max_length=512):
        """Predict drug-target interaction"""
        if self.model is None:
            return "Error: Model not loaded. Please load a model first."
        
        try:
            # Tokenize inputs
            target_inputs = self.target_tokenizer(
                target_sequence,
                padding="max_length", 
                truncation=True, 
                max_length=512,
                return_tensors="pt"
            ).to(self.device)
            
            drug_inputs = self.drug_tokenizer(
                drug_smiles,
                padding="max_length", 
                truncation=True, 
                max_length=512,
                return_tensors="pt"
            ).to(self.device)

            # Make prediction
            self.model.INTERPR_DISABLE_MODE()
            with torch.no_grad():
                prediction = self.model(target_inputs, drug_inputs)
                
                # Unscale if scaler exists
                if self.model.scaler is not None:
                    prediction = self.model.unscale(prediction)
                
                prediction_value = prediction.cpu().numpy()[0][0]
            
            return f"Predicted Binding Affinity: {prediction_value:.4f}"
            
        except Exception as e:
            logger.error(f"Prediction error: {str(e)}")
            return f"Error during prediction: {str(e)}"

    def visualize_interaction(self, target_sequence, drug_smiles):
        """
        Generate visualization images for drug-target interaction
        
        Args:
            target_sequence (str): RNA sequence
            drug_smiles (str): Drug SMILES notation
            
        Returns:
            tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message)
        """
        if self.model is None:
            return None, None, None, "Error: Model not loaded. Please load a model first."
        
        try:
            # Tokenize inputs
            target_inputs = self.target_tokenizer(
                target_sequence,
                padding="max_length", 
                truncation=True, 
                max_length=512,
                return_tensors="pt"
            ).to(self.device)
            
            drug_inputs = self.drug_tokenizer(
                drug_smiles,
                padding="max_length", 
                truncation=True, 
                max_length=512,
                return_tensors="pt"
            ).to(self.device)
            
            # Enable interpretation mode
            self.model.INTERPR_ENABLE_MODE()

            # Make prediction and extract visualization data
            with torch.no_grad():
                prediction = self.model(target_inputs, drug_inputs)
                
                # Unscale if scaler exists
                if self.model.scaler is not None:
                    prediction = self.model.unscale(prediction)
                
                prediction_value = prediction.cpu().numpy()[0][0]
                
                # Extract data needed for visualizations
                presum_values = self.model.model.presum_layer  # Shape: (1, seq_len)
                cross_attention_weights = self.model.model.crossattention_weights  # Shape: (batch, heads, seq_len, seq_len)
                
                # Get model parameters for scaling
                w = self.model.model.w.squeeze(1)
                b = self.model.model.b
                scaler = self.model.model.scaler
                
            logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}")
            logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}")
            
            # Generate visualizations
            try:
                # 1. Cross-attention heatmap
                cross_attention_img = None
                logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}")
                if cross_attention_weights is not None:
                    logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}")
                    
                    try:
                        cross_attn_matrix = cross_attention_weights[0, 0]
                            
                        if cross_attn_matrix is not None:
                            logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}")
                            logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}")
                            logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}")

                            cross_attention_img = plot_crossattention_weights(
                                target_inputs["attention_mask"][0],
                                drug_inputs["attention_mask"][0],
                                target_inputs,
                                drug_inputs,
                                cross_attn_matrix,
                                self.target_tokenizer,
                                self.drug_tokenizer
                            )
                        else:
                            logger.warning("Could not extract valid cross-attention matrix")
                            
                    except (IndexError, TypeError, AttributeError) as e:
                        logger.warning(f"Error extracting cross-attention matrix: {str(e)}")
                        cross_attn_matrix = None
                else:
                    logger.warning("Cross-attention weights are None")
                    
            except Exception as e:
                logger.error(f"Cross-attention visualization error: {str(e)}")
                cross_attention_img = None
            
            try:
                # 2. Normalized contribution visualization (only if pKd > 0)
                normalized_img = None
                if presum_values is not None:
                    normalized_img = plot_presum(
                        target_inputs,
                        presum_values.detach(),  # Detach the tensor
                        scaler,
                        w.detach(),  # Detach the tensor
                        b.detach(),  # Detach the tensor
                        self.target_tokenizer,
                        raw_affinities=False
                    )
                else:
                    if prediction_value <= 0:
                        logger.info("Skipping normalized affinities visualization as pKd <= 0")
                    if presum_values is None:
                        logger.warning("Cannot generate raw visualization: presum values are None")

                    
            except Exception as e:
                logger.error(f"Normalized contribution visualization error: {str(e)}")
                normalized_img = None
            
            try:
                # 3. Raw contribution visualization (always generate)
                raw_img = None
                if prediction_value > 0 and presum_values is not None:
                    raw_img = plot_presum(
                        target_inputs,
                        presum_values.detach(),  # Detach the tensor
                        scaler,
                        w.detach(),  # Detach the tensor
                        b.detach(),  # Detach the tensor
                        self.target_tokenizer,
                        raw_affinities=True
                    )
                else:
                    logger.warning("Presum values are None")
                        
            except Exception as e:
                logger.error(f"Raw contribution visualization error: {str(e)}")
                raw_img = None
            
            # Disable interpretation mode after use
            self.model.INTERPR_DISABLE_MODE()
            
            # Create placeholder images if generation failed
            if cross_attention_img is None:
                cross_attention_img = create_placeholder_image(
                    text="Cross-Attention Heatmap\nFailed to generate"
                )
            if normalized_img is None:
                normalized_img = create_placeholder_image(
                    text="Normalized Contribution\nFailed to generate"
                )
            if raw_img is None and prediction_value > 0:
                raw_img = create_placeholder_image(
                    text="Raw Contribution\nFailed to generate"
                )
            elif raw_img is None:
                raw_img = create_placeholder_image(
                    text="Raw Contribution\nSkipped (pKd ≤ 0)"
                )
                
            status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}"
            if prediction_value <= 0:
                status_msg += " (Raw contribution visualization skipped due to non-positive pKd)"
            if cross_attention_weights is None:
                status_msg += " (Cross-attention visualization failed: weights not available)"
                
            return cross_attention_img, raw_img, normalized_img, status_msg
            
        except Exception as e:
            logger.error(f"Visualization error: {str(e)}")
            # Make sure to disable interpretation mode even if there's an error
            try:
                self.model.INTERPR_DISABLE_MODE()
            except:
                pass
            return None, None, None, f"Error during visualization: {str(e)}"


# Initialize the app
app = DrugTargetInteractionApp()

def predict_wrapper(target_seq, drug_smiles):
    """Wrapper function for Gradio interface"""
    if not target_seq.strip() or not drug_smiles.strip():
        return "Please provide both target sequence and drug SMILES."
    
    return app.predict_interaction(target_seq, drug_smiles)

def visualize_wrapper(target_seq, drug_smiles):
    """Wrapper function for visualization"""
    if not target_seq.strip() or not drug_smiles.strip():
        return None, None, None, "Please provide both target sequence and drug SMILES."

    return app.visualize_interaction(target_seq, drug_smiles)

def load_model_wrapper(model_path):
    """Wrapper function to load model"""
    if app.load_model(model_path):
        return "Model loaded successfully!"
    else:
        return "Failed to load model. Check the path and files."

# Create Gradio interface
with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo:
    gr.HTML("""
    <div style="text-align: center; margin-bottom: 30px;">
        <h1 style="color: #2E86AB; font-size: 2.5em; margin-bottom: 10px;">
            🧬 Drug-Target Interaction Predictor
        </h1>
        <p style="font-size: 1.2em; color: #666;">
            Predict binding affinity between drugs and target RNA sequences using deep learning
        </p>
    </div>
    """)
    
    # Create state variables to share images between tabs
    viz_state1 = gr.State()
    viz_state2 = gr.State()
    viz_state3 = gr.State()
    
    with gr.Tab("🔮 Prediction & Analysis"):
        with gr.Row():
            with gr.Column(scale=1):
                target_input = gr.Textbox(
                    label="Target RNA Sequence",
                    placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)",
                    lines=4,
                    max_lines=6
                )
                
                drug_input = gr.Textbox(
                    label="Drug SMILES",
                    placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)",
                    lines=2
                )
                
                with gr.Row():
                    predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg")
                    visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg")
            
            with gr.Column(scale=1):
                prediction_output = gr.Textbox(
                    label="Prediction Result",
                    interactive=False,
                    lines=4
                )
        
        # Example inputs
        gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>📚 Example Inputs:</h3>")
        
        examples = gr.Examples(
            examples=[
                [
                    "AUGCUAGCUAGUACGUAUAUCUGCACUGC",
                    "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O"
                ],
                [
                    "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU",
                    "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
                ]
            ],
            inputs=[target_input, drug_input],
            outputs=prediction_output,
            fn=predict_wrapper,
            cache_examples=False
        )
        
        # Button click events
        predict_btn.click(
            fn=predict_wrapper,
            inputs=[target_input, drug_input],
            outputs=prediction_output
        )
        
        def visualize_and_update(target_seq, drug_smiles):
            """Generate visualizations and update both status and state"""
            img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles)
            # Combine prediction result with visualization status
            combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images."
            return img1, img2, img3, combined_status
        
        visualize_btn.click(
            fn=visualize_and_update,
            inputs=[target_input, drug_input],
            outputs=[viz_state1, viz_state2, viz_state3, prediction_output]
        )
    
    with gr.Tab("📊 Visualizations"):
        gr.HTML("""
        <div style="text-align: center; margin-bottom: 20px;">
            <h2 style="color: #2E86AB;">🔬 Interaction Analysis & Visualizations</h2>
            <p style="font-size: 1.1em; color: #666;">
                Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab
            </p>
        </div>
        """)
        
        # Visualization outputs - Large and vertically aligned
        viz_image1 = gr.Image(
            label="Cross-Attention Heatmap",
            type="pil",
            interactive=False,
            container=True,
            height=500,
            value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)")
        )
        
        viz_image2 = gr.Image(
            label="Raw pKd Contribution Visualization",
            type="pil", 
            interactive=False,
            container=True,
            height=500,
            value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)")
        )
        
        viz_image3 = gr.Image(
            label="Normalized pKd Contribution Visualization",
            type="pil",
            interactive=False,
            container=True,
            height=500,
            value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)")
        )
        
        # Update visualization images when state changes
        viz_state1.change(
            fn=lambda x: x,
            inputs=viz_state1,
            outputs=viz_image1
        )
        
        viz_state2.change(
            fn=lambda x: x,
            inputs=viz_state2,
            outputs=viz_image2
        )
        
        viz_state3.change(
            fn=lambda x: x,
            inputs=viz_state3,
            outputs=viz_image3
        )
    
    with gr.Tab("⚙️ Model Settings"):
        gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>")
        
        model_path_input = gr.Textbox(
            label="Model Path",
            value="./",
            placeholder="Path to model directory"
        )
        
        load_model_btn = gr.Button("📥 Load Model", variant="secondary")
        model_status = gr.Textbox(
            label="Status",
            interactive=False,
            value="No model loaded"
        )
        
        load_model_btn.click(
            fn=load_model_wrapper,
            inputs=model_path_input,
            outputs=model_status
        )
    
    with gr.Tab("📊 Dataset"):
        gr.Markdown("""
        ## Training and Test Datasets
        
        ### Fine-tuning Dataset (Training)
        
        The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including:
        - **759 unique compounds** (SMILES representations)
        - **294 unique RNA sequences**
        - Dissociation constants (pKd values) for binding affinity prediction
        
        **RNA Sequence Distribution by Type:**
        
        | RNA Sequence Type | Number of Interactions |
        |-------------------|------------------------|
        | Aptamers          | 520                    |
        | Ribosomal         | 295                    |
        | Viral RNAs        | 281                    |
        | miRNAs            | 146                    |
        | Riboswitches      | 100                    |
        | Repeats           | 97                     |
        | **Total**         | **1,439**              |
        
        ### External Evaluation Dataset (Test)
        
        Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**:
        - **2,991 positive interactions**
        - **2,538 negative interactions**
        
        **Test Dataset Composition:**
        - **1,617 aptamer pairs** (5 unique RNA sequences)
        - **1,828 viral RNA pairs** (6 unique RNA sequences)  
        - **1,459 riboswitch pairs** (5 unique RNA sequences)
        - **630 miRNA pairs** (3 unique RNA sequences)
        
        ### Dataset Downloads

        - [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true)
        - [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true)

        ### Citation
        
        Original datasets published by:
        **Krishnan et al.** - Available on the RSAPred website in PDF format.

        *Reference:*
        ```bibtex
            @article{krishnan2024reliable,
            title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning},
            author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael},
            journal={Briefings in Bioinformatics},
            volume={25},
            number={2},
            pages={bbae002},
            year={2024},
            publisher={Oxford University Press}
            }
            ``` 
        """)
    with gr.Tab("ℹ️ About"):
        gr.Markdown("""
        ## About this application
        
        This application implements DLRNA-BERTa, a Dual Langauge RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes:
        
        - **Target encoder**: Processes RNA sequences using RNA-BERTa
        - **Drug encoder**: Processes molecular SMILES notation using ChemBERTa
        - **Cross-attention mechanism**: Captures interactions between drugs and targets
        - **Regression head**: Predicts binding affinity scores (pKd values)
        
        ### Input requirements:
        - **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C)
        - **Drug SMILES**: Simplified Molecular Input Line Entry System notation
        
        ### Model features:
        - Cross-attention for drug-target interaction modeling
        - Dropout for regularization
        - Layer normalization for stable training
        - Interpretability mode for contribution and attention visualization
        
        ### Usage tips:
        1. Load a trained model using the Model Settings tab (optional)
        2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab
        3. Click "Predict Interaction" for binding affinity prediction only
        4. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab
        
        For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens).
        
        ### Visualization features:
        - **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens
        - **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token (only when pKd > 0)
        - **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token
        
        ### Performance metrics:
        - Training on diverse drug-target interaction datasets
        - Evaluated using RMSE, Pearson correlation, and Concordance Index
        - Optimized for both predictive accuracy and interpretability

        ### GitHub repository:
        - The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction
        
        ### Contribution:
        - Special thanks to Umut Onur Özcan for help in developing this space:)

        ### Contact:
        - Ziaurrehman Tanoli ([email protected])
            Principal investigator at Institute for Molecular Medicine Finland
            HiLIFE, University of Helsinki, Finland.
            
        """)

# Launch the app
if __name__ == "__main__":
    # Try to load model on startup
    if os.path.exists("./config.json"):
        app.load_model("./")
    
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True
    )