import gradio as gr import torch import numpy as np from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler from configuration_dlmberta import InteractionModelATTNConfig from chemberta import ChembertaTokenizer import json import os from pathlib import Path import logging # Import visualization functions from analysis import plot_crossattention_weights, plot_presum from PIL import Image, ImageDraw, ImageFont # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def create_placeholder_image(width=600, height=400, text="No visualization available", bg_color=(0, 0, 0, 0)): """ Create a transparent placeholder image with text Args: width (int): Image width height (int): Image height text (str): Text to display bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent Returns: PIL.Image: Transparent placeholder image """ # Create image with transparent background img = Image.new('RGBA', (width, height), bg_color) draw = ImageDraw.Draw(img) # Try to use a default font, fallback to default if not available try: font = ImageFont.truetype("arial.ttf", 16) except: try: font = ImageFont.load_default() except: font = None # Get text size and position for centering if font: bbox = draw.textbbox((0, 0), text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] else: # Rough estimation if no font available text_width = len(text) * 8 text_height = 16 x = (width - text_width) // 2 y = (height - text_height) // 2 # Draw text in gray draw.text((x, y), text, fill=(128, 128, 128, 255), font=font) return img class DrugTargetInteractionApp: def __init__(self): self.model = None self.target_tokenizer = None self.drug_tokenizer = None self.scaler = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(self, model_path="./"): """Load the pre-trained model and tokenizers""" try: # Load configuration config = InteractionModelATTNConfig.from_pretrained(model_path) # Load drug encoder (ChemBERTa) drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR") drug_encoder_config.pooler = None drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False) # Load target encoder target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700") # Load scaler if exists scaler_path = os.path.join(model_path, "scaler.config") scaler = None if os.path.exists(scaler_path): scaler = StdScaler() scaler.load(model_path) self.model = InteractionModelATTNForRegression.from_pretrained( model_path, config=config, target_encoder=target_encoder, drug_encoder=drug_encoder, scaler=scaler ) self.model.to(self.device) self.model.eval() # Load tokenizers self.target_tokenizer = AutoTokenizer.from_pretrained( os.path.join(model_path, "target_tokenizer") ) # Load drug tokenizer (ChemBERTa) vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json") self.drug_tokenizer = ChembertaTokenizer(vocab_file) logger.info("Model and tokenizers loaded successfully!") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") return False def predict_interaction(self, target_sequence, drug_smiles, max_length=512): """Predict drug-target interaction""" if self.model is None: return "Error: Model not loaded. Please load a model first." try: # Tokenize inputs target_inputs = self.target_tokenizer( target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) drug_inputs = self.drug_tokenizer( drug_smiles, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) # Make prediction self.model.INTERPR_DISABLE_MODE() with torch.no_grad(): prediction = self.model(target_inputs, drug_inputs) # Unscale if scaler exists if self.model.scaler is not None: prediction = self.model.unscale(prediction) prediction_value = prediction.cpu().numpy()[0][0] return f"Predicted Binding Affinity: {prediction_value:.4f}" except Exception as e: logger.error(f"Prediction error: {str(e)}") return f"Error during prediction: {str(e)}" def visualize_interaction(self, target_sequence, drug_smiles): """ Generate visualization images for drug-target interaction Args: target_sequence (str): RNA sequence drug_smiles (str): Drug SMILES notation Returns: tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message) """ if self.model is None: return None, None, None, "Error: Model not loaded. Please load a model first." try: # Tokenize inputs target_inputs = self.target_tokenizer( target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) drug_inputs = self.drug_tokenizer( drug_smiles, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) # Enable interpretation mode self.model.INTERPR_ENABLE_MODE() # Make prediction and extract visualization data with torch.no_grad(): prediction = self.model(target_inputs, drug_inputs) # Unscale if scaler exists if self.model.scaler is not None: prediction = self.model.unscale(prediction) prediction_value = prediction.cpu().numpy()[0][0] # Extract data needed for visualizations presum_values = self.model.model.presum_layer # Shape: (1, seq_len) cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len) # Get model parameters for scaling w = self.model.model.w.squeeze(1) b = self.model.model.b scaler = self.model.model.scaler logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}") logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}") # Generate visualizations try: # 1. Cross-attention heatmap cross_attention_img = None logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}") if cross_attention_weights is not None: logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}") try: cross_attn_matrix = cross_attention_weights[0, 0] if cross_attn_matrix is not None: logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}") logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}") logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}") cross_attention_img = plot_crossattention_weights( target_inputs["attention_mask"][0], drug_inputs["attention_mask"][0], target_inputs, drug_inputs, cross_attn_matrix, self.target_tokenizer, self.drug_tokenizer ) else: logger.warning("Could not extract valid cross-attention matrix") except (IndexError, TypeError, AttributeError) as e: logger.warning(f"Error extracting cross-attention matrix: {str(e)}") cross_attn_matrix = None else: logger.warning("Cross-attention weights are None") except Exception as e: logger.error(f"Cross-attention visualization error: {str(e)}") cross_attention_img = None try: # 2. Normalized contribution visualization (only if pKd > 0) normalized_img = None if presum_values is not None: normalized_img = plot_presum( target_inputs, presum_values.detach(), # Detach the tensor scaler, w.detach(), # Detach the tensor b.detach(), # Detach the tensor self.target_tokenizer, raw_affinities=False ) else: if prediction_value <= 0: logger.info("Skipping normalized affinities visualization as pKd <= 0") if presum_values is None: logger.warning("Cannot generate raw visualization: presum values are None") except Exception as e: logger.error(f"Normalized contribution visualization error: {str(e)}") normalized_img = None try: # 3. Raw contribution visualization (always generate) raw_img = None if prediction_value > 0 and presum_values is not None: raw_img = plot_presum( target_inputs, presum_values.detach(), # Detach the tensor scaler, w.detach(), # Detach the tensor b.detach(), # Detach the tensor self.target_tokenizer, raw_affinities=True ) else: logger.warning("Presum values are None") except Exception as e: logger.error(f"Raw contribution visualization error: {str(e)}") raw_img = None # Disable interpretation mode after use self.model.INTERPR_DISABLE_MODE() # Create placeholder images if generation failed if cross_attention_img is None: cross_attention_img = create_placeholder_image( text="Cross-Attention Heatmap\nFailed to generate" ) if normalized_img is None: normalized_img = create_placeholder_image( text="Normalized Contribution\nFailed to generate" ) if raw_img is None and prediction_value > 0: raw_img = create_placeholder_image( text="Raw Contribution\nFailed to generate" ) elif raw_img is None: raw_img = create_placeholder_image( text="Raw Contribution\nSkipped (pKd ≤ 0)" ) status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}" if prediction_value <= 0: status_msg += " (Raw contribution visualization skipped due to non-positive pKd)" if cross_attention_weights is None: status_msg += " (Cross-attention visualization failed: weights not available)" return cross_attention_img, raw_img, normalized_img, status_msg except Exception as e: logger.error(f"Visualization error: {str(e)}") # Make sure to disable interpretation mode even if there's an error try: self.model.INTERPR_DISABLE_MODE() except: pass return None, None, None, f"Error during visualization: {str(e)}" # Initialize the app app = DrugTargetInteractionApp() def predict_wrapper(target_seq, drug_smiles): """Wrapper function for Gradio interface""" if not target_seq.strip() or not drug_smiles.strip(): return "Please provide both target sequence and drug SMILES." return app.predict_interaction(target_seq, drug_smiles) def visualize_wrapper(target_seq, drug_smiles): """Wrapper function for visualization""" if not target_seq.strip() or not drug_smiles.strip(): return None, None, None, "Please provide both target sequence and drug SMILES." return app.visualize_interaction(target_seq, drug_smiles) def load_model_wrapper(model_path): """Wrapper function to load model""" if app.load_model(model_path): return "Model loaded successfully!" else: return "Failed to load model. Check the path and files." # Create Gradio interface with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo: gr.HTML("""

