--- title: "DLRNA-BERTa9700 Interface" sdk: "gradio" sdk_version: 5.44.0 app_file: app.py pinned: false --- # Drug-target interaction prediction model ## Model description This model predicts drug-target interactions using a novel cross-attention architecture that combines RNA sequence understanding with molecular representation learning. The model processes RNA target sequences and drug SMILES representations to predict binding affinity scores (pKd values). ## Architecture The model consists of several key components: 1. **Target encoder**: RNA-BERTa model that processes RNA sequences (nucleotides A, U, G, C) 2. **Drug encoder**: ChemBERTa-77M-MTR model [1] that processes molecular SMILES representations 3. **Cross-attention layer**: Single-head attention mechanism (1 head) that models interactions between drug and target representations 4. **Regression head**: Predicts binding affinity scores with learnable scaling and bias parameters ### Technical specifications - **Model size**: Combines RNA-BERTa (target encoder) + ChemBERTa-77M-MTR (drug encoder) - **Cross-attention**: Single-head attention with 384-dimensional embeddings - **Maximum sequence length**: 512 tokens for both target and drug inputs - **Output**: Continuous binding affinity prediction (pKd values) - **Dropout**: Configurable attention dropout and hidden dropout for regularization - **Layer normalization**: Applied for training stability ## Performance metrics Evaluated on external ROBIN test datasets [2] across different RNA classes: | Dataset | Precision | Specificity | Recall | AUROC | F1 Score | |---------|-----------|-------------|---------|-------|----------| | Aptamers | 0.648 | 0.002 | 1.000 | 0.571 | 0.787 | | Riboswitch | 0.519 | 0.035 | 0.972 | 0.577 | 0.677 | | Viral RNA | 0.562 | 0.095 | 0.943 | 0.579 | 0.704 | | miRNA | 0.373 | 0.028 | 0.991 | 0.596 | 0.542 | ## Usage ### Using the Gradio interface ```python import gradio as gr from updated_app import demo # Launch the interactive interface demo.launch() ``` ### Programmatic usage ```python from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler from configuration_dlmberta import InteractionModelATTNConfig from transformers import AutoModel, RobertaModel, AutoConfig from chemberta import ChembertaTokenizer # Load model components config = InteractionModelATTNConfig.from_pretrained("path/to/model") # Load encoders target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700") drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR") drug_encoder_config.pooler = None drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False) # Load scaler (if available) scaler = StdScaler() scaler.load("path/to/model") # Initialize model model = InteractionModelATTNForRegression.from_pretrained( "path/to/model", config=config, target_encoder=target_encoder, drug_encoder=drug_encoder, scaler=scaler ) # Make predictions target_sequence = "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU" drug_smiles = "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2" # Tokenize inputs target_inputs = target_tokenizer(target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt") drug_inputs = drug_tokenizer(drug_smiles, padding="max_length", truncation=True, max_length=512, return_tensors="pt") # Predict with torch.no_grad(): prediction = model(target_inputs, drug_inputs) if model.scaler: prediction = model.unscale(prediction) ``` ## Model inputs - **Target sequence**: RNA sequence using nucleotides A, U, G, C (string) - **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string) ## Model outputs - **Binding affinity**: Predicted pKd binding affinity score (float) - **Attention weights**: Cross-attention weights for interpretability analysis (when enabled) ## Interpretability features The model includes advanced interpretability capabilities: - **Cross-attention visualization**: Heatmaps showing interaction patterns between drug and target tokens - **Token-level contributions**: Visualization of individual token contributions to the final prediction - **Unnormalized vs. normalized contributions**: Both scaled and unscaled contribution analysis - **Interpretation mode**: Special mode for extracting attention weights and intermediate values ### Enabling interpretation mode ```python # Enable interpretation mode (evaluation only) model.INTERPR_ENABLE_MODE() # Make prediction with interpretation data prediction = model(target_inputs, drug_inputs) # Access interpretation data cross_attention_weights = model.model.crossattention_weights presum_contributions = model.model.presum_layer attention_scores = model.model.scores # Disable interpretation mode model.INTERPR_DISABLE_MODE() ``` ## Training details ### Data processing - **Scaling**: Uses StdScaler for target value normalization - **Tokenization**: Separate tokenizers for RNA sequences and SMILES strings - **Padding**: Max length padding to 512 tokens - **Masking**: Attention masks to handle variable-length sequences ### Architecture details - **Embedding dimension**: 384 for cross-attention layer - **Target encoder output**: 512 dimensions, mapped to 384 - **Drug encoder output**: 384 dimensions (direct use) - **Attention mechanism**: Single-head cross-attention with scaled dot-product - **Learnable parameters**: Weighted sum with learnable scaling vector and bias - **Padding handling**: Learnable padding value for masked positions ## Limitations - Performance varies significantly across RNA classes (miRNA shows lower precision) - May not generalize well to RNA sequences or chemical scaffolds not represented in training data - Computational requirements scale with sequence length (max 512 tokens) - Single attention head may limit capacity to capture diverse interaction patterns - SMILES representation may not capture all relevant molecular properties ## Files in this repository - `modeling_dlmberta.py`: Main model implementation with cross-attention architecture - `configuration_dlmberta.py`: Model configuration class - `chemberta.py`: Custom tokenizer for chemical SMILES processing - `updated_app.py`: Gradio application interface with visualization capabilities - `analysis.py`: Visualization functions for interpretability - `requirements.txt`: Python dependencies - `config.json`: Model configuration file ## License This model is released under the MIT License. ### Citations [1] ```bibtex @article{ahmad2022chemberta, title={Chemberta-2: Towards chemical foundation models}, author={Ahmad, Walid and Simon, Elana and Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath}, journal={arXiv preprint arXiv:2209.01712}, year={2022} } ``` [2] ```bibtex @article{krishnan2024reliable, title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning}, author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael}, journal={Briefings in Bioinformatics}, volume={25}, number={2}, pages={bbae002}, year={2024}, publisher={Oxford University Press} } ```