Spaces:
Running
on
Zero
Running
on
Zero
xieli
commited on
Commit
Β·
781d823
1
Parent(s):
245708d
feat: support whisper asr
Browse filesfeat: default enable auto transcribe
- app.py +54 -1
- whisper_wrapper.py +75 -0
app.py
CHANGED
|
@@ -28,6 +28,7 @@ from tokenizer import StepAudioTokenizer
|
|
| 28 |
from tts import StepAudioTTS
|
| 29 |
from model_loader import ModelSource
|
| 30 |
from config.edit_config import get_supported_edit_types
|
|
|
|
| 31 |
|
| 32 |
# Configure logging
|
| 33 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
@@ -36,12 +37,13 @@ logger = logging.getLogger(__name__)
|
|
| 36 |
# Global variables for ZeroGPU-optimized loading
|
| 37 |
encoder = None
|
| 38 |
common_tts_engine = None
|
|
|
|
| 39 |
args_global = None
|
| 40 |
_model_lock = threading.Lock() # Thread lock for model initialization
|
| 41 |
|
| 42 |
def initialize_models():
|
| 43 |
"""Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
|
| 44 |
-
global encoder, common_tts_engine, args_global
|
| 45 |
|
| 46 |
# Fast path: check if already initialized (without lock)
|
| 47 |
if common_tts_engine is not None:
|
|
@@ -87,6 +89,12 @@ def initialize_models():
|
|
| 87 |
device_map=args_global.device_map,
|
| 88 |
)
|
| 89 |
logger.info("β StepCommonAudioTTS loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
print("Models initialized inside GPU context.")
|
| 91 |
|
| 92 |
if ZEROGPU_AVAILABLE:
|
|
@@ -178,6 +186,7 @@ class EditxTab:
|
|
| 178 |
self.args = args
|
| 179 |
self.edit_type_list = list(get_supported_edit_types().keys())
|
| 180 |
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
|
|
|
| 181 |
|
| 182 |
def history_messages_to_show(self, messages):
|
| 183 |
"""Convert message history to gradio chatbot format"""
|
|
@@ -415,6 +424,14 @@ class EditxTab:
|
|
| 415 |
outputs=self.edit_info,
|
| 416 |
)
|
| 417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
def update_edit_info(self, category):
|
| 419 |
"""Update sub-task dropdown based on main task selection"""
|
| 420 |
category_items = get_supported_edit_types()
|
|
@@ -422,6 +439,36 @@ class EditxTab:
|
|
| 422 |
value = None if len(choices) == 0 else choices[0]
|
| 423 |
return gr.Dropdown(label="Sub-task", choices=choices, value=value)
|
| 424 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 425 |
|
| 426 |
def launch_demo(args, editx_tab):
|
| 427 |
"""Launch the gradio demo"""
|
|
@@ -503,6 +550,12 @@ if __name__ == "__main__":
|
|
| 503 |
default="cuda",
|
| 504 |
help="Device mapping for model loading (default: cuda)"
|
| 505 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
|
| 507 |
args = parser.parse_args()
|
| 508 |
|
|
|
|
| 28 |
from tts import StepAudioTTS
|
| 29 |
from model_loader import ModelSource
|
| 30 |
from config.edit_config import get_supported_edit_types
|
| 31 |
+
from whisper_wrapper import WhisperWrapper
|
| 32 |
|
| 33 |
# Configure logging
|
| 34 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
|
|
| 37 |
# Global variables for ZeroGPU-optimized loading
|
| 38 |
encoder = None
|
| 39 |
common_tts_engine = None
|
| 40 |
+
whisper_asr = None
|
| 41 |
args_global = None
|
| 42 |
_model_lock = threading.Lock() # Thread lock for model initialization
|
| 43 |
|
| 44 |
def initialize_models():
|
| 45 |
"""Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)"""
|
| 46 |
+
global encoder, common_tts_engine, whisper_asr, args_global
|
| 47 |
|
| 48 |
# Fast path: check if already initialized (without lock)
|
| 49 |
if common_tts_engine is not None:
|
|
|
|
| 89 |
device_map=args_global.device_map,
|
| 90 |
)
|
| 91 |
logger.info("β StepCommonAudioTTS loaded")
|
| 92 |
+
|
| 93 |
+
# Initialize Whisper ASR (load outside GPU context, lighter model)
|
| 94 |
+
if whisper_asr is None:
|
| 95 |
+
whisper_asr = WhisperWrapper()
|
| 96 |
+
logger.info("β WhisperWrapper loaded")
|
| 97 |
+
|
| 98 |
print("Models initialized inside GPU context.")
|
| 99 |
|
| 100 |
if ZEROGPU_AVAILABLE:
|
|
|
|
| 186 |
self.args = args
|
| 187 |
self.edit_type_list = list(get_supported_edit_types().keys())
|
| 188 |
self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
|
| 189 |
+
self.enable_auto_transcribe = getattr(args, 'enable_auto_transcribe', False)
|
| 190 |
|
| 191 |
def history_messages_to_show(self, messages):
|
| 192 |
"""Convert message history to gradio chatbot format"""
|
|
|
|
| 424 |
outputs=self.edit_info,
|
| 425 |
)
|
| 426 |
|
| 427 |
+
# Add audio transcription event only if enabled
|
| 428 |
+
if self.enable_auto_transcribe:
|
| 429 |
+
self.prompt_audio_input.change(
|
| 430 |
+
fn=self.transcribe_audio,
|
| 431 |
+
inputs=[self.prompt_audio_input, self.prompt_text_input],
|
| 432 |
+
outputs=self.prompt_text_input,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
def update_edit_info(self, category):
|
| 436 |
"""Update sub-task dropdown based on main task selection"""
|
| 437 |
category_items = get_supported_edit_types()
|
|
|
|
| 439 |
value = None if len(choices) == 0 else choices[0]
|
| 440 |
return gr.Dropdown(label="Sub-task", choices=choices, value=value)
|
| 441 |
|
| 442 |
+
def transcribe_audio(self, audio_input, current_text):
|
| 443 |
+
"""Transcribe audio using Whisper ASR when prompt text is empty"""
|
| 444 |
+
global whisper_asr
|
| 445 |
+
|
| 446 |
+
# Only transcribe if current text is empty
|
| 447 |
+
if current_text and current_text.strip():
|
| 448 |
+
return current_text # Keep existing text
|
| 449 |
+
|
| 450 |
+
if not audio_input:
|
| 451 |
+
return "" # No audio to transcribe
|
| 452 |
+
|
| 453 |
+
try:
|
| 454 |
+
# Initialize whisper if not already loaded
|
| 455 |
+
if whisper_asr is None:
|
| 456 |
+
if args_global is None:
|
| 457 |
+
self.logger.error("Global args not set. Cannot initialize Whisper.")
|
| 458 |
+
return ""
|
| 459 |
+
|
| 460 |
+
whisper_asr = WhisperWrapper()
|
| 461 |
+
self.logger.info("β WhisperWrapper initialized for ASR")
|
| 462 |
+
|
| 463 |
+
# Transcribe audio
|
| 464 |
+
transcribed_text = whisper_asr(audio_input)
|
| 465 |
+
self.logger.info(f"Audio transcribed: {transcribed_text}")
|
| 466 |
+
return transcribed_text
|
| 467 |
+
|
| 468 |
+
except Exception as e:
|
| 469 |
+
self.logger.error(f"Failed to transcribe audio: {e}")
|
| 470 |
+
return ""
|
| 471 |
+
|
| 472 |
|
| 473 |
def launch_demo(args, editx_tab):
|
| 474 |
"""Launch the gradio demo"""
|
|
|
|
| 550 |
default="cuda",
|
| 551 |
help="Device mapping for model loading (default: cuda)"
|
| 552 |
)
|
| 553 |
+
parser.add_argument(
|
| 554 |
+
"--enable-auto-transcribe",
|
| 555 |
+
action="store_true",
|
| 556 |
+
help="Enable automatic audio transcription when uploading audio files (default: disabled)"
|
| 557 |
+
)
|
| 558 |
+
parser.set_defaults(enable_auto_transcribe=True)
|
| 559 |
|
| 560 |
args = parser.parse_args()
|
| 561 |
|
whisper_wrapper.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
from transformers import pipeline
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class WhisperWrapper:
|
| 8 |
+
"""Simplified Whisper ASR wrapper"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, model_id="openai/whisper-large-v3"):
|
| 11 |
+
"""
|
| 12 |
+
Initialize WhisperWrapper
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
model_id: Whisper model ID, default uses openai/whisper-large-v3
|
| 16 |
+
"""
|
| 17 |
+
self.logger = logging.getLogger(__name__)
|
| 18 |
+
self.model = None
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
self.model = pipeline("automatic-speech-recognition", model=model_id)
|
| 22 |
+
self.logger.info(f"β Whisper model loaded successfully: {model_id}")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
self.logger.error(f"β Failed to load Whisper model: {e}")
|
| 25 |
+
raise
|
| 26 |
+
|
| 27 |
+
def __call__(self, audio_input):
|
| 28 |
+
"""
|
| 29 |
+
Audio to text transcription
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
audio_input: Audio file path or audio tensor
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
Transcribed text
|
| 36 |
+
"""
|
| 37 |
+
if self.model is None:
|
| 38 |
+
raise RuntimeError("Whisper model not loaded")
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
# Load audio
|
| 42 |
+
if isinstance(audio_input, str):
|
| 43 |
+
# Audio file path
|
| 44 |
+
audio, audio_sr = torchaudio.load(audio_input)
|
| 45 |
+
audio = torchaudio.functional.resample(audio, audio_sr, 16000)
|
| 46 |
+
# Handle stereo to mono conversion (pipeline may not handle this)
|
| 47 |
+
if audio.shape[0] > 1:
|
| 48 |
+
audio = audio.mean(dim=0, keepdim=True) # Convert stereo to mono by averaging
|
| 49 |
+
# Convert to numpy and squeeze
|
| 50 |
+
audio = audio.squeeze(0).numpy()
|
| 51 |
+
elif isinstance(audio_input, torch.Tensor):
|
| 52 |
+
# Tensor input
|
| 53 |
+
audio = audio_input.cpu()
|
| 54 |
+
audio = torchaudio.functional.resample(audio, audio_sr, 16000)
|
| 55 |
+
# Handle stereo to mono conversion
|
| 56 |
+
if audio.ndim > 1 and audio.shape[0] > 1:
|
| 57 |
+
audio = audio.mean(dim=0, keepdim=True)
|
| 58 |
+
audio = audio.squeeze().numpy()
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
|
| 61 |
+
|
| 62 |
+
# Transcribe
|
| 63 |
+
result = self.model(audio)
|
| 64 |
+
text = result.get("text", "").strip() if isinstance(result, dict) else str(result).strip()
|
| 65 |
+
|
| 66 |
+
self.logger.debug(f"Transcription result: {text}")
|
| 67 |
+
return text
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
self.logger.error(f"Audio transcription failed: {e}")
|
| 71 |
+
return ""
|
| 72 |
+
|
| 73 |
+
def is_available(self):
|
| 74 |
+
"""Check if whisper model is available"""
|
| 75 |
+
return self.model is not None
|