🧬 Drug-Target Interaction Predictor

Predict binding affinity between drugs and target RNA sequences using deep learning

""") # Create state variables to share images between tabs viz_state1 = gr.State() viz_state2 = gr.State() viz_state3 = gr.State() with gr.Tab("🔮 Prediction & Analysis"): with gr.Row(): with gr.Column(scale=1): target_input = gr.Textbox( label="Target RNA Sequence", placeholder="Enter RNA sequence (e.g., AUGCUAGCUAGUACGUA...)", lines=4, max_lines=6 ) drug_input = gr.Textbox( label="Drug SMILES", placeholder="Enter SMILES notation (e.g., CC(C)CC1=CC=C(C=C1)C(C)C(=O)O)", lines=2 ) with gr.Row(): predict_btn = gr.Button("🚀 Predict Interaction", variant="primary", size="lg") visualize_btn = gr.Button("📊 Generate Visualizations", variant="secondary", size="lg") with gr.Column(scale=1): prediction_output = gr.Textbox( label="Prediction Result", interactive=False, lines=4 ) # Example inputs gr.HTML("

📚 Example Inputs:

") examples = gr.Examples( examples=[ [ "AUGCUAGCUAGUACGUAUAUCUGCACUGC", "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" ], [ "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU", "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2" ] ], inputs=[target_input, drug_input], outputs=prediction_output, fn=predict_wrapper, cache_examples=False ) # Button click events predict_btn.click( fn=predict_wrapper, inputs=[target_input, drug_input], outputs=prediction_output ) def visualize_and_update(target_seq, drug_smiles): """Generate visualizations and update both status and state""" img1, img2, img3, status = visualize_wrapper(target_seq, drug_smiles) # Combine prediction result with visualization status combined_status = status + "\n\nVisualization analysis complete. Please navigate to the Visualizations tab to view the generated images." return img1, img2, img3, combined_status visualize_btn.click( fn=visualize_and_update, inputs=[target_input, drug_input], outputs=[viz_state1, viz_state2, viz_state3, prediction_output] ) with gr.Tab("📊 Visualizations"): gr.HTML("""

