File size: 7,185 Bytes
abcfaf8
 
 
 
 
 
 
 
cbce8f4
79111ac
cbce8f4
79111ac
 
 
 
 
 
 
cbce8f4
 
 
 
79111ac
cbce8f4
79111ac
cbce8f4
 
 
79111ac
 
cbce8f4
79111ac
cbce8f4
79111ac
 
 
 
 
 
 
 
 
 
 
 
cbce8f4
79111ac
 
 
 
 
 
 
 
 
cbce8f4
79111ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93c30c4
79111ac
cbce8f4
79111ac
 
93c30c4
79111ac
cbce8f4
 
79111ac
93c30c4
79111ac
 
 
cbce8f4
 
 
 
79111ac
cbce8f4
79111ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbce8f4
79111ac
cbce8f4
79111ac
 
 
 
 
cbce8f4
 
 
 
 
 
 
79111ac
 
 
 
 
 
 
 
 
cbce8f4
79111ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
---
title: "DLRNA-BERTa9700 Interface"
sdk: "gradio"
sdk_version: 5.44.0
app_file: app.py
pinned: false
---

# Drug-target interaction prediction model

## Model description

This model predicts drug-target interactions using a novel cross-attention architecture that combines RNA sequence understanding with molecular representation learning. The model processes RNA target sequences and drug SMILES representations to predict binding affinity scores (pKd values).

## Architecture

The model consists of several key components:

1. **Target encoder**: RNA-BERTa model that processes RNA sequences (nucleotides A, U, G, C)
2. **Drug encoder**: ChemBERTa-77M-MTR model [1] that processes molecular SMILES representations  
3. **Cross-attention layer**: Single-head attention mechanism (1 head) that models interactions between drug and target representations
4. **Regression head**: Predicts binding affinity scores with learnable scaling and bias parameters

### Technical specifications

- **Model size**: Combines RNA-BERTa (target encoder) + ChemBERTa-77M-MTR (drug encoder)
- **Cross-attention**: Single-head attention with 384-dimensional embeddings
- **Maximum sequence length**: 512 tokens for both target and drug inputs
- **Output**: Continuous binding affinity prediction (pKd values)
- **Dropout**: Configurable attention dropout and hidden dropout for regularization
- **Layer normalization**: Applied for training stability

## Performance metrics

Evaluated on external ROBIN test datasets [2] across different RNA classes:

| Dataset | Precision | Specificity | Recall | AUROC | F1 Score |
|---------|-----------|-------------|---------|-------|----------|
| Aptamers | 0.648 | 0.002 | 1.000 | 0.571 | 0.787 |
| Riboswitch | 0.519 | 0.035 | 0.972 | 0.577 | 0.677 |
| Viral RNA | 0.562 | 0.095 | 0.943 | 0.579 | 0.704 |
| miRNA | 0.373 | 0.028 | 0.991 | 0.596 | 0.542 |

## Usage

### Using the Gradio interface

```python
import gradio as gr
from updated_app import demo

# Launch the interactive interface
demo.launch()
```

### Programmatic usage

```python
from modeling_dlmberta import InteractionModelATTNForRegression, StdScaler
from configuration_dlmberta import InteractionModelATTNConfig
from transformers import AutoModel, RobertaModel, AutoConfig
from chemberta import ChembertaTokenizer

# Load model components
config = InteractionModelATTNConfig.from_pretrained("path/to/model")

# Load encoders
target_encoder = AutoModel.from_pretrained("IlPakoZ/RNA-BERTa9700")
drug_encoder_config = AutoConfig.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
drug_encoder_config.pooler = None
drug_encoder = RobertaModel(config=drug_encoder_config, add_pooling_layer=False)

# Load scaler (if available)
scaler = StdScaler()
scaler.load("path/to/model")

# Initialize model
model = InteractionModelATTNForRegression.from_pretrained(
    "path/to/model",
    config=config,
    target_encoder=target_encoder,
    drug_encoder=drug_encoder,
    scaler=scaler
)

# Make predictions
target_sequence = "AUGCGAUCGACGUACGUUAGCCGUAGCGUAGCUAGUGUAGCUAGUAGCU"
drug_smiles = "C1=CC=C(C=C1)NC(=O)C2=CC=CC=N2"

# Tokenize inputs
target_inputs = target_tokenizer(target_sequence, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
drug_inputs = drug_tokenizer(drug_smiles, padding="max_length", truncation=True, max_length=512, return_tensors="pt")

# Predict
with torch.no_grad():
    prediction = model(target_inputs, drug_inputs)
    if model.scaler:
        prediction = model.unscale(prediction)
```

## Model inputs

- **Target sequence**: RNA sequence using nucleotides A, U, G, C (string)
- **Drug SMILES**: Simplified Molecular Input Line Entry System notation (string)

## Model outputs

- **Binding affinity**: Predicted pKd binding affinity score (float)
- **Attention weights**: Cross-attention weights for interpretability analysis (when enabled)

## Interpretability features

The model includes advanced interpretability capabilities:

- **Cross-attention visualization**: Heatmaps showing interaction patterns between drug and target tokens
- **Token-level contributions**: Visualization of individual token contributions to the final prediction
- **Unnormalized vs. normalized contributions**: Both scaled and unscaled contribution analysis
- **Interpretation mode**: Special mode for extracting attention weights and intermediate values

### Enabling interpretation mode

```python
# Enable interpretation mode (evaluation only)
model.INTERPR_ENABLE_MODE()

# Make prediction with interpretation data
prediction = model(target_inputs, drug_inputs)

# Access interpretation data
cross_attention_weights = model.model.crossattention_weights
presum_contributions = model.model.presum_layer
attention_scores = model.model.scores

# Disable interpretation mode
model.INTERPR_DISABLE_MODE()
```

## Training details

### Data processing
- **Scaling**: Uses StdScaler for target value normalization
- **Tokenization**: Separate tokenizers for RNA sequences and SMILES strings
- **Padding**: Max length padding to 512 tokens
- **Masking**: Attention masks to handle variable-length sequences

### Architecture details
- **Embedding dimension**: 384 for cross-attention layer
- **Target encoder output**: 512 dimensions, mapped to 384
- **Drug encoder output**: 384 dimensions (direct use)
- **Attention mechanism**: Single-head cross-attention with scaled dot-product
- **Learnable parameters**: Weighted sum with learnable scaling vector and bias
- **Padding handling**: Learnable padding value for masked positions

## Limitations

- Performance varies significantly across RNA classes (miRNA shows lower precision)
- May not generalize well to RNA sequences or chemical scaffolds not represented in training data
- Computational requirements scale with sequence length (max 512 tokens)
- Single attention head may limit capacity to capture diverse interaction patterns
- SMILES representation may not capture all relevant molecular properties

## Files in this repository

- `modeling_dlmberta.py`: Main model implementation with cross-attention architecture
- `configuration_dlmberta.py`: Model configuration class
- `chemberta.py`: Custom tokenizer for chemical SMILES processing
- `updated_app.py`: Gradio application interface with visualization capabilities
- `analysis.py`: Visualization functions for interpretability
- `requirements.txt`: Python dependencies
- `config.json`: Model configuration file

## License

This model is released under the MIT License.

### Citations
[1] 
```bibtex
@article{ahmad2022chemberta,
  title={Chemberta-2: Towards chemical foundation models},
  author={Ahmad, Walid and Simon, Elana and Chithrananda, Seyone and Grand, Gabriel and Ramsundar, Bharath},
  journal={arXiv preprint arXiv:2209.01712},
  year={2022}
}
```

[2] 
```bibtex
@article{krishnan2024reliable,
  title={Reliable method for predicting the binding affinity of RNA-small molecule interactions using machine learning},
  author={Krishnan, Sowmya R and Roy, Arijit and Gromiha, M Michael},
  journal={Briefings in Bioinformatics},
  volume={25},
  number={2},
  pages={bbae002},
  year={2024},
  publisher={Oxford University Press}
}
```