yinuozhang commited on
Commit
680e7ec
·
1 Parent(s): c237d58

add data_processing

Browse files
data_preprocessing/data.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from datasets import Dataset, DatasetDict, load_from_disk
4
+ from torch.utils.data import DataLoader
5
+ import os
6
+ from multiprocessing import Pool
7
+ from tqdm import tqdm
8
+ import lightning.pytorch as pl
9
+ sys.path.append('/home/yz927/projects/peptune/scripts/')
10
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
11
+ global_tokenizer = None
12
+
13
+
14
+ def init_pool(tokenizer):
15
+ global global_tokenizer
16
+ global_tokenizer = tokenizer
17
+
18
+ class SequenceDataset:
19
+ def __init__(self, sequences, tokenizer, max_sequence_length, num_cores=8):
20
+ self.sequences = sequences
21
+ self.tokenizer = tokenizer
22
+ self.max_sequence_length = max_sequence_length
23
+ self.num_cores = 8
24
+ self.tokenized_sequences = []
25
+ self.original_sequences = []
26
+
27
+ def tokenize_sequences(self):
28
+ print(f"Starting parallel tokenization using {self.num_cores} cores")
29
+ with Pool(processes=self.num_cores, initializer=init_pool, initargs=(self.tokenizer,)) as pool:
30
+ results = list(tqdm(
31
+ pool.imap(standalone_tokenize_function, self.sequences),
32
+ total=len(self.sequences)
33
+ ))
34
+
35
+ for result, seq in zip(results, self.sequences):
36
+ if result is not None and len(result['input_ids'][0]) <= self.max_sequence_length:
37
+ self.tokenized_sequences.append(result)
38
+ self.original_sequences.append(seq)
39
+
40
+
41
+ def process_sequences(self, batch_size):
42
+ self.tokenize_sequences()
43
+
44
+ lengths = [(len(seq['input_ids'][0]), i) for i, seq in enumerate(self.tokenized_sequences)]
45
+ lengths.sort()
46
+
47
+ batches = []
48
+ sequence_batches = []
49
+ current_batch = []
50
+ current_sequence_batch = []
51
+ current_length = 0
52
+
53
+ for length, idx in tqdm(lengths):
54
+ if current_length + length > self.max_sequence_length or len(current_batch) == batch_size:
55
+ if current_batch:
56
+ batches.append([self.tokenized_sequences[i] for i in current_batch])
57
+ sequence_batches.append([self.original_sequences[i] for i in current_batch])
58
+ current_batch = [idx]
59
+ current_sequence_batch = [self.original_sequences[idx]]
60
+ current_length = length
61
+ else:
62
+ current_batch.append(idx)
63
+ current_sequence_batch.append(self.original_sequences[idx])
64
+ current_length += length
65
+
66
+ if current_batch:
67
+ batches.append([self.tokenized_sequences[i] for i in current_batch])
68
+ sequence_batches.append([self.original_sequences[i] for i in current_batch])
69
+
70
+ token_batch_fn = TokenizeBatch(self.tokenizer)
71
+ processed_batches = [token_batch_fn(batch) for batch in tqdm(batches)]
72
+
73
+ dataset = Dataset.from_dict({
74
+ 'attention_mask': [batch['attention_mask'] for batch in processed_batches],
75
+ 'input_ids': [batch['input_ids'] for batch in processed_batches],
76
+ 'labels': sequence_batches
77
+ })
78
+
79
+ return dataset
80
+
81
+ class DynamicBatchingDataset(Dataset):
82
+ """
83
+ Process dynamically batched datasets of Huggingface Datasets object. Need special handling since in the previous
84
+ steps, each batch (row in the Datasets object) is already processed for per batch loading
85
+ """
86
+
87
+ def __init__(self, dataset_dict):
88
+ print('Initializing dataset...')
89
+ self.dataset_dict = {
90
+ 'attention_mask': [torch.tensor(item) for item in dataset_dict['attention_mask']],
91
+ 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']],
92
+ 'labels': dataset_dict['labels'] # Store original sequences as it is
93
+ }
94
+
95
+ def __len__(self):
96
+ return len(self.dataset_dict['attention_mask'])
97
+
98
+ def __getitem__(self, idx):
99
+ if isinstance(idx, int):
100
+ return {
101
+ 'attention_mask': self.dataset_dict['attention_mask'][idx],
102
+ 'input_ids': self.dataset_dict['input_ids'][idx],
103
+ 'labels': self.dataset_dict['labels'][idx]
104
+ }
105
+ elif isinstance(idx, list):
106
+ return {
107
+ 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx],
108
+ 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx],
109
+ 'labels': [self.dataset_dict['labels'][i] for i in idx]
110
+ }
111
+ else:
112
+ raise ValueError(f"Expected idx to be int or list, but got {type(idx)}")
113
+
114
+ @staticmethod
115
+ def collate_fn(batch, verbose=False):
116
+ item = batch[0]
117
+ return {
118
+ 'input_ids': item['input_ids'],
119
+ 'attention_mask': item['attention_mask'],
120
+ 'labels': item['labels']
121
+ }
122
+
123
+ def standalone_tokenize_function(sequence):
124
+ global global_tokenizer
125
+ try:
126
+ tokens = global_tokenizer(sequence)
127
+ # The tokenizer already returns lists of integers, so we just need to wrap them in another list
128
+ # to match the expected format [batch_size, sequence_length]
129
+ return {
130
+ 'input_ids': [tokens['input_ids']],
131
+ 'attention_mask': [tokens['attention_mask']]
132
+ }
133
+ except Exception as e:
134
+ print(f"Error tokenizing sequence '{sequence}': {e}")
135
+ return None
136
+
137
+ class TokenizeBatch:
138
+ def __init__(self, tokenizer):
139
+ self.pad_token_id = tokenizer.pad_token_id
140
+
141
+ def __call__(self, batches):
142
+ data_tokens = [torch.tensor(batch['input_ids'][0]) for batch in batches]
143
+ data_tokens_padded = torch.nn.utils.rnn.pad_sequence(data_tokens, batch_first=True, padding_value=self.pad_token_id)
144
+ attention_masks = (data_tokens_padded != self.pad_token_id).long()
145
+
146
+ return {
147
+ 'input_ids': data_tokens_padded,
148
+ 'attention_mask': attention_masks,
149
+ }
150
+
151
+ class PretrainSequenceDataModule(pl.LightningDataModule):
152
+ def __init__(self,
153
+ tokenizer,
154
+ input_dataset_path,
155
+ output_dataset_path,
156
+ num_workers,
157
+ batch_size,
158
+ max_sequence_length=512,):
159
+ super().__init__()
160
+ self.tokenizer = tokenizer
161
+ self.input_path = input_dataset_path
162
+ self.output_path = output_dataset_path
163
+ self.num_workers = num_workers
164
+ self.batch_size = batch_size
165
+ self.max_sequence_length = max_sequence_length
166
+
167
+ def prepare_data(self):
168
+ if not os.path.exists(self.output_path):
169
+ print("Loading text files")
170
+ with open(f"{self.input_path}/train.txt", 'r') as f:
171
+ train_sequences = [line.strip() for line in f if line.strip()]
172
+ with open(f"{self.input_path}/val.txt", 'r') as f:
173
+ val_sequences = [line.strip() for line in f if line.strip()]
174
+
175
+ print("Processing training data")
176
+ train_dataset = SequenceDataset(train_sequences,
177
+ self.tokenizer,
178
+ self.max_sequence_length)
179
+ print("Processing validation data")
180
+ val_dataset = SequenceDataset(val_sequences,
181
+ self.tokenizer,
182
+ self.max_sequence_length)
183
+
184
+ processed_train = train_dataset.process_sequences(self.batch_size)
185
+ processed_val = val_dataset.process_sequences(self.batch_size)
186
+
187
+ print("Combining datasets")
188
+ combined_dataset = DatasetDict({
189
+ 'train': processed_train,
190
+ 'val': processed_val,
191
+ })
192
+
193
+ print(f"Saving dataset to {self.output_path}")
194
+ combined_dataset.save_to_disk(self.output_path)
195
+
196
+ def setup(self, stage: str):
197
+ print("Loading processed dataset")
198
+ dataset = load_from_disk(self.output_path)
199
+ self.train_dataset = DynamicBatchingDataset(dataset['train'])
200
+ self.val_dataset = DynamicBatchingDataset(dataset['val'])
201
+
202
+ def train_dataloader(self):
203
+ print("Creating training dataloader")
204
+ return DataLoader(self.train_dataset,
205
+ batch_size=1,
206
+ shuffle=False,
207
+ num_workers=self.num_workers,
208
+ collate_fn=DynamicBatchingDataset.collate_fn,
209
+ pin_memory=True)
210
+
211
+ def val_dataloader(self):
212
+ print("Creating validation dataloader")
213
+ return DataLoader(self.val_dataset,
214
+ batch_size=1,
215
+ shuffle=False,
216
+ num_workers=self.num_workers,
217
+ collate_fn=DynamicBatchingDataset.collate_fn,
218
+ pin_memory=True)
219
+
220
+
221
+ if __name__ == '__main__':
222
+ tokenizer = SMILES_SPE_Tokenizer('/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_vocab.txt',
223
+ '/home/st512/peptune/scripts/peptide-mdlm-mcts/tokenizer/new_splits.txt')
224
+ dm = PretrainSequenceDataModule(
225
+ tokenizer=tokenizer,
226
+ input_dataset_path='/home/yz927/projects/peptune/tokens/11M_smiles',
227
+ output_dataset_path='/home/yz927/projects/peptune/tokenized/11M_smiles_old_tokenizer_no_limit',
228
+ num_workers=8,
229
+ batch_size=2000,
230
+ max_sequence_length=16*1000,
231
+ )
232
+ dm.prepare_data()
233
+ dm.setup('fit')
234
+ dm.train_dataloader()
235
+ dm.val_dataloader()
data_preprocessing/data_split.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rdkit import Chem
2
+ from rdkit.Chem import AllChem
3
+ from rdkit import DataStructs
4
+ import numpy as np
5
+ from sklearn.cluster import MiniBatchKMeans
6
+ from collections import defaultdict
7
+ from tqdm import tqdm
8
+ import selfies as sf
9
+ from multiprocessing import Pool, cpu_count
10
+ from functools import partial
11
+ def generate_fingerprint_batch_selfies(selfies_batch):
12
+ fps = []
13
+ valid_selfies = []
14
+
15
+ for selfies in tqdm(selfies_batch, desc="Generating fingerprints", leave=False):
16
+ try:
17
+ # Convert SELFIES to SMILES then to molecule
18
+ smiles = sf.decoder(selfies)
19
+ mol = Chem.MolFromSmiles(smiles)
20
+ if mol is not None:
21
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, 2048)
22
+ arr = np.zeros((1,))
23
+ DataStructs.ConvertToNumpyArray(fp, arr)
24
+ fps.append(arr)
25
+ valid_selfies.append(selfies)
26
+ except:
27
+ continue
28
+
29
+ return np.array(fps), valid_selfies
30
+
31
+ def process_batch(batch, n_clusters, seed):
32
+ fps, valid_selfies = generate_fingerprint_batch_selfies(batch)
33
+ if len(fps) > 0:
34
+ clusterer = MiniBatchKMeans(n_clusters=n_clusters, random_state=seed)
35
+ clusterer.fit(fps)
36
+ labels = clusterer.predict(fps)
37
+ return list(zip(labels, valid_selfies))
38
+ return []
39
+
40
+ def parallel_clustering_split_selfies(selfies_list, batch_size=10000, n_clusters=1000, train_ratio=0.9, seed=42):
41
+ np.random.seed(seed)
42
+
43
+ # Create batches
44
+ batches = [selfies_list[i:i + batch_size]
45
+ for i in range(0, len(selfies_list), batch_size)]
46
+
47
+ # Initialize parallel processing
48
+ n_cores = 12
49
+ process_batch_partial = partial(process_batch, n_clusters=n_clusters, seed=seed)
50
+
51
+ cluster_assignments = defaultdict(list)
52
+ with Pool(n_cores) as pool:
53
+ results = list(tqdm(
54
+ pool.imap(process_batch_partial, batches),
55
+ total=len(batches),
56
+ desc="Processing batches"
57
+ ))
58
+
59
+ # Combine results
60
+ for batch_results in results:
61
+ for label, selfies in batch_results:
62
+ cluster_assignments[label].append(selfies)
63
+
64
+ # Split into train/val
65
+ clusters = list(cluster_assignments.values())
66
+ np.random.shuffle(clusters)
67
+
68
+ train_selfies = []
69
+ val_selfies = []
70
+ total_mols = sum(len(cluster) for cluster in clusters)
71
+
72
+ for cluster in tqdm(clusters, desc="Splitting clusters"):
73
+ if len(train_selfies) / total_mols < train_ratio:
74
+ train_selfies.extend(cluster)
75
+ else:
76
+ val_selfies.extend(cluster)
77
+
78
+ print(f"Final splits: Train={len(train_selfies)}, Validation={len(val_selfies)}")
79
+ return train_selfies, val_selfies
80
+
81
+ try:
82
+ with open('/home/yz927/projects/peptune/tokens/filtered_peptides_selfies.txt', 'r') as f:
83
+ selfies_list = [line.strip() for line in f if line.strip()]
84
+ print(f"Loaded {len(selfies_list)} selfies sequences from file")
85
+ except FileNotFoundError:
86
+ raise FileNotFoundError(f"Could not find the file at file")
87
+ except Exception as e:
88
+ raise Exception(f"Error reading file: {str(e)}")
89
+
90
+ train_selfies, val_selfies = parallel_clustering_split_selfies(
91
+ selfies_list,
92
+ batch_size=10000,
93
+ n_clusters=1000,
94
+ train_ratio=0.8
95
+ )
96
+ with open('/home/yz927/projects/peptune/tokens/11M_selfies/train_selfies.txt', 'w') as f:
97
+ for line in train_selfies:
98
+ f.write(f"{line}\n")
99
+ with open('/home/yz927/projects/peptune/tokens/11M_selfies/val_selfies.txt', 'w') as f:
100
+ for line in val_selfies:
101
+ f.write(f"{line}\n")