🔬 Interaction Analysis & Visualizations

Generated visualizations will appear here after clicking "Generate Visualizations" in the Prediction tab

""") # Visualization outputs - Large and vertically aligned viz_image1 = gr.Image( label="Cross-Attention Heatmap", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Cross-Attention Heatmap\n(Generate visualizations in the Prediction tab)") ) viz_image2 = gr.Image( label="Raw pKd Contribution Visualization", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Raw pKd Contribution\n(Generate visualizations in the Prediction tab)") ) viz_image3 = gr.Image( label="Normalized pKd Contribution Visualization", type="pil", interactive=False, container=True, height=500, value=create_placeholder_image(text="Normalized pKd Contribution\n(Generate visualizations in the Prediction tab)") ) # Update visualization images when state changes viz_state1.change( fn=lambda x: x, inputs=viz_state1, outputs=viz_image1 ) viz_state2.change( fn=lambda x: x, inputs=viz_state2, outputs=viz_image2 ) viz_state3.change( fn=lambda x: x, inputs=viz_state3, outputs=viz_image3 ) with gr.Tab("⚙️ Model Settings"): gr.HTML("

Model Configuration

") model_path_input = gr.Textbox( label="Model Path", value="./", placeholder="Path to model directory" ) load_model_btn = gr.Button("📥 Load Model", variant="secondary") model_status = gr.Textbox( label="Status", interactive=False, value="No model loaded" ) load_model_btn.click( fn=load_model_wrapper, inputs=model_path_input, outputs=model_status ) with gr.Tab("📊 Dataset"): gr.Markdown(""" ## Training and Test Datasets ### Fine-tuning Dataset (Training) The model was trained on a dataset comprising **1,439 RNA–drug interaction pairs**, including: - **759 unique compounds** (SMILES representations) - **294 unique RNA sequences** - Dissociation constants (pKd values) for binding affinity prediction **RNA Sequence Distribution by Type:** | RNA Sequence Type | Number of Interactions | |-------------------|------------------------| | Aptamers | 520 | | Ribosomal | 295 | | Viral RNAs | 281 | | miRNAs | 146 | | Riboswitches | 100 | | Repeats | 97 | | **Total** | **1,439** | ### External Evaluation Dataset (Test) Model validation was performed using external ROBIN classification datasets containing **5,534 RNA–drug pairs**: - **2,991 positive interactions** - **2,538 negative interactions** **Test Dataset Composition:** - **1,617 aptamer pairs** (5 unique RNA sequences) - **1,828 viral RNA pairs** (6 unique RNA sequences) - **1,459 riboswitch pairs** (5 unique RNA sequences) - **630 miRNA pairs** (3 unique RNA sequences) ### Dataset Downloads - [Training Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/training_data.csv?download=true) - [Test Dataset Download](https://huggingface.co/spaces/IlPakoZ/DLRNA-BERTa/resolve/main/datasets/test_data.csv?download=true) ### Citation Original datasets published by: **Krishnan et al.** - Available on the RSAPred website in PDF format. *Reference:* ```bibtex @article{krishnan2024reliable, title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning}, author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael}, journal={Briefings in Bioinformatics}, volume={25}, number={2}, pages={bbae002}, year={2024}, publisher={Oxford University Press} } ``` """) with gr.Tab("ℹ️ About"): gr.Markdown(""" ## About this application This application implements DLRNA-BERTa, a Dual Langauge RoBERTa Transformer model for predicting drug to RNA target interactions. The model architecture includes: - **Target encoder**: Processes RNA sequences using RNA-BERTa - **Drug encoder**: Processes molecular SMILES notation using ChemBERTa - **Cross-attention mechanism**: Captures interactions between drugs and targets - **Regression head**: Predicts binding affinity scores (pKd values) ### Input requirements: - **Target sequence**: RNA sequence of the target (nucleotide sequences: A, U, G, C) - **Drug SMILES**: Simplified Molecular Input Line Entry System notation ### Model features: - Cross-attention for drug-target interaction modeling - Dropout for regularization - Layer normalization for stable training - Interpretability mode for contribution and attention visualization ### Usage tips: 1. Load a trained model using the Model Settings tab (optional) 2. Enter a RNA sequence and drug SMILES in the Prediction & Analysis tab 3. Click "Predict Interaction" for binding affinity prediction only 4. Click "Generate Visualizations" to create detailed interaction analysis - results will appear in the Visualizations tab For best results, ensure your input sequences are properly formatted and within reasonable length limits (max 512 tokens). ### Visualization features: - **Cross-attention heatmap**: Shows cross-attention weights between drug and target tokens - **Unnormalized pKd contribution**: Shows unnormalized signed contributions from each target token (only when pKd > 0) - **Normalized pKd contribution**: Shows normalized non-negative contributions from each target token ### Performance metrics: - Training on diverse drug-target interaction datasets - Evaluated using RMSE, Pearson correlation, and Concordance Index - Optimized for both predictive accuracy and interpretability ### GitHub repository: - The full model GitHub repository can be found here: https://github.com/IlPakoZ/dlrnaberta-dti-prediction ### Contribution: - Special thanks to Umut Onur Özcan for help in developing this space:) ### Contact: - Ziaurrehman Tanoli (ziaurrehman.tanoli@helsinki.fi) Principal investigator at Institute for Molecular Medicine Finland HiLIFE, University of Helsinki, Finland. """) # Launch the app if __name__ == "__main__": # Try to load model on startup if os.path.exists("./config.json"): app.load_model("./") demo.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )