Spaces:
Sleeping
Sleeping
| title: "DLRNA-BERTa9700 Interface" | |
| sdk: "gradio" | |
| sdk_version: 5.44.0 | |
| app_file: app.py | |
| pinned: false | |
| # Drug-target interaction prediction model | |
| ## Model description | |
| This model predicts drug-target interactions using a novel cross-attention architecture that combines RNA sequence understanding with molecular representation learning. The model processes RNA target sequences and drug SMILES representations to predict binding affinity scores (pKd values). | |
| ## Architecture | |
| The model consists of several key components: | |
| 1. **Target encoder**: RNA-BERTa model that processes RNA sequences (nucleotides A, U, G, C) | |
| 2. **Drug encoder**: ChemBERTa-77M-MTR model [1] that processes molecular SMILES representations | |
| 3. **Cross-attention layer**: Single-head attention mechanism (1 head) that models interactions between drug and target representations | |
| 4. **Regression head**: Predicts binding affinity scores with learnable scaling and bias parameters | |
| ### Technical specifications | |
| - **Model size**: Combines RNA-BERTa (target encoder) + ChemBERTa-77M-MTR (drug encoder) | |
| - **Cross-attention**: Single-head attention with 384-dimensional embeddings | |
| - **Maximum sequence length**: 512 tokens for both target and drug inputs | |
| - **Output**: Continuous binding affinity prediction (pKd values) | |
| - **Dropout**: Configurable attention dropout and hidden dropout for regularization | |
| - **Layer normalization**: Applied for training stability | |
| ## Performance metrics | |
| Evaluated on external ROBIN test datasets [2] across different RNA classes: | |
| | Dataset | Precision | Specificity | Recall | AUROC | F1 Score | | |
| |---------|-----------|-------------|---------|-------|----------| | |
| | Aptamers | 0.648 | 0.002 | 1.000 | 0.571 | 0.787 | | |
| | Riboswitch | 0.519 | 0.035 | 0.972 | 0.577 | 0.677 | | |
| | Viral RNA | 0.562 | 0.095 | 0.943 | 0.579 | 0.704 | | |
| | miRNA | 0.373 | 0.028 | 0.991 | 0.596 | 0.542 | | |
| ## Usage | |
| ### Using the Gradio interface | |
| ```python | |
| import gradio as gr | |
| from updated_app import demo | |
| # Launch the interactive interface | |
| demo.launch() | |
| ``` | |
| ### Programmatic usage | |
| ```python | |
| from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler | |
| from configuration_dlmberta import InteractionModelATTNConfig | |
| from transformers import AutoModel, RobertaModel, AutoConfig | |
| from chemberta import ChembertaTokenizer | |
| # Load model components | |
| config = InteractionModelATTNConfig.from_pretrained("path/to/model") | |
| # Load encoders | |
| target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700") | |
| drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR") | |
| drug_encoder_config.pooler = None | |
| drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False) | |
| # Load scaler (if available) | |
| scaler = StdScaler() | |
| scaler.load("path/to/model") | |
| # Initialize model | |
| model = InteractionModelATTNForRegression.from_pretrained( | |
| "path/to/model", | |
| config=config, | |
| target_encoder=target_encoder, | |
| drug_encoder=drug_encoder, | |
| scaler=scaler | |
| ) | |
| # Make predictions | |
| target_sequence = "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU" | |
| drug_smiles = "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2" | |
| # Tokenize inputs | |
| target_inputs = target_tokenizer(target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt") | |
| drug_inputs = drug_tokenizer(drug_smiles, padding="max_length", truncation=True, max_length=512, return_tensors="pt") | |
| # Predict | |
| with torch.no_grad(): | |
| prediction = model(target_inputs, drug_inputs) | |
| if model.scaler: | |
| prediction = model.unscale(prediction) | |
| ``` | |
| ## Model inputs | |
| - **Target sequence**: RNA sequence using nucleotides A, U, G, C (string) | |
| - **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string) | |
| ## Model outputs | |
| - **Binding affinity**: Predicted pKd binding affinity score (float) | |
| - **Attention weights**: Cross-attention weights for interpretability analysis (when enabled) | |
| ## Interpretability features | |
| The model includes advanced interpretability capabilities: | |
| - **Cross-attention visualization**: Heatmaps showing interaction patterns between drug and target tokens | |
| - **Token-level contributions**: Visualization of individual token contributions to the final prediction | |
| - **Unnormalized vs. normalized contributions**: Both scaled and unscaled contribution analysis | |
| - **Interpretation mode**: Special mode for extracting attention weights and intermediate values | |
| ### Enabling interpretation mode | |
| ```python | |
| # Enable interpretation mode (evaluation only) | |
| model.INTERPR_ENABLE_MODE() | |
| # Make prediction with interpretation data | |
| prediction = model(target_inputs, drug_inputs) | |
| # Access interpretation data | |
| cross_attention_weights = model.model.crossattention_weights | |
| presum_contributions = model.model.presum_layer | |
| attention_scores = model.model.scores | |
| # Disable interpretation mode | |
| model.INTERPR_DISABLE_MODE() | |
| ``` | |
| ## Training details | |
| ### Data processing | |
| - **Scaling**: Uses StdScaler for target value normalization | |
| - **Tokenization**: Separate tokenizers for RNA sequences and SMILES strings | |
| - **Padding**: Max length padding to 512 tokens | |
| - **Masking**: Attention masks to handle variable-length sequences | |
| ### Architecture details | |
| - **Embedding dimension**: 384 for cross-attention layer | |
| - **Target encoder output**: 512 dimensions, mapped to 384 | |
| - **Drug encoder output**: 384 dimensions (direct use) | |
| - **Attention mechanism**: Single-head cross-attention with scaled dot-product | |
| - **Learnable parameters**: Weighted sum with learnable scaling vector and bias | |
| - **Padding handling**: Learnable padding value for masked positions | |
| ## Limitations | |
| - Performance varies significantly across RNA classes (miRNA shows lower precision) | |
| - May not generalize well to RNA sequences or chemical scaffolds not represented in training data | |
| - Computational requirements scale with sequence length (max 512 tokens) | |
| - Single attention head may limit capacity to capture diverse interaction patterns | |
| - SMILES representation may not capture all relevant molecular properties | |
| ## Files in this repository | |
| - `modeling_dlmberta.py`: Main model implementation with cross-attention architecture | |
| - `configuration_dlmberta.py`: Model configuration class | |
| - `chemberta.py`: Custom tokenizer for chemical SMILES processing | |
| - `updated_app.py`: Gradio application interface with visualization capabilities | |
| - `analysis.py`: Visualization functions for interpretability | |
| - `requirements.txt`: Python dependencies | |
| - `config.json`: Model configuration file | |
| ## License | |
| This model is released under the MIT License. | |
| ### Citations | |
| [1] | |
| ```bibtex | |
| @article{ahmad2022chemberta, | |
| title={Chemberta-2: Towards chemical foundation models}, | |
| author={Ahmad, Walid and Simon, Elana and Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath}, | |
| journal={arXiv preprint arXiv:2209.01712}, | |
| year={2022} | |
| } | |
| ``` | |
| [2] | |
| ```bibtex | |
| @article{krishnan2024reliable, | |
| title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning}, | |
| author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael}, | |
| journal={Briefings in Bioinformatics}, | |
| volume={25}, | |
| number={2}, | |
| pages={bbae002}, | |
| year={2024}, | |
| publisher={Oxford University Press} | |
| } | |
| ``` | |