Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel | |
| from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler | |
| from configuration_dlmberta import InteractionModelATTNConfig | |
| from chemberta import ChembertaTokenizer | |
| import json | |
| import os | |
| from pathlib import Path | |
| import logging | |
| # Import visualization functions | |
| from analysis import plot_crossattention_weights, plot_presum | |
| from PIL import Image, ImageDraw, ImageFont | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)): | |
| """ | |
| Create a transparent placeholder image with text | |
| Args: | |
| width (int): Image width | |
| height (int): Image height | |
| text (str): Text to display | |
| bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent | |
| Returns: | |
| PIL.Image: Transparent placeholder image | |
| """ | |
| # Create image with transparent background | |
| img = Image.new('RGBA', (width, height), bg_color) | |
| draw = ImageDraw.Draw(img) | |
| # Try to use a default font, fallback to default if not available | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 16) | |
| except: | |
| try: | |
| font = ImageFont.load_default() | |
| except: | |
| font = None | |
| # Get text size and position for centering | |
| if font: | |
| bbox = draw.textbbox((0, 0), text, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| else: | |
| # Rough estimation if no font available | |
| text_width = len(text) * 8 | |
| text_height = 16 | |
| x = (width - text_width) // 2 | |
| y = (height - text_height) // 2 | |
| # Draw text in gray | |
| draw.text((x, y), text, fill=(128, 128, 128, 255), font=font) | |
| return img | |
| class DrugTargetInteractionApp: | |
| def __init__(self): | |
| self.model = None | |
| self.target_tokenizer = None | |
| self.drug_tokenizer = None | |
| self.scaler = None | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_model(self, model_path="./"): | |
| """Load the pre-trained model and tokenizers""" | |
| try: | |
| # Load configuration | |
| config = InteractionModelATTNConfig.from_pretrained(model_path) | |
| # Load drug encoder (ChemBERTa) | |
| drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR") | |
| drug_encoder_config.pooler = None | |
| drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False) | |
| # Load target encoder | |
| target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700") | |
| # Load scaler if exists | |
| scaler_path = os.path.join(model_path, "scaler.config") | |
| scaler = None | |
| if os.path.exists(scaler_path): | |
| scaler = StdScaler() | |
| scaler.load(model_path) | |
| self.model = InteractionModelATTNForRegression.from_pretrained( | |
| model_path, | |
| config=config, | |
| target_encoder=target_encoder, | |
| drug_encoder=drug_encoder, | |
| scaler=scaler | |
| ) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| # Load tokenizers | |
| self.target_tokenizer = AutoTokenizer.from_pretrained( | |
| os.path.join(model_path, "target_tokenizer") | |
| ) | |
| # Load drug tokenizer (ChemBERTa) | |
| vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json") | |
| self.drug_tokenizer = ChembertaTokenizer(vocab_file) | |
| logger.info("Model and tokenizers loaded successfully!") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| return False | |
| def predict_interaction(self, target_sequence, drug_smiles, max_length=512): | |
| """Predict drug-target interaction""" | |
| if self.model is None: | |
| return "Error: Model not loaded. Please load a model first." | |
| try: | |
| # Tokenize inputs | |
| target_inputs = self.target_tokenizer( | |
| target_sequence, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| drug_inputs = self.drug_tokenizer( | |
| drug_smiles, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Make prediction | |
| self.model.INTERPR_DISABLE_MODE() | |
| with torch.no_grad(): | |
| prediction = self.model(target_inputs, drug_inputs) | |
| # Unscale if scaler exists | |
| if self.model.scaler is not None: | |
| prediction = self.model.unscale(prediction) | |
| prediction_value = prediction.cpu().numpy()[0][0] | |
| return f"Predicted Binding Affinity: {prediction_value:.4f}" | |
| except Exception as e: | |
| logger.error(f"Prediction error: {str(e)}") | |
| return f"Error during prediction: {str(e)}" | |
| def visualize_interaction(self, target_sequence, drug_smiles): | |
| """ | |
| Generate visualization images for drug-target interaction | |
| Args: | |
| target_sequence (str): RNA sequence | |
| drug_smiles (str): Drug SMILES notation | |
| Returns: | |
| tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message) | |
| """ | |
| if self.model is None: | |
| return None, None, None, "Error: Model not loaded. Please load a model first." | |
| try: | |
| # Tokenize inputs | |
| target_inputs = self.target_tokenizer( | |
| target_sequence, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| drug_inputs = self.drug_tokenizer( | |
| drug_smiles, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ).to(self.device) | |
| # Enable interpretation mode | |
| self.model.INTERPR_ENABLE_MODE() | |
| # Make prediction and extract visualization data | |
| with torch.no_grad(): | |
| prediction = self.model(target_inputs, drug_inputs) | |
| # Unscale if scaler exists | |
| if self.model.scaler is not None: | |
| prediction = self.model.unscale(prediction) | |
| prediction_value = prediction.cpu().numpy()[0][0] | |
| # Extract data needed for visualizations | |
| presum_values = self.model.model.presum_layer # Shape: (1, seq_len) | |
| cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len) | |
| # Get model parameters for scaling | |
| w = self.model.model.w.squeeze(1) | |
| b = self.model.model.b | |
| scaler = self.model.model.scaler | |
| logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}") | |
| logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}") | |
| # Generate visualizations | |
| try: | |
| # 1. Cross-attention heatmap | |
| cross_attention_img = None | |
| logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}") | |
| if cross_attention_weights is not None: | |
| logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}") | |
| try: | |
| cross_attn_matrix = cross_attention_weights[0, 0] | |
| if cross_attn_matrix is not None: | |
| logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}") | |
| logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}") | |
| logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}") | |
| cross_attention_img = plot_crossattention_weights( | |
| target_inputs["attention_mask"][0], | |
| drug_inputs["attention_mask"][0], | |
| target_inputs, | |
| drug_inputs, | |
| cross_attn_matrix, | |
| self.target_tokenizer, | |
| self.drug_tokenizer | |
| ) | |
| else: | |
| logger.warning("Could not extract valid cross-attention matrix") | |
| except (IndexError, TypeError, AttributeError) as e: | |
| logger.warning(f"Error extracting cross-attention matrix: {str(e)}") | |
| cross_attn_matrix = None | |
| else: | |
| logger.warning("Cross-attention weights are None") | |
| except Exception as e: | |
| logger.error(f"Cross-attention visualization error: {str(e)}") | |
| cross_attention_img = None | |
| try: | |
| # 2. Normalized contribution visualization (only if pKd > 0) | |
| normalized_img = None | |
| if presum_values is not None: | |
| normalized_img = plot_presum( | |
| target_inputs, | |
| presum_values.detach(), # Detach the tensor | |
| scaler, | |
| w.detach(), # Detach the tensor | |
| b.detach(), # Detach the tensor | |
| self.target_tokenizer, | |
| raw_affinities=False | |
| ) | |
| else: | |
| if prediction_value <= 0: | |
| logger.info("Skipping normalized affinities visualization as pKd <= 0") | |
| if presum_values is None: | |
| logger.warning("Cannot generate raw visualization: presum values are None") | |
| except Exception as e: | |
| logger.error(f"Normalized contribution visualization error: {str(e)}") | |
| normalized_img = None | |
| try: | |
| # 3. Raw contribution visualization (always generate) | |
| raw_img = None | |
| if prediction_value > 0 and presum_values is not None: | |
| raw_img = plot_presum( | |
| target_inputs, | |
| presum_values.detach(), # Detach the tensor | |
| scaler, | |
| w.detach(), # Detach the tensor | |
| b.detach(), # Detach the tensor | |
| self.target_tokenizer, | |
| raw_affinities=True | |
| ) | |
| else: | |
| logger.warning("Presum values are None") | |
| except Exception as e: | |
| logger.error(f"Raw contribution visualization error: {str(e)}") | |
| raw_img = None | |
| # Disable interpretation mode after use | |
| self.model.INTERPR_DISABLE_MODE() | |
| # Create placeholder images if generation failed | |
| if cross_attention_img is None: | |
| cross_attention_img = create_placeholder_image( | |
| text="Cross-Attention Heatmap\nFailed to generate" | |
| ) | |
| if normalized_img is None: | |
| normalized_img = create_placeholder_image( | |
| text="Normalized Contribution\nFailed to generate" | |
| ) | |
| if raw_img is None and prediction_value > 0: | |
| raw_img = create_placeholder_image( | |
| text="Raw Contribution\nFailed to generate" | |
| ) | |
| elif raw_img is None: | |
| raw_img = create_placeholder_image( | |
| text="Raw Contribution\nSkipped (pKd ≤ 0)" | |
| ) | |
| status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}" | |
| if prediction_value <= 0: | |
| status_msg += " (Raw contribution visualization skipped due to non-positive pKd)" | |
| if cross_attention_weights is None: | |
| status_msg += " (Cross-attention visualization failed: weights not available)" | |
| return cross_attention_img, raw_img, normalized_img, status_msg | |
| except Exception as e: | |
| logger.error(f"Visualization error: {str(e)}") | |
| # Make sure to disable interpretation mode even if there's an error | |
| try: | |
| self.model.INTERPR_DISABLE_MODE() | |
| except: | |
| pass | |
| return None, None, None, f"Error during visualization: {str(e)}" | |
| # Initialize the app | |
| app = DrugTargetInteractionApp() | |
| def predict_wrapper(target_seq, drug_smiles): | |
| """Wrapper function for Gradio interface""" | |
| if not target_seq.strip() or not drug_smiles.strip(): | |
| return "Please provide both target sequence and drug SMILES." | |
| return app.predict_interaction(target_seq, drug_smiles) | |
| def visualize_wrapper(target_seq, drug_smiles): | |
| """Wrapper function for visualization""" | |
| if not target_seq.strip() or not drug_smiles.strip(): | |
| return None, None, None, "Please provide both target sequence and drug SMILES." | |
| return app.visualize_interaction(target_seq, drug_smiles) | |
| def load_model_wrapper(model_path): | |
| """Wrapper function to load model""" | |
| if app.load_model(model_path): | |
| return "Model loaded successfully!" | |
| else: | |
| return "Failed to load model. Check the path and files." | |
| # Create Gradio interface | |
| with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo: | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 30px;"> | |
| <h1 style="color: #2E86AB; font-size: 2.5em; margin-bottom: 10px;"> | |
| 🧬 Drug-Target Interaction Predictor | |
| </h1> | |
| <p style="font-size: 1.2em; color: #666;"> | |
| Predict binding affinity between drugs and target RNA sequences using deep learning | |
| </p> | |
| </div> | |
| """) | |
| # Create state variables to share images between tabs | |
| viz_state1 = gr.State() | |
| viz_state2 = gr.State() | |
| viz_state3 = gr.State() | |
| with gr.Tab("🔮 Prediction & Analysis"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| target_input = gr.Textbox( | |
| label="Target RNA Sequence", | |
| placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)", | |
| lines=4, | |
| max_lines=6 | |
| ) | |
| drug_input = gr.Textbox( | |
| label="Drug SMILES", | |
| placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg") | |
| visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg") | |
| with gr.Column(scale=1): | |
| prediction_output = gr.Textbox( | |
| label="Prediction Result", | |
| interactive=False, | |
| lines=4 | |
| ) | |
| # Example inputs | |
| gr.HTML("<h3 style='margin-top: 20px; color: #2E86AB;'>📚 Example Inputs:</h3>") | |
| examples = gr.Examples( | |
| examples=[ | |
| [ | |
| "AUGCUAGCUAGUACGUAUAUCUGCACUGC", | |
| "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" | |
| ], | |
| [ | |
| "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU", | |
| "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2" | |
| ] | |
| ], | |
| inputs=[target_input, drug_input], | |
| outputs=prediction_output, | |
| fn=predict_wrapper, | |
| cache_examples=False | |
| ) | |
| # Button click events | |
| predict_btn.click( | |
| fn=predict_wrapper, | |
| inputs=[target_input, drug_input], | |
| outputs=prediction_output | |
| ) | |
| def visualize_and_update(target_seq, drug_smiles): | |
| """Generate visualizations and update both status and state""" | |
| img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles) | |
| # Combine prediction result with visualization status | |
| combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images." | |
| return img1, img2, img3, combined_status | |
| visualize_btn.click( | |
| fn=visualize_and_update, | |
| inputs=[target_input, drug_input], | |
| outputs=[viz_state1, viz_state2, viz_state3, prediction_output] | |
| ) | |
| with gr.Tab("📊 Visualizations"): | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-bottom: 20px;"> | |
| <h2 style="color: #2E86AB;">🔬 Interaction Analysis & Visualizations</h2> | |
| <p style="font-size: 1.1em; color: #666;"> | |
| Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab | |
| </p> | |
| </div> | |
| """) | |
| # Visualization outputs - Large and vertically aligned | |
| viz_image1 = gr.Image( | |
| label="Cross-Attention Heatmap", | |
| type="pil", | |
| interactive=False, | |
| container=True, | |
| height=500, | |
| value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)") | |
| ) | |
| viz_image2 = gr.Image( | |
| label="Raw pKd Contribution Visualization", | |
| type="pil", | |
| interactive=False, | |
| container=True, | |
| height=500, | |
| value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)") | |
| ) | |
| viz_image3 = gr.Image( | |
| label="Normalized pKd Contribution Visualization", | |
| type="pil", | |
| interactive=False, | |
| container=True, | |
| height=500, | |
| value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)") | |
| ) | |
| # Update visualization images when state changes | |
| viz_state1.change( | |
| fn=lambda x: x, | |
| inputs=viz_state1, | |
| outputs=viz_image1 | |
| ) | |
| viz_state2.change( | |
| fn=lambda x: x, | |
| inputs=viz_state2, | |
| outputs=viz_image2 | |
| ) | |
| viz_state3.change( | |
| fn=lambda x: x, | |
| inputs=viz_state3, | |
| outputs=viz_image3 | |
| ) | |
| with gr.Tab("⚙️ Model Settings"): | |
| gr.HTML("<h3 style='color: #2E86AB;'>Model Configuration</h3>") | |
| model_path_input = gr.Textbox( | |
| label="Model Path", | |
| value="./", | |
| placeholder="Path to model directory" | |
| ) | |
| load_model_btn = gr.Button("📥 Load Model", variant="secondary") | |
| model_status = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| value="No model loaded" | |
| ) | |
| load_model_btn.click( | |
| fn=load_model_wrapper, | |
| inputs=model_path_input, | |
| outputs=model_status | |
| ) | |
| with gr.Tab("📊 Dataset"): | |
| gr.Markdown(""" | |
| ## Training and Test Datasets | |
| ### Fine-tuning Dataset (Training) | |
| The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including: | |
| - **759 unique compounds** (SMILES representations) | |
| - **294 unique RNA sequences** | |
| - Dissociation constants (pKd values) for binding affinity prediction | |
| **RNA Sequence Distribution by Type:** | |
| | RNA Sequence Type | Number of Interactions | | |
| |-------------------|------------------------| | |
| | Aptamers | 520 | | |
| | Ribosomal | 295 | | |
| | Viral RNAs | 281 | | |
| | miRNAs | 146 | | |
| | Riboswitches | 100 | | |
| | Repeats | 97 | | |
| | **Total** | **1,439** | | |
| ### External Evaluation Dataset (Test) | |
| Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**: | |
| - **2,991 positive interactions** | |
| - **2,538 negative interactions** | |
| **Test Dataset Composition:** | |
| - **1,617 aptamer pairs** (5 unique RNA sequences) | |
| - **1,828 viral RNA pairs** (6 unique RNA sequences) | |
| - **1,459 riboswitch pairs** (5 unique RNA sequences) | |
| - **630 miRNA pairs** (3 unique RNA sequences) | |
| ### Dataset Downloads | |
| - [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true) | |
| - [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true) | |
| ### Citation | |
| Original datasets published by: | |
| **Krishnan et al.** - Available on the RSAPred website in PDF format. | |
| *Reference:* | |
| ```bibtex | |
| @article{krishnan2024reliable, | |
| title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning}, | |
| author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael}, | |
| journal={Briefings in Bioinformatics}, | |
| volume={25}, | |
| number={2}, | |
| pages={bbae002}, | |
| year={2024}, | |
| publisher={Oxford University Press} | |
| } | |
| ``` | |
| """) | |
| with gr.Tab("ℹ️ About"): | |
| gr.Markdown(""" | |
| ## About this application | |
| This application implements DLRNA-BERTa, a Dual Langauge RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes: | |
| - **Target encoder**: Processes RNA sequences using RNA-BERTa | |
| - **Drug encoder**: Processes molecular SMILES notation using ChemBERTa | |
| - **Cross-attention mechanism**: Captures interactions between drugs and targets | |
| - **Regression head**: Predicts binding affinity scores (pKd values) | |
| ### Input requirements: | |
| - **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C) | |
| - **Drug SMILES**: Simplified Molecular Input Line Entry System notation | |
| ### Model features: | |
| - Cross-attention for drug-target interaction modeling | |
| - Dropout for regularization | |
| - Layer normalization for stable training | |
| - Interpretability mode for contribution and attention visualization | |
| ### Usage tips: | |
| 1. Load a trained model using the Model Settings tab (optional) | |
| 2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab | |
| 3. Click "Predict Interaction" for binding affinity prediction only | |
| 4. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab | |
| For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens). | |
| ### Visualization features: | |
| - **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens | |
| - **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token (only when pKd > 0) | |
| - **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token | |
| ### Performance metrics: | |
| - Training on diverse drug-target interaction datasets | |
| - Evaluated using RMSE, Pearson correlation, and Concordance Index | |
| - Optimized for both predictive accuracy and interpretability | |
| ### GitHub repository: | |
| - The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction | |
| ### Contribution: | |
| - Special thanks to Umut Onur Özcan for help in developing this space:) | |
| ### Contact: | |
| - Ziaurrehman Tanoli ([email protected]) | |
| Principal investigator at Institute for Molecular Medicine Finland | |
| HiLIFE, University of Helsinki, Finland. | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| # Try to load model on startup | |
| if os.path.exists("./config.json"): | |
| app.load_model("./") | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |