DLRNA-BERTa / README.md
IlPakoZ's picture
Rename readme.md to README.md
abcfaf8 verified
---
title: "DLRNA-BERTa9700 Interface"
sdk: "gradio"
sdk_version: 5.44.0
app_file: app.py
pinned: false
---
# Drug-target interaction prediction model
## Model description
This model predicts drug-target interactions using a novel cross-attention architecture that combines RNA sequence understanding with molecular representation learning. The model processes RNA target sequences and drug SMILES representations to predict binding affinity scores (pKd values).
## Architecture
The model consists of several key components:
1. **Target encoder**: RNA-BERTa model that processes RNA sequences (nucleotides A, U, G, C)
2. **Drug encoder**: ChemBERTa-77M-MTR model [1] that processes molecular SMILES representations
3. **Cross-attention layer**: Single-head attention mechanism (1 head) that models interactions between drug and target representations
4. **Regression head**: Predicts binding affinity scores with learnable scaling and bias parameters
### Technical specifications
- **Model size**: Combines RNA-BERTa (target encoder) + ChemBERTa-77M-MTR (drug encoder)
- **Cross-attention**: Single-head attention with 384-dimensional embeddings
- **Maximum sequence length**: 512 tokens for both target and drug inputs
- **Output**: Continuous binding affinity prediction (pKd values)
- **Dropout**: Configurable attention dropout and hidden dropout for regularization
- **Layer normalization**: Applied for training stability
## Performance metrics
Evaluated on external ROBIN test datasets [2] across different RNA classes:
| Dataset | Precision | Specificity | Recall | AUROC | F1 Score |
|---------|-----------|-------------|---------|-------|----------|
| Aptamers | 0.648 | 0.002 | 1.000 | 0.571 | 0.787 |
| Riboswitch | 0.519 | 0.035 | 0.972 | 0.577 | 0.677 |
| Viral RNA | 0.562 | 0.095 | 0.943 | 0.579 | 0.704 |
| miRNA | 0.373 | 0.028 | 0.991 | 0.596 | 0.542 |
## Usage
### Using the Gradio interface
```python
import gradio as gr
from updated_app import demo
# Launch the interactive interface
demo.launch()
```
### Programmatic usage
```python
from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler
from configuration_dlmberta import InteractionModelATTNConfig
from transformers import AutoModel, RobertaModel, AutoConfig
from chemberta import ChembertaTokenizer
# Load model components
config = InteractionModelATTNConfig.from_pretrained("path/to/model")
# Load encoders
target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
drug_encoder_config.pooler = None
drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)
# Load scaler (if available)
scaler = StdScaler()
scaler.load("path/to/model")
# Initialize model
model = InteractionModelATTNForRegression.from_pretrained(
"path/to/model",
config=config,
target_encoder=target_encoder,
drug_encoder=drug_encoder,
scaler=scaler
)
# Make predictions
target_sequence = "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU"
drug_smiles = "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"
# Tokenize inputs
target_inputs = target_tokenizer(target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
drug_inputs = drug_tokenizer(drug_smiles, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
# Predict
with torch.no_grad():
prediction = model(target_inputs, drug_inputs)
if model.scaler:
prediction = model.unscale(prediction)
```
## Model inputs
- **Target sequence**: RNA sequence using nucleotides A, U, G, C (string)
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string)
## Model outputs
- **Binding affinity**: Predicted pKd binding affinity score (float)
- **Attention weights**: Cross-attention weights for interpretability analysis (when enabled)
## Interpretability features
The model includes advanced interpretability capabilities:
- **Cross-attention visualization**: Heatmaps showing interaction patterns between drug and target tokens
- **Token-level contributions**: Visualization of individual token contributions to the final prediction
- **Unnormalized vs. normalized contributions**: Both scaled and unscaled contribution analysis
- **Interpretation mode**: Special mode for extracting attention weights and intermediate values
### Enabling interpretation mode
```python
# Enable interpretation mode (evaluation only)
model.INTERPR_ENABLE_MODE()
# Make prediction with interpretation data
prediction = model(target_inputs, drug_inputs)
# Access interpretation data
cross_attention_weights = model.model.crossattention_weights
presum_contributions = model.model.presum_layer
attention_scores = model.model.scores
# Disable interpretation mode
model.INTERPR_DISABLE_MODE()
```
## Training details
### Data processing
- **Scaling**: Uses StdScaler for target value normalization
- **Tokenization**: Separate tokenizers for RNA sequences and SMILES strings
- **Padding**: Max length padding to 512 tokens
- **Masking**: Attention masks to handle variable-length sequences
### Architecture details
- **Embedding dimension**: 384 for cross-attention layer
- **Target encoder output**: 512 dimensions, mapped to 384
- **Drug encoder output**: 384 dimensions (direct use)
- **Attention mechanism**: Single-head cross-attention with scaled dot-product
- **Learnable parameters**: Weighted sum with learnable scaling vector and bias
- **Padding handling**: Learnable padding value for masked positions
## Limitations
- Performance varies significantly across RNA classes (miRNA shows lower precision)
- May not generalize well to RNA sequences or chemical scaffolds not represented in training data
- Computational requirements scale with sequence length (max 512 tokens)
- Single attention head may limit capacity to capture diverse interaction patterns
- SMILES representation may not capture all relevant molecular properties
## Files in this repository
- `modeling_dlmberta.py`: Main model implementation with cross-attention architecture
- `configuration_dlmberta.py`: Model configuration class
- `chemberta.py`: Custom tokenizer for chemical SMILES processing
- `updated_app.py`: Gradio application interface with visualization capabilities
- `analysis.py`: Visualization functions for interpretability
- `requirements.txt`: Python dependencies
- `config.json`: Model configuration file
## License
This model is released under the MIT License.
### Citations
[1]
```bibtex
@article{ahmad2022chemberta,
title={Chemberta-2: Towards chemical foundation models},
author={Ahmad, Walid and Simon, Elana and Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath},
journal={arXiv preprint arXiv:2209.01712},
year={2022}
}
```
[2]
```bibtex
@article{krishnan2024reliable,
title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning},
author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael},
journal={Briefings in Bioinformatics},
volume={25},
number={2},
pages={bbae002},
year={2024},
publisher={Oxford University Press}
}
```