xieli commited on
Commit
781d823
Β·
1 Parent(s): 245708d

feat: support whisper asr

Browse files

feat: default enable auto transcribe

Files changed (2) hide show
  1. app.py +54 -1
  2. whisper_wrapper.py +75 -0
app.py CHANGED
@@ -28,6 +28,7 @@ from tokenizer import StepAudioTokenizer
28
  from tts import StepAudioTTS
29
  from model_loader import ModelSource
30
  from config.edit_config import get_supported_edit_types
 
31
 
32
  # Configure logging
33
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
@@ -36,12 +37,13 @@ logger = logging.getLogger(__name__)
36
  # Global variables for ZeroGPU-optimized loading
37
  encoder = None
38
  common_tts_engine = None
 
39
  args_global = None
40
  _model_lock = threading.Lock() # Thread lock for model initialization
41
 
42
  def initialize_models():
43
  """Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
44
- global encoder, common_tts_engine, args_global
45
 
46
  # Fast path: check if already initialized (without lock)
47
  if common_tts_engine is not None:
@@ -87,6 +89,12 @@ def initialize_models():
87
  device_map=args_global.device_map,
88
  )
89
  logger.info("βœ“ StepCommonAudioTTS loaded")
 
 
 
 
 
 
90
  print("Models initialized inside GPU context.")
91
 
92
  if ZEROGPU_AVAILABLE:
@@ -178,6 +186,7 @@ class EditxTab:
178
  self.args = args
179
  self.edit_type_list = list(get_supported_edit_types().keys())
180
  self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
 
181
 
182
  def history_messages_to_show(self, messages):
183
  """Convert message history to gradio chatbot format"""
@@ -415,6 +424,14 @@ class EditxTab:
415
  outputs=self.edit_info,
416
  )
417
 
 
 
 
 
 
 
 
 
418
  def update_edit_info(self, category):
419
  """Update sub-task dropdown based on main task selection"""
420
  category_items = get_supported_edit_types()
@@ -422,6 +439,36 @@ class EditxTab:
422
  value = None if len(choices) == 0 else choices[0]
423
  return gr.Dropdown(label="Sub-task", choices=choices, value=value)
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  def launch_demo(args, editx_tab):
427
  """Launch the gradio demo"""
@@ -503,6 +550,12 @@ if __name__ == "__main__":
503
  default="cuda",
504
  help="Device mapping for model loading (default: cuda)"
505
  )
 
 
 
 
 
 
506
 
507
  args = parser.parse_args()
508
 
 
28
  from tts import StepAudioTTS
29
  from model_loader import ModelSource
30
  from config.edit_config import get_supported_edit_types
31
+ from whisper_wrapper import WhisperWrapper
32
 
33
  # Configure logging
34
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
37
  # Global variables for ZeroGPU-optimized loading
38
  encoder = None
39
  common_tts_engine = None
40
+ whisper_asr = None
41
  args_global = None
42
  _model_lock = threading.Lock() # Thread lock for model initialization
43
 
44
  def initialize_models():
45
  """Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
46
+ global encoder, common_tts_engine, whisper_asr, args_global
47
 
48
  # Fast path: check if already initialized (without lock)
49
  if common_tts_engine is not None:
 
89
  device_map=args_global.device_map,
90
  )
91
  logger.info("βœ“ StepCommonAudioTTS loaded")
92
+
93
+ # Initialize Whisper ASR (load outside GPU context, lighter model)
94
+ if whisper_asr is None:
95
+ whisper_asr = WhisperWrapper()
96
+ logger.info("βœ“ WhisperWrapper loaded")
97
+
98
  print("Models initialized inside GPU context.")
99
 
100
  if ZEROGPU_AVAILABLE:
 
186
  self.args = args
187
  self.edit_type_list = list(get_supported_edit_types().keys())
188
  self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
189
+ self.enable_auto_transcribe = getattr(args, 'enable_auto_transcribe', False)
190
 
191
  def history_messages_to_show(self, messages):
192
  """Convert message history to gradio chatbot format"""
 
424
  outputs=self.edit_info,
425
  )
426
 
427
+ # Add audio transcription event only if enabled
428
+ if self.enable_auto_transcribe:
429
+ self.prompt_audio_input.change(
430
+ fn=self.transcribe_audio,
431
+ inputs=[self.prompt_audio_input, self.prompt_text_input],
432
+ outputs=self.prompt_text_input,
433
+ )
434
+
435
  def update_edit_info(self, category):
436
  """Update sub-task dropdown based on main task selection"""
437
  category_items = get_supported_edit_types()
 
439
  value = None if len(choices) == 0 else choices[0]
440
  return gr.Dropdown(label="Sub-task", choices=choices, value=value)
441
 
442
+ def transcribe_audio(self, audio_input, current_text):
443
+ """Transcribe audio using Whisper ASR when prompt text is empty"""
444
+ global whisper_asr
445
+
446
+ # Only transcribe if current text is empty
447
+ if current_text and current_text.strip():
448
+ return current_text # Keep existing text
449
+
450
+ if not audio_input:
451
+ return "" # No audio to transcribe
452
+
453
+ try:
454
+ # Initialize whisper if not already loaded
455
+ if whisper_asr is None:
456
+ if args_global is None:
457
+ self.logger.error("Global args not set. Cannot initialize Whisper.")
458
+ return ""
459
+
460
+ whisper_asr = WhisperWrapper()
461
+ self.logger.info("βœ“ WhisperWrapper initialized for ASR")
462
+
463
+ # Transcribe audio
464
+ transcribed_text = whisper_asr(audio_input)
465
+ self.logger.info(f"Audio transcribed: {transcribed_text}")
466
+ return transcribed_text
467
+
468
+ except Exception as e:
469
+ self.logger.error(f"Failed to transcribe audio: {e}")
470
+ return ""
471
+
472
 
473
  def launch_demo(args, editx_tab):
474
  """Launch the gradio demo"""
 
550
  default="cuda",
551
  help="Device mapping for model loading (default: cuda)"
552
  )
553
+ parser.add_argument(
554
+ "--enable-auto-transcribe",
555
+ action="store_true",
556
+ help="Enable automatic audio transcription when uploading audio files (default: disabled)"
557
+ )
558
+ parser.set_defaults(enable_auto_transcribe=True)
559
 
560
  args = parser.parse_args()
561
 
whisper_wrapper.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torchaudio
4
+ from transformers import pipeline
5
+
6
+
7
+ class WhisperWrapper:
8
+ """Simplified Whisper ASR wrapper"""
9
+
10
+ def __init__(self, model_id="openai/whisper-large-v3"):
11
+ """
12
+ Initialize WhisperWrapper
13
+
14
+ Args:
15
+ model_id: Whisper model ID, default uses openai/whisper-large-v3
16
+ """
17
+ self.logger = logging.getLogger(__name__)
18
+ self.model = None
19
+
20
+ try:
21
+ self.model = pipeline("automatic-speech-recognition", model=model_id)
22
+ self.logger.info(f"βœ“ Whisper model loaded successfully: {model_id}")
23
+ except Exception as e:
24
+ self.logger.error(f"❌ Failed to load Whisper model: {e}")
25
+ raise
26
+
27
+ def __call__(self, audio_input):
28
+ """
29
+ Audio to text transcription
30
+
31
+ Args:
32
+ audio_input: Audio file path or audio tensor
33
+
34
+ Returns:
35
+ Transcribed text
36
+ """
37
+ if self.model is None:
38
+ raise RuntimeError("Whisper model not loaded")
39
+
40
+ try:
41
+ # Load audio
42
+ if isinstance(audio_input, str):
43
+ # Audio file path
44
+ audio, audio_sr = torchaudio.load(audio_input)
45
+ audio = torchaudio.functional.resample(audio, audio_sr, 16000)
46
+ # Handle stereo to mono conversion (pipeline may not handle this)
47
+ if audio.shape[0] > 1:
48
+ audio = audio.mean(dim=0, keepdim=True) # Convert stereo to mono by averaging
49
+ # Convert to numpy and squeeze
50
+ audio = audio.squeeze(0).numpy()
51
+ elif isinstance(audio_input, torch.Tensor):
52
+ # Tensor input
53
+ audio = audio_input.cpu()
54
+ audio = torchaudio.functional.resample(audio, audio_sr, 16000)
55
+ # Handle stereo to mono conversion
56
+ if audio.ndim > 1 and audio.shape[0] > 1:
57
+ audio = audio.mean(dim=0, keepdim=True)
58
+ audio = audio.squeeze().numpy()
59
+ else:
60
+ raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
61
+
62
+ # Transcribe
63
+ result = self.model(audio)
64
+ text = result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
65
+
66
+ self.logger.debug(f"Transcription result: {text}")
67
+ return text
68
+
69
+ except Exception as e:
70
+ self.logger.error(f"Audio transcription failed: {e}")
71
+ return ""
72
+
73
+ def is_available(self):
74
+ """Check if whisper model is available"""
75
+ return self.model is not None