sungo-ganpare commited on
Commit
d4575dc
·
1 Parent(s): 21b4fcb

オーディオファイルの前処理と文字起こし機能を改善し、エラーハンドリングを強化

Browse files
Files changed (1) hide show
  1. app.py +194 -137
app.py CHANGED
@@ -69,27 +69,28 @@ def get_audio_segment(audio_path, start_second, end_second):
69
  print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
70
  return None
71
 
72
- @spaces.GPU
73
- def get_transcripts_and_raw_times(audio_path, session_dir):
74
- if not audio_path:
75
- gr.Error("No audio file path provided for transcription.", duration=None)
76
- return [], [], [], None, gr.DownloadButton(visible=False)
77
 
78
- vis_data = [["N/A", "N/A", "Processing failed"]]
79
- raw_times_data = [[0.0, 0.0]]
80
- char_vis_data = []
81
- processed_audio_path = None
82
- original_path_name = Path(audio_path).name
83
- audio_name = Path(audio_path).stem
84
 
 
 
 
85
  try:
 
 
 
86
  try:
87
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
88
  audio = AudioSegment.from_file(audio_path)
89
  duration_sec = audio.duration_seconds
90
  except Exception as load_e:
91
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
92
- return [["Error", "Error", "Load failed"]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
93
 
94
  resampled = False
95
  mono = False
@@ -101,7 +102,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
101
  resampled = True
102
  except Exception as resample_e:
103
  gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
104
- return [["Error", "Error", "Resample failed"]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
105
 
106
  if audio.channels == 2:
107
  try:
@@ -109,11 +110,12 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
109
  mono = True
110
  except Exception as mono_e:
111
  gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
112
- return [["Error", "Error", "Mono conversion failed"]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
113
  elif audio.channels > 2:
114
  gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
115
- return [["Error", "Error", f"{audio.channels}-channel audio not supported"]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
116
 
 
117
  if resampled or mono:
118
  try:
119
  processed_audio_path = Path(session_dir, f"{audio_name}_resampled.wav")
@@ -124,134 +126,189 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
124
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
125
  if processed_audio_path and os.path.exists(processed_audio_path):
126
  os.remove(processed_audio_path)
127
- return [["Error", "Error", "Export failed"]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
128
  else:
129
  transcribe_path = audio_path
130
  info_path_name = original_path_name
131
 
132
- long_audio_settings_applied = False
133
- try:
134
- model.to(device)
135
- model.to(torch.float32)
136
- gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- if duration_sec > 480:
139
- try:
140
- gr.Info("Audio longer than 8 minutes. Applying optimized settings for long transcription.", duration=3)
141
- print("Applying long audio settings: Local Attention and Chunking.")
142
- model.change_attention_model("rel_pos_local_attn", [256,256])
143
- model.change_subsampling_conv_chunking_factor(1)
144
- long_audio_settings_applied = True
145
- except Exception as setting_e:
146
- gr.Warning(f"Could not apply long audio settings: {setting_e}", duration=5)
147
- print(f"Warning: Failed to apply long audio settings: {setting_e}")
148
-
149
- model.to(torch.bfloat16)
150
- output = model.transcribe([transcribe_path], timestamps=True)
151
-
152
- if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
153
- gr.Error("Transcription failed or produced unexpected output format.", duration=None)
154
- return [["Error", "Error", "Transcription Format Issue"]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
155
-
156
- segment_timestamps = output[0].timestamp['segment']
157
- csv_headers = ["Start (s)", "End (s)", "Segment"]
158
- vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
159
- raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
160
-
161
- char_timestamps_raw = output[0].timestamp.get("char", [])
162
- if not isinstance(char_timestamps_raw, list):
163
- print(f"Warning: char_timestamps_raw is not a list, but {type(char_timestamps_raw)}. Defaulting to empty.")
164
- char_timestamps_raw = []
165
- char_vis_data = [
166
- [f"{c['start']:.2f}", f"{c['end']:.2f}", c["char"]]
167
- for c in char_timestamps_raw if isinstance(c, dict) and 'start' in c and 'end' in c and 'char' in c
168
- ]
169
-
170
- word_timestamps_raw = output[0].timestamp.get("word", [])
171
- word_vis_data = [
172
- [f"{w['start']:.2f}", f"{w['end']:.2f}", w["word"]]
173
- for w in word_timestamps_raw if isinstance(w, dict) and 'start' in w and 'end' in w and 'word' in w
174
- ]
175
-
176
- button_update = gr.DownloadButton(visible=False)
177
- srt_file_path = None
178
- vtt_file_path = None
179
- json_file_path = None
180
- lrc_file_path = None
181
  try:
182
- csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
183
- with open(csv_file_path, 'w', newline='', encoding='utf-8') as f:
184
- writer = csv.writer(f)
185
- writer.writerow(csv_headers)
186
- writer.writerows(vis_data)
187
- print(f"CSV transcript saved to temporary file: {csv_file_path}")
188
- button_update = gr.DownloadButton(value=csv_file_path.as_posix(), visible=True)
189
-
190
- srt_file_path = Path(session_dir, f"transcription_{audio_name}.srt")
191
- vtt_file_path = Path(session_dir, f"transcription_{audio_name}.vtt")
192
- json_file_path = Path(session_dir, f"transcription_{audio_name}.json")
193
- write_srt(vis_data, srt_file_path)
194
- write_vtt(vis_data, word_vis_data, vtt_file_path)
195
- write_json(vis_data, word_vis_data, json_file_path)
196
- print(f"SRT, VTT, JSON transcript saved to temporary files: {srt_file_path}, {vtt_file_path}, {json_file_path}")
197
-
198
- lrc_file_path = Path(session_dir, f"transcription_{audio_name}.lrc")
199
- write_lrc(vis_data, lrc_file_path)
200
- print(f"LRC transcript saved to temporary file: {lrc_file_path}")
201
- except Exception as csv_e:
202
- gr.Error(f"Failed to create transcript files: {csv_e}", duration=None)
203
- print(f"Error writing transcript files: {csv_e}")
204
-
205
- gr.Info("Transcription complete.", duration=2)
206
- return (
207
- vis_data,
208
- raw_times_data,
209
- word_vis_data,
210
- audio_path,
211
- gr.DownloadButton(value=csv_file_path.as_posix(), visible=True),
212
- gr.DownloadButton(value=srt_file_path.as_posix(), visible=True),
213
- gr.DownloadButton(value=vtt_file_path.as_posix(), visible=True),
214
- gr.DownloadButton(value=json_file_path.as_posix(), visible=True),
215
- gr.DownloadButton(value=lrc_file_path.as_posix(), visible=True)
216
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
- except torch.cuda.OutOfMemoryError as e:
219
- error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
220
- print(f"CUDA OutOfMemoryError: {e}")
221
- gr.Error(error_msg, duration=None)
222
- return [["OOM", "OOM", error_msg]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
223
-
224
- except FileNotFoundError:
225
- error_msg = f"Audio file for transcription not found: {Path(transcribe_path).name}."
226
- print(f"Error: Transcribe audio file not found at path: {transcribe_path}")
227
- gr.Error(error_msg, duration=None)
228
- return [["Error", "Error", "File not found for transcription"]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
229
-
230
- except Exception as e:
231
- error_msg = f"Transcription failed: {e}"
232
- print(f"Error during transcription processing: {e}")
233
- gr.Error(error_msg, duration=None)
234
- return [["Error", "Error", error_msg]], [[0.0, 0.0]], [], audio_path, gr.DownloadButton(visible=False)
235
- finally:
236
- try:
237
- if long_audio_settings_applied:
238
- try:
239
- print("Reverting long audio settings.")
240
- model.change_attention_model("rel_pos")
241
- model.change_subsampling_conv_chunking_factor(-1)
242
- except Exception as revert_e:
243
- print(f"Warning: Failed to revert long audio settings: {revert_e}")
244
- gr.Warning(f"Issue reverting model settings after long transcription: {revert_e}", duration=5)
245
-
246
- if 'model' in locals() and hasattr(model, 'cpu'):
247
- if device == 'cuda':
248
- model.cpu()
249
- gc.collect()
250
- if device == 'cuda':
251
- torch.cuda.empty_cache()
252
- except Exception as cleanup_e:
253
- print(f"Error during model cleanup: {cleanup_e}")
254
- gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
255
  finally:
256
  if processed_audio_path and os.path.exists(processed_audio_path):
257
  try:
@@ -489,4 +546,4 @@ with gr.Blocks(theme=nvidia_theme) as demo:
489
  if __name__ == "__main__":
490
  print("Launching Gradio Demo...")
491
  demo.queue()
492
- demo.launch()
 
69
  print(f"Error clipping audio {audio_path} from {start_second}s to {end_second}s: {e}")
70
  return None
71
 
72
+ def preprocess_audio(audio_path, session_dir):
73
+ """
74
+ オーディオファイルの前処理(リサンプリング、モノラル変換)を行う。
 
 
75
 
76
+ Args:
77
+ audio_path (str): 入力オーディオファイルのパス。
78
+ session_dir (str): セッションディレクトリのパス。
 
 
 
79
 
80
+ Returns:
81
+ tuple: (processed_path, info_path_name, duration_sec) のタプル、または None(処理に失敗した場合)。
82
+ """
83
  try:
84
+ original_path_name = Path(audio_path).name
85
+ audio_name = Path(audio_path).stem
86
+
87
  try:
88
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
89
  audio = AudioSegment.from_file(audio_path)
90
  duration_sec = audio.duration_seconds
91
  except Exception as load_e:
92
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
93
+ return None, None, None
94
 
95
  resampled = False
96
  mono = False
 
102
  resampled = True
103
  except Exception as resample_e:
104
  gr.Error(f"Failed to resample audio: {resample_e}", duration=None)
105
+ return None, None, None
106
 
107
  if audio.channels == 2:
108
  try:
 
110
  mono = True
111
  except Exception as mono_e:
112
  gr.Error(f"Failed to convert audio to mono: {mono_e}", duration=None)
113
+ return None, None, None
114
  elif audio.channels > 2:
115
  gr.Error(f"Audio has {audio.channels} channels. Only mono (1) or stereo (2) supported.", duration=None)
116
+ return None, None, None
117
 
118
+ processed_audio_path = None
119
  if resampled or mono:
120
  try:
121
  processed_audio_path = Path(session_dir, f"{audio_name}_resampled.wav")
 
126
  gr.Error(f"Failed to export processed audio: {export_e}", duration=None)
127
  if processed_audio_path and os.path.exists(processed_audio_path):
128
  os.remove(processed_audio_path)
129
+ return None, None, None
130
  else:
131
  transcribe_path = audio_path
132
  info_path_name = original_path_name
133
 
134
+ return transcribe_path, info_path_name, duration_sec
135
+ except Exception as e:
136
+ gr.Error(f"Audio preprocessing failed: {e}", duration=None)
137
+ return None, None, None
138
+
139
+ def transcribe_audio(transcribe_path, model, duration_sec, device):
140
+ """
141
+ オーディオファイルを文字起こしし、タイムスタンプを取得する。
142
+
143
+ Args:
144
+ transcribe_path (str): 入力オーディオファイルのパス。
145
+ model (ASRModel): 使用するASRモデル。
146
+ duration_sec (float): オーディオファイルの長さ(秒)。
147
+ device (str): 使用するデバイス('cuda' or 'cpu')。
148
+
149
+ Returns:
150
+ tuple: (vis_data, raw_times_data, word_vis_data) のタプル、または None(処理に失敗した場合)。
151
+ """
152
+ long_audio_settings_applied = False
153
+ try:
154
+ model.to(device)
155
+ model.to(torch.float32)
156
+ gr.Info(f"Transcribing on {device}...", duration=2)
157
 
158
+ if duration_sec > 480:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  try:
160
+ gr.Info("Audio longer than 8 minutes. Applying optimized settings for long transcription.", duration=3)
161
+ print("Applying long audio settings: Local Attention and Chunking.")
162
+ model.change_attention_model("rel_pos_local_attn", [256,256])
163
+ model.change_subsampling_conv_chunking_factor(1)
164
+ long_audio_settings_applied = True
165
+ except Exception as setting_e:
166
+ gr.Warning(f"Could not apply long audio settings: {setting_e}", duration=5)
167
+ print(f"Warning: Failed to apply long audio settings: {setting_e}")
168
+
169
+ model.to(torch.bfloat16)
170
+ output = model.transcribe([transcribe_path], timestamps=True)
171
+
172
+ if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
173
+ gr.Error("Transcription failed or produced unexpected output format.", duration=None)
174
+ return None, None, None
175
+
176
+ segment_timestamps = output[0].timestamp['segment']
177
+ vis_data = [[f"{ts['start']:.2f}", f"{ts['end']:.2f}", ts['segment']] for ts in segment_timestamps]
178
+ raw_times_data = [[ts['start'], ts['end']] for ts in segment_timestamps]
179
+
180
+ word_timestamps_raw = output[0].timestamp.get("word", [])
181
+ word_vis_data = [
182
+ [f"{w['start']:.2f}", f"{w['end']:.2f}", w["word"]]
183
+ for w in word_timestamps_raw if isinstance(w, dict) and 'start' in w and 'end' in w and 'word' in w
184
+ ]
185
+
186
+ gr.Info("Transcription complete.", duration=2)
187
+ return vis_data, raw_times_data, word_vis_data
188
+
189
+ except torch.cuda.OutOfMemoryError as e:
190
+ error_msg = 'CUDA out of memory. Please try a shorter audio or reduce GPU load.'
191
+ print(f"CUDA OutOfMemoryError: {e}")
192
+ gr.Error(error_msg, duration=None)
193
+ return None, None, None
194
+
195
+ except Exception as e:
196
+ error_msg = f"Transcription failed: {e}"
197
+ print(f"Error during transcription processing: {e}")
198
+ gr.Error(error_msg, duration=None)
199
+ return None, None, None
200
+
201
+ finally:
202
+ try:
203
+ if long_audio_settings_applied:
204
+ try:
205
+ print("Reverting long audio settings.")
206
+ model.change_attention_model("rel_pos")
207
+ model.change_subsampling_conv_chunking_factor(-1)
208
+ except Exception as revert_e:
209
+ print(f"Warning: Failed to revert long audio settings: {revert_e}")
210
+ gr.Warning(f"Issue reverting model settings after long transcription: {revert_e}", duration=5)
211
+
212
+ if device == 'cuda':
213
+ model.cpu()
214
+ gc.collect()
215
+ if device == 'cuda':
216
+ torch.cuda.empty_cache()
217
+ except Exception as cleanup_e:
218
+ print(f"Error during model cleanup: {cleanup_e}")
219
+ gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
220
+
221
+ def save_transcripts(session_dir, audio_name, vis_data, word_vis_data):
222
+ """
223
+ 文字起こし結果を各種ファイル形式(CSV、SRT、VTT、JSON、LRC)で保存する。
224
+
225
+ Args:
226
+ session_dir (str): セッションディレクトリのパス。
227
+ audio_name (str): オーディオファイルの名前。
228
+ vis_data (list): 表示用の文字起こし結果のリスト。
229
+ word_vis_data (list): 単語レベルのタイムスタンプのリスト。
230
+
231
+ Returns:
232
+ tuple: 各ファイルのダウンロードボタンの更新情報を含むタプル。
233
+ """
234
+ try:
235
+ csv_headers = ["Start (s)", "End (s)", "Segment"]
236
+ csv_file_path = Path(session_dir, f"transcription_{audio_name}.csv")
237
+ with open(csv_file_path, 'w', newline='', encoding='utf-8') as f:
238
+ writer = csv.writer(f)
239
+ writer.writerow(csv_headers)
240
+ writer.writerows(vis_data)
241
+ print(f"CSV transcript saved to temporary file: {csv_file_path}")
242
+
243
+ srt_file_path = Path(session_dir, f"transcription_{audio_name}.srt")
244
+ vtt_file_path = Path(session_dir, f"transcription_{audio_name}.vtt")
245
+ json_file_path = Path(session_dir, f"transcription_{audio_name}.json")
246
+ write_srt(vis_data, srt_file_path)
247
+ write_vtt(vis_data, word_vis_data, vtt_file_path)
248
+ write_json(vis_data, word_vis_data, json_file_path)
249
+ print(f"SRT, VTT, JSON transcript saved to temporary files: {srt_file_path}, {vtt_file_path}, {json_file_path}")
250
+
251
+ lrc_file_path = Path(session_dir, f"transcription_{audio_name}.lrc")
252
+ write_lrc(vis_data, lrc_file_path)
253
+ print(f"LRC transcript saved to temporary file: {lrc_file_path}")
254
+
255
+ return (
256
+ gr.DownloadButton(value=csv_file_path.as_posix(), visible=True),
257
+ gr.DownloadButton(value=srt_file_path.as_posix(), visible=True),
258
+ gr.DownloadButton(value=vtt_file_path.as_posix(), visible=True),
259
+ gr.DownloadButton(value=json_file_path.as_posix(), visible=True),
260
+ gr.DownloadButton(value=lrc_file_path.as_posix(), visible=True)
261
+ )
262
+ except Exception as e:
263
+ gr.Error(f"Failed to create transcript files: {e}", duration=None)
264
+ print(f"Error writing transcript files: {e}")
265
+ return tuple([gr.DownloadButton(visible=False)] * 5)
266
+
267
+ @spaces.GPU
268
+ def get_transcripts_and_raw_times(audio_path, session_dir):
269
+ """
270
+ オーディオファイルを処理し、文字起こし結果を生成する。
271
+
272
+ Args:
273
+ audio_path (str): 入力オーディオファイルのパス。
274
+ session_dir (str): セッションディレクトリのパス。
275
+
276
+ Returns:
277
+ tuple: 文字起こし結果と関連データを含むタプル。
278
+ """
279
+ if not audio_path:
280
+ gr.Error("No audio file path provided for transcription.", duration=None)
281
+ return [], [], [], None, gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False)
282
+
283
+ audio_name = Path(audio_path).stem
284
+ processed_audio_path = None
285
+
286
+ try:
287
+ # オーディオの前処理
288
+ transcribe_path, info_path_name, duration_sec = preprocess_audio(audio_path, session_dir)
289
+ if not transcribe_path or not duration_sec:
290
+ return [], [], [], audio_path, gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False)
291
+
292
+ processed_audio_path = transcribe_path if transcribe_path != audio_path else None
293
+
294
+ # 文字起こしの実行
295
+ result = transcribe_audio(transcribe_path, model, duration_sec, device)
296
+ if not result:
297
+ return [], [], [], audio_path, gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False), gr.DownloadButton(visible=False)
298
+
299
+ vis_data, raw_times_data, word_vis_data = result
300
+
301
+ # ファイルの保存
302
+ button_updates = save_transcripts(session_dir, audio_name, vis_data, word_vis_data)
303
+
304
+ return (
305
+ vis_data,
306
+ raw_times_data,
307
+ word_vis_data,
308
+ audio_path,
309
+ *button_updates
310
+ )
311
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  finally:
313
  if processed_audio_path and os.path.exists(processed_audio_path):
314
  try:
 
546
  if __name__ == "__main__":
547
  print("Launching Gradio Demo...")
548
  demo.queue()
549
+ demo.launch()