internalhell commited on
Commit
d45c3c7
·
verified ·
1 Parent(s): 7561bca

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +210 -6
README.md CHANGED
@@ -67,6 +67,209 @@ The following hyperparameters were used during training:
67
  - num_epochs: 3
68
  - mixed_precision_training: Native AMP
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ### Training results
71
 
72
  | Training Loss | Epoch | Step | Cer | Validation Loss | Ser | Wer |
@@ -84,12 +287,13 @@ The following hyperparameters were used during training:
84
  | 0.0609 | 1.6677 | 5500 | 4.4298 | 0.2077 | 59.3355 | 16.8546 |
85
  | 0.0721 | 1.8193 | 6000 | 4.3442 | 0.2060 | 58.6592 | 16.5527 |
86
  | 0.0681 | 1.9709 | 6500 | 4.3284 | 0.2038 | 58.1692 | 16.3575 |
87
- | 0.0322 | 2.1225 | 7000 | 0.2130 | 16.2809 | 4.2709 | 57.7771 |
88
- | 0.0277 | 2.2741 | 7500 | 0.2151 | 16.1067 | 4.2543 | 57.4733 |
89
- | 0.0249 | 2.4257 | 8000 | 0.2130 | 16.0741 | 4.2513 | 57.4635 |
90
- | 0.0234 | 2.5773 | 8500 | 0.2150 | 16.2600 | 4.2832 | 57.6693 |
91
- | 0.0264 | 2.7289 | 9000 | 0.2145 | 16.1160 | 4.2645 | 57.6301 |
92
- | 0.0268 | 2.8805 | 9500 | 0.2125 | 16.0405 | 4.2321 | 57.5223 |
 
93
 
94
 
95
  ### Framework versions
 
67
  - num_epochs: 3
68
  - mixed_precision_training: Native AMP
69
 
70
+ ### Training code
71
+
72
+ ```bash
73
+ pip install transformers evaluate soundfile
74
+ pip install -q jiwer tensorboard
75
+ pip install --upgrade datasets transformers
76
+ ```
77
+
78
+ ```python
79
+ import re
80
+ import json
81
+ from datasets import load_dataset, DatasetDict, Audio
82
+ from transformers import WhisperForConditionalGeneration, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, Seq2SeqTrainingArguments, Seq2SeqTrainer
83
+ import os, numpy as np, torch, evaluate, jiwer
84
+ from huggingface_hub import login
85
+ from dataclasses import dataclass
86
+ from typing import Any, Dict, List, Union
87
+
88
+ login("***")
89
+
90
+
91
+ common_voice = DatasetDict()
92
+ common_voice["train"] = load_dataset("mozilla-foundation/common_voice_17_0", "ru", split="train")
93
+ common_voice["test"] = load_dataset("mozilla-foundation/common_voice_17_0", "ru", split="test")
94
+
95
+ common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])
96
+ common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))
97
+
98
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
99
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Russian", task="transcribe")
100
+ processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Russian", task="transcribe")
101
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
102
+ model.config.forced_decoder_ids = None
103
+ model.config.suppress_tokens = []
104
+ model.config.use_cache = False
105
+
106
+ def prepare_dataset(batch):
107
+ audio = batch["audio"]
108
+
109
+ batch["input_features"] = feature_extractor(
110
+ audio["array"],
111
+ sampling_rate=audio["sampling_rate"]
112
+ ).input_features[0]
113
+
114
+ batch["labels"] = tokenizer(batch["sentence"]).input_ids
115
+ return batch
116
+
117
+ common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=2 )
118
+
119
+ common_voice
120
+
121
+ wer_metric = evaluate.load("wer")
122
+ cer_metric = evaluate.load("cer")
123
+
124
+ def compute_metrics(pred):
125
+ pred_ids = pred.predictions
126
+ label_ids = pred.label_ids
127
+
128
+ label_ids[label_ids == -100] = tokenizer.pad_token_id
129
+
130
+ pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
131
+ label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
132
+
133
+ pairs = [(ref.strip(), hyp.strip()) for ref, hyp in zip(label_str, pred_str)]
134
+ pairs = [(ref, hyp) for ref, hyp in pairs if len(ref) > 0]
135
+
136
+ label_str, pred_str = zip(*pairs)
137
+
138
+ wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str)
139
+ cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str)
140
+
141
+ ser = 100 * (sum(p.strip() != r.strip() for p, r in zip(pred_str, label_str)) / len(pred_str))
142
+
143
+ return {
144
+ "wer": wer,
145
+ "cer": cer,
146
+ "ser": ser
147
+ }
148
+
149
+ @dataclass
150
+ class DataCollatorSpeechSeq2SeqWithPadding:
151
+ processor: Any
152
+ decoder_start_token_id: int
153
+
154
+ def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
155
+ input_features = [{"input_features": f["input_features"]} for f in features]
156
+ batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
157
+
158
+ label_features = [{"input_ids": f["labels"]} for f in features]
159
+ labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
160
+
161
+ labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
162
+
163
+ if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
164
+ labels = labels[:, 1:]
165
+
166
+ batch["labels"] = labels
167
+ return batch
168
+
169
+ data_collator = DataCollatorSpeechSeq2SeqWithPadding(
170
+ processor=processor,
171
+ decoder_start_token_id=model.config.decoder_start_token_id,
172
+ )
173
+
174
+ training_args = Seq2SeqTrainingArguments(
175
+ output_dir="/content/drive/MyDrive/models/whisper_small_ru_model_trainer_3ep",
176
+ logging_dir="/content/drive/MyDrive/models/whisper_small_ru_model_trainer_3ep",
177
+ group_by_length=True,
178
+ per_device_train_batch_size=8,
179
+ per_device_eval_batch_size=4,
180
+ eval_strategy="steps",
181
+ logging_strategy="steps",
182
+ save_strategy="steps",
183
+ num_train_epochs=3,
184
+ generation_max_length=170,
185
+ logging_steps=25,
186
+ eval_steps=500,
187
+ save_steps=500,
188
+ fp16=True,
189
+ optim="adamw_torch_fused",
190
+ torch_compile=True,
191
+ gradient_checkpointing=True,
192
+ learning_rate=1e-5,
193
+ report_to=["tensorboard"],
194
+ load_best_model_at_end=True,
195
+ metric_for_best_model="wer",
196
+ greater_is_better=False,
197
+ push_to_hub=False,
198
+ predict_with_generate=True,
199
+ )
200
+
201
+ trainer = Seq2SeqTrainer(
202
+ args=training_args,
203
+ model=model,
204
+ train_dataset=common_voice["train"],
205
+ eval_dataset=common_voice["test"],
206
+ data_collator=data_collator,
207
+ compute_metrics=compute_metrics,
208
+ tokenizer=processor.feature_extractor,
209
+ )
210
+
211
+ trainer.train()
212
+
213
+ ```
214
+
215
+ ### Test result
216
+
217
+ ```python
218
+
219
+ import os
220
+ from transformers import (WhisperProcessor,
221
+ WhisperForConditionalGeneration,
222
+ pipeline)
223
+ import torch
224
+ import torchaudio
225
+ import librosa
226
+ import numpy as np
227
+
228
+ MODEL_HUG = "internalhell/whisper_small_ru_model_trainer_3ep"
229
+
230
+ processor = None
231
+ model = None
232
+ pipe = None
233
+
234
+ def get_model_pipe():
235
+ global model, processor, pipe
236
+ if model is None or processor is None:
237
+ processor = WhisperProcessor.from_pretrained(MODEL_HUG, language="russian")
238
+ model = WhisperForConditionalGeneration.from_pretrained(MODEL_HUG)
239
+
240
+ model.generation_config.forced_decoder_ids = None
241
+ forced_decoder_ids = processor.get_decoder_prompt_ids(language="ru", task="transcribe")
242
+ model.config.forced_decoder_ids = forced_decoder_ids
243
+
244
+ pipe = pipeline(
245
+ "automatic-speech-recognition",
246
+ model=model,
247
+ tokenizer=processor.tokenizer,
248
+ feature_extractor=processor.feature_extractor,
249
+ device=0 if torch.cuda.is_available() else -1,
250
+ )
251
+
252
+ return model
253
+
254
+ def recognize_audio_pipe(audio_path):
255
+ model = get_model_pipe()
256
+
257
+ waveform, sr = torchaudio.load(audio_path)
258
+ waveform = waveform.mean(dim=0, keepdim=True) # моно
259
+
260
+ if sr != 16000:
261
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
262
+ waveform = resampler(waveform)
263
+ sr = 16000
264
+
265
+ waveform_np = waveform.squeeze(0).numpy()
266
+ return pipe({"array": waveform_np, "sampling_rate": sr})["text"]
267
+
268
+ print(recognize_audio_pipe("test.wav")) # jast .wav only
269
+
270
+
271
+ ```
272
+
273
  ### Training results
274
 
275
  | Training Loss | Epoch | Step | Cer | Validation Loss | Ser | Wer |
 
287
  | 0.0609 | 1.6677 | 5500 | 4.4298 | 0.2077 | 59.3355 | 16.8546 |
288
  | 0.0721 | 1.8193 | 6000 | 4.3442 | 0.2060 | 58.6592 | 16.5527 |
289
  | 0.0681 | 1.9709 | 6500 | 4.3284 | 0.2038 | 58.1692 | 16.3575 |
290
+ | 0.0322 | 2.1225 | 7000 | 4.2709 | 0.2130 | 57.7771 | 16.2809 |
291
+ | 0.0277 | 2.2741 | 7500 | 4.2543 | 0.2151 | 57.4733 | 16.1067 |
292
+ | 0.0249 | 2.4257 | 8000 | 4.2513 | 0.2130 | 57.4635 | 16.0741 |
293
+ | 0.0234 | 2.5773 | 8500 | 4.2832 | 0.2150 | 57.6693 | 16.2600 |
294
+ | 0.0264 | 2.7289 | 9000 | 4.2645 | 0.2145 | 57.6301 | 16.1160 |
295
+ | 0.0268 | 2.8805 | 9500 | 4.2321 | 0.2125 | 57.5223 | 16.0405 |
296
+
297
 
298
 
299
  ### Framework versions