import gradio as gr import torch import numpy as np from transformers import AutoModel, AutoTokenizer, AutoConfig, RobertaModel from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler from configuration_dlmberta import InteractionModelATTNConfig from chemberta import ChembertaTokenizer import json import os from pathlib import Path import logging # Import visualization functions from analysis import plot_crossattention_weights, plot_presum from PIL import Image, ImageDraw, ImageFont # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def create_placeholder_image(width=400, height=300, text="No visualization available", bg_color=(0, 0, 0, 0)): """ Create a transparent placeholder image with text Args: width (int): Image width height (int): Image height text (str): Text to display bg_color (tuple): Background color (R, G, B, A) - (0,0,0,0) for transparent Returns: PIL.Image: Transparent placeholder image """ # Create image with transparent background img = Image.new('RGBA', (width, height), bg_color) draw = ImageDraw.Draw(img) # Try to use a default font, fallback to default if not available try: font = ImageFont.truetype("arial.ttf", 16) except: try: font = ImageFont.load_default() except: font = None # Get text size and position for centering if font: bbox = draw.textbbox((0, 0), text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] else: # Rough estimation if no font available text_width = len(text) * 8 text_height = 16 x = (width - text_width) // 2 y = (height - text_height) // 2 # Draw text in gray draw.text((x, y), text, fill=(128, 128, 128, 255), font=font) return img class DrugTargetInteractionApp: def __init__(self): self.model = None self.target_tokenizer = None self.drug_tokenizer = None self.scaler = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(self, model_path="./"): """Load the pre-trained model and tokenizers""" try: # Load configuration config = InteractionModelATTNConfig.from_pretrained(model_path) # Load drug encoder (ChemBERTa) drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR") drug_encoder_config.pooler = None drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False) # Load target encoder target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700") # Load scaler if exists scaler_path = os.path.join(model_path, "scaler.config") scaler = None if os.path.exists(scaler_path): scaler = StdScaler() scaler.load(model_path) self.model = InteractionModelATTNForRegression.from_pretrained( model_path, config=config, target_encoder=target_encoder, drug_encoder=drug_encoder, scaler=scaler ) self.model.to(self.device) self.model.eval() # Load tokenizers self.target_tokenizer = AutoTokenizer.from_pretrained( os.path.join(model_path, "target_tokenizer") ) # Load drug tokenizer (ChemBERTa) vocab_file = os.path.join(model_path, "drug_tokenizer", "vocab.json") self.drug_tokenizer = ChembertaTokenizer(vocab_file) logger.info("Model and tokenizers loaded successfully!") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") return False def predict_interaction(self, target_sequence, drug_smiles, max_length=512): """Predict drug-target interaction""" if self.model is None: return "Error: Model not loaded. Please load a model first." try: # Tokenize inputs target_inputs = self.target_tokenizer( target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) drug_inputs = self.drug_tokenizer( drug_smiles, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) # Make prediction self.model.INTERPR_DISABLE_MODE() with torch.no_grad(): prediction = self.model(target_inputs, drug_inputs) # Unscale if scaler exists if self.model.scaler is not None: prediction = self.model.unscale(prediction) prediction_value = prediction.cpu().numpy()[0][0] return f"Predicted Binding Affinity: {prediction_value:.4f}" except Exception as e: logger.error(f"Prediction error: {str(e)}") return f"Error during prediction: {str(e)}" def visualize_interaction(self, target_sequence, drug_smiles): """ Generate visualization images for drug-target interaction Args: target_sequence (str): RNA sequence drug_smiles (str): Drug SMILES notation Returns: tuple: (cross_attention_image, raw_contribution_image, normalized_contribution_image, status_message) """ if self.model is None: return None, None, None, "Error: Model not loaded. Please load a model first." try: # Tokenize inputs target_inputs = self.target_tokenizer( target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) drug_inputs = self.drug_tokenizer( drug_smiles, padding="max_length", truncation=True, max_length=512, return_tensors="pt" ).to(self.device) # Enable interpretation mode self.model.INTERPR_ENABLE_MODE() # Make prediction and extract visualization data with torch.no_grad(): prediction = self.model(target_inputs, drug_inputs) # Unscale if scaler exists if self.model.scaler is not None: prediction = self.model.unscale(prediction) prediction_value = prediction.cpu().numpy()[0][0] # Extract data needed for visualizations presum_values = self.model.model.presum_layer # Shape: (1, seq_len) cross_attention_weights = self.model.model.crossattention_weights # Shape: (batch, heads, seq_len, seq_len) # Get model parameters for scaling w = self.model.model.w.squeeze(1) b = self.model.model.b scaler = self.model.model.scaler logger.info(f"Target inputs shape: {target_inputs['input_ids'].shape}") logger.info(f"Drug inputs shape: {drug_inputs['input_ids'].shape}") # Generate visualizations try: # 1. Cross-attention heatmap cross_attention_img = None logger.info(f"Cross-attention weights type: {type(cross_attention_weights)}") if cross_attention_weights is not None: logger.info(f"Cross-attention weights shape: {cross_attention_weights.shape if hasattr(cross_attention_weights, 'shape') else 'No shape attr'}") try: cross_attn_matrix = cross_attention_weights[0, 0] if cross_attn_matrix is not None: logger.info(f"Extracted cross-attention matrix shape: {cross_attn_matrix.shape}") logger.info(f"Target attention mask shape: {target_inputs['attention_mask'].shape}") logger.info(f"Drug attention mask shape: {drug_inputs['attention_mask'].shape}") cross_attention_img = plot_crossattention_weights( target_inputs["attention_mask"][0], drug_inputs["attention_mask"][0], target_inputs, drug_inputs, cross_attn_matrix, self.target_tokenizer, self.drug_tokenizer ) else: logger.warning("Could not extract valid cross-attention matrix") except (IndexError, TypeError, AttributeError) as e: logger.warning(f"Error extracting cross-attention matrix: {str(e)}") cross_attn_matrix = None else: logger.warning("Cross-attention weights are None") except Exception as e: logger.error(f"Cross-attention visualization error: {str(e)}") cross_attention_img = None try: # 2. Normalized contribution visualization (always generate) normalized_img = None if presum_values is not None: normalized_img = plot_presum( target_inputs, presum_values.detach(), # Detach the tensor scaler, w.detach(), # Detach the tensor b.detach(), # Detach the tensor self.target_tokenizer, raw_affinities=False ) else: logger.warning("Presum values are None") except Exception as e: logger.error(f"Normalized contribution visualization error: {str(e)}") normalized_img = None try: # 3. Raw contribution visualization (only if pKd > 0) raw_img = None if prediction_value > 0 and presum_values is not None: raw_img = plot_presum( target_inputs, presum_values.detach(), # Detach the tensor scaler, w.detach(), # Detach the tensor b.detach(), # Detach the tensor self.target_tokenizer, raw_affinities=True ) else: if prediction_value <= 0: logger.info("Skipping raw affinities visualization as pKd <= 0") if presum_values is None: logger.warning("Cannot generate raw visualization: presum values are None") except Exception as e: logger.error(f"Raw contribution visualization error: {str(e)}") raw_img = None # Disable interpretation mode after use self.model.INTERPR_DISABLE_MODE() # Create placeholder images if generation failed if cross_attention_img is None: cross_attention_img = create_placeholder_image( text="Cross-Attention Heatmap\nFailed to generate" ) if normalized_img is None: normalized_img = create_placeholder_image( text="Normalized Contribution\nFailed to generate" ) if raw_img is None and prediction_value > 0: raw_img = create_placeholder_image( text="Raw Contribution\nFailed to generate" ) elif raw_img is None: raw_img = create_placeholder_image( text="Raw Contribution\nSkipped (pKd ⤠0)" ) status_msg = f"Predicted Binding Affinity: {prediction_value:.4f}" if prediction_value <= 0: status_msg += " (Raw contribution visualization skipped due to non-positive pKd)" if cross_attention_weights is None: status_msg += " (Cross-attention visualization failed: weights not available)" return cross_attention_img, raw_img, normalized_img, status_msg except Exception as e: logger.error(f"Visualization error: {str(e)}") # Make sure to disable interpretation mode even if there's an error try: self.model.INTERPR_DISABLE_MODE() except: pass return None, None, None, f"Error during visualization: {str(e)}" # Initialize the app app = DrugTargetInteractionApp() def predict_wrapper(target_seq, drug_smiles): """Wrapper function for Gradio interface""" if not target_seq.strip() or not drug_smiles.strip(): return "Please provide both target sequence and drug SMILES." return app.predict_interaction(target_seq, drug_smiles) def visualize_wrapper(target_seq, drug_smiles): """Wrapper function for visualization""" if not target_seq.strip() or not drug_smiles.strip(): return None, None, None, "Please provide both target sequence and drug SMILES." return app.visualize_interaction(target_seq, drug_smiles) def load_model_wrapper(model_path): """Wrapper function to load model""" if app.load_model(model_path): return "Model loaded successfully!" else: return "Failed to load model. Check the path and files." # Create Gradio interface with gr.Blocks(title="Drug-Target Interaction Predictor", theme=gr.themes.Soft()) as demo: gr.HTML("""
Predict binding affinity between drugs and target RNA sequences using deep learning