import gradio as gr import os import argparse import torch import logging import threading from datetime import datetime import torchaudio import librosa import soundfile as sf # ZeroGPU support try: import spaces ZEROGPU_AVAILABLE = True except ImportError: ZEROGPU_AVAILABLE = False # Create a dummy decorator for non-ZeroGPU environments class spaces: @staticmethod def GPU(duration=10): def decorator(func): return func return decorator # Project imports from tokenizer import StepAudioTokenizer from tts import StepAudioTTS from model_loader import ModelSource from config.edit_config import get_supported_edit_types from whisper_wrapper import WhisperWrapper # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Global variables for ZeroGPU-optimized loading encoder = None common_tts_engine = None whisper_asr = None args_global = None _model_lock = threading.Lock() # Thread lock for model initialization def initialize_whisper(): global whisper_asr if whisper_asr is not None: return with _model_lock: if whisper_asr is not None: return try: whisper_asr = WhisperWrapper() logger.info("โœ“ WhisperWrapper initialized for ASR") except Exception as e: logger.error(f"โŒ Error loading Whisper ASR model: {e}") raise def initialize_models(): """Initialize models on first GPU call (ZeroGPU optimization: load inside GPU context)""" global encoder, common_tts_engine, args_global # Fast path: check if already initialized (without lock) if common_tts_engine is not None: return # Already initialized # Slow path: acquire lock and double-check with _model_lock: # Double-check pattern: another thread might have initialized while waiting for lock if common_tts_engine is not None: return # Already initialized by another thread if args_global is None: raise RuntimeError("Global args not set. Cannot initialize models.") try: logger.info("๐Ÿš€ Initializing models inside GPU context (first call)...") # Determine model source source_mapping = { "auto": ModelSource.AUTO, "local": ModelSource.LOCAL, "modelscope": ModelSource.MODELSCOPE, "huggingface": ModelSource.HUGGINGFACE } model_source = source_mapping[args_global.model_source] # Load StepAudioTokenizer (avoid CUDA initialization in main process) encoder = StepAudioTokenizer( os.path.join(args_global.model_path, "Step-Audio-Tokenizer"), model_source=model_source, funasr_model_id=args_global.tokenizer_model_id ) logger.info("โœ“ StepAudioTokenizer loaded") # Initialize common TTS engine (avoid CUDA initialization in main process) common_tts_engine = StepAudioTTS( os.path.join(args_global.model_path, "Step-Audio-EditX"), encoder, model_source=model_source, tts_model_id=args_global.tts_model_id, quantization_config=args_global.quantization, torch_dtype=torch_dtype, device_map=args_global.device_map, ) logger.info("โœ“ StepCommonAudioTTS loaded") print("Models initialized inside GPU context.") if ZEROGPU_AVAILABLE: logger.info("๐Ÿ’ก Models loaded inside GPU context - ready for inference") else: logger.info("๐Ÿ’ก Models loaded - ready for inference") except Exception as e: logger.error(f"โŒ Error loading models: {e}") raise def get_model_config(): """Get model configuration without initializing GPU models""" if args_global is None: raise RuntimeError("Global args not set. Cannot get model config.") return { "encoder_path": os.path.join(args_global.model_path, "Step-Audio-Tokenizer"), "tts_path": os.path.join(args_global.model_path, "Step-Audio-EditX"), "model_source": args_global.model_source, "tokenizer_model_id": args_global.tokenizer_model_id, "tts_model_id": args_global.tts_model_id } def get_gpu_duration(audio_input, text_input, target_text, task_type, task_info): """Dynamic GPU duration based on whether models need initialization""" global common_tts_engine if common_tts_engine is None: # First call - need time for model loading (up to 5 minutes) return 120 # Maximum allowed duration for model initialization else: # Subsequent calls - only inference time needed return 120 # Standard inference duration @spaces.GPU(duration=get_gpu_duration) # Dynamic duration based on model state def process_audio_with_gpu(audio_input, text_input, target_text, task_type, task_info): """Process audio using GPU (models are loaded inside GPU context to avoid main process errors)""" global common_tts_engine # Initialize models if not already loaded (inside GPU context to avoid main process errors) if common_tts_engine is None: print("Initializing common_tts_engine inside GPU context...") logger.info("๐ŸŽฏ GPU allocated for 300s (first call with model loading)...") initialize_models() logger.info("โœ… Models loaded successfully inside GPU context") else: print("common_tts_engine already initialized.") logger.info("๐ŸŽฏ GPU allocated for 120s (inference with loaded models)...") try: # Use loaded models (first call may include loading time, subsequent calls are fast) if task_type == "clone": output_audio, sr = common_tts_engine.clone(audio_input, text_input, target_text) else: output_audio, sr = common_tts_engine.edit(audio_input, text_input, task_type, task_info, target_text) logger.info("โœ… Audio processing completed") return output_audio, sr except Exception as e: logger.error(f"โŒ Audio processing failed: {e}") raise # GPU automatically deallocated when function exits @spaces.GPU(duration=30) def transcribe_audio(audio_input, current_text): """Transcribe audio using Whisper ASR when prompt text is empty""" global whisper_asr # Only transcribe if current text is empty if current_text and current_text.strip(): return current_text # Keep existing text if not audio_input: return "" # No audio to transcribe if whisper_asr is None: initialize_whisper() try: # Transcribe audio transcribed_text = whisper_asr(audio_input) logger.info(f"Audio transcribed: {transcribed_text}") return transcribed_text except Exception as e: logger.error(f"Failed to transcribe audio: {e}") return "" # Save audio to temporary directory def save_audio(audio_type, audio_data, sr, tmp_dir): """Save audio data to a temporary file with timestamp""" current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") save_path = os.path.join(tmp_dir, audio_type, f"{current_time}.wav") os.makedirs(os.path.dirname(save_path), exist_ok=True) try: if isinstance(audio_data, torch.Tensor): torchaudio.save(save_path, audio_data, sr) else: sf.write(save_path, audio_data, sr) logger.debug(f"Audio saved to: {save_path}") return save_path except Exception as e: logger.error(f"Failed to save audio: {e}") raise class EditxTab: """Audio editing and voice cloning interface tab""" def __init__(self, args): self.args = args self.edit_type_list = list(get_supported_edit_types().keys()) self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") self.enable_auto_transcribe = getattr(args, 'enable_auto_transcribe', False) def history_messages_to_show(self, messages): """Convert message history to gradio chatbot format""" show_msgs = [] for message in messages: edit_type = message['edit_type'] edit_info = message['edit_info'] source_text = message['source_text'] target_text = message['target_text'] raw_audio_part = message['raw_wave'] edit_audio_part = message['edit_wave'] type_str = f"{edit_type}-{edit_info}" if edit_info is not None else f"{edit_type}" show_msgs.extend([ {"role": "user", "content": f"ไปปๅŠก็ฑปๅž‹๏ผš{type_str}\nๆ–‡ๆœฌ๏ผš{source_text}"}, {"role": "user", "content": gr.Audio(value=raw_audio_part, interactive=False)}, {"role": "assistant", "content": f"่พ“ๅ‡บ้Ÿณ้ข‘๏ผš\nๆ–‡ๆœฌ๏ผš{target_text}"}, {"role": "assistant", "content": gr.Audio(value=edit_audio_part, interactive=False)} ]) return show_msgs def generate_clone(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state): """Generate cloned audio (models are loaded on first GPU call)""" self.logger.info("Starting voice cloning process") state['history_audio'] = [] state['history_messages'] = [] # Input validation if not prompt_text_input or prompt_text_input.strip() == "": error_msg = "[Error] Uploaded text cannot be empty." self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state if not prompt_audio_input: error_msg = "[Error] Uploaded audio cannot be empty." self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state if not generated_text or generated_text.strip() == "": error_msg = "[Error] Clone content cannot be empty." self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state if edit_type != "clone": error_msg = "[Error] CLONE button must use clone task." self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state try: # Use GPU inference with models loaded inside GPU context output_audio, output_sr = process_audio_with_gpu( prompt_audio_input, prompt_text_input, generated_text, "clone", edit_info ) if output_audio is not None and output_sr is not None: # Convert tensor to numpy if needed if isinstance(output_audio, torch.Tensor): audio_numpy = output_audio.cpu().numpy().squeeze() else: audio_numpy = output_audio # Load original audio for comparison input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input) # Create message for history cur_assistant_msg = { "edit_type": edit_type, "edit_info": edit_info, "source_text": prompt_text_input, "target_text": generated_text, "raw_wave": (input_sample_rate, input_audio_data_numpy), "edit_wave": (output_sr, audio_numpy), } state["history_audio"].append((output_sr, audio_numpy, generated_text)) state["history_messages"].append(cur_assistant_msg) show_msgs = self.history_messages_to_show(state["history_messages"]) self.logger.info("Voice cloning completed successfully") return show_msgs, state else: error_msg = "[Error] Clone failed" self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state except Exception as e: error_msg = f"[Error] Clone failed: {str(e)}" self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state def generate_edit(self, prompt_text_input, prompt_audio_input, generated_text, edit_type, edit_info, state): """Generate edited audio (models are loaded on first GPU call)""" self.logger.info("Starting audio editing process") # Input validation if not prompt_audio_input: error_msg = "[Error] Uploaded audio cannot be empty." self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state try: # Determine which audio to use if len(state["history_audio"]) == 0: # First edit - use uploaded audio audio_to_edit = prompt_audio_input text_to_use = prompt_text_input self.logger.debug("Using prompt audio, no history found") else: # Use previous edited audio - save it to temp file first sample_rate, audio_numpy, previous_text = state["history_audio"][-1] temp_path = save_audio("temp", audio_numpy, sample_rate, self.args.tmp_dir) audio_to_edit = temp_path text_to_use = previous_text self.logger.debug(f"Using previous audio from history, count: {len(state['history_audio'])}") # For para-linguistic, use generated_text; otherwise use source text if edit_type not in {"paralinguistic"}: generated_text = text_to_use # Use GPU inference with models loaded inside GPU context output_audio, output_sr = process_audio_with_gpu( audio_to_edit, text_to_use, generated_text, edit_type, edit_info ) if output_audio is not None and output_sr is not None: # Convert tensor to numpy if needed if isinstance(output_audio, torch.Tensor): audio_numpy = output_audio.cpu().numpy().squeeze() else: audio_numpy = output_audio # Load original audio for comparison if len(state["history_audio"]) == 0: input_audio_data_numpy, input_sample_rate = librosa.load(prompt_audio_input) else: input_sample_rate, input_audio_data_numpy, _ = state["history_audio"][-1] # Create message for history cur_assistant_msg = { "edit_type": edit_type, "edit_info": edit_info, "source_text": text_to_use, "target_text": generated_text, "raw_wave": (input_sample_rate, input_audio_data_numpy), "edit_wave": (output_sr, audio_numpy), } state["history_audio"].append((output_sr, audio_numpy, generated_text)) state["history_messages"].append(cur_assistant_msg) show_msgs = self.history_messages_to_show(state["history_messages"]) self.logger.info("Audio editing completed successfully") return show_msgs, state else: error_msg = "[Error] Edit failed" self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state except Exception as e: error_msg = f"[Error] Edit failed: {str(e)}" self.logger.error(error_msg) return [{"role": "user", "content": error_msg}], state def clear_history(self, state): """Clear conversation history""" state["history_messages"] = [] state["history_audio"] = [] return [], state def init_state(self): """Initialize conversation state""" return { "history_messages": [], "history_audio": [] } def register_components(self): """Register gradio components - maintaining exact layout from original""" with gr.Tab("Editx"): with gr.Row(): with gr.Column(): self.model_input = gr.Textbox(label="Model Name", value="Step-Audio-EditX", scale=1) self.prompt_text_input = gr.Textbox(label="Prompt Text", value="", scale=1) self.prompt_audio_input = gr.Audio( sources=["upload", "microphone"], format="wav", type="filepath", label="Input Audio", ) self.generated_text = gr.Textbox(label="Target Text", lines=1, max_lines=200, max_length=1000) with gr.Column(): with gr.Row(): self.edit_type = gr.Dropdown(label="Task", choices=self.edit_type_list, value="clone") self.edit_info = gr.Dropdown(label="Sub-task", choices=[], value=None) self.chat_box = gr.Chatbot(label="History", type="messages", height=480*1) with gr.Row(): with gr.Column(): with gr.Row(): self.button_tts = gr.Button("CLONE", variant="primary") self.button_edit = gr.Button("EDIT", variant="primary") with gr.Column(): self.clean_history_submit = gr.Button("Clear History", variant="primary") gr.Markdown("---") gr.Markdown(""" **Button Description:** - CLONE: Synthesizes audio based on uploaded audio and text, only used for clone mode, will clear history information when used. - EDIT: Edits based on uploaded audio, or continues to stack edit effects based on the previous round of generated audio. """) gr.Markdown(""" **Operation Workflow:** - Upload the audio to be edited on the left side and fill in the corresponding text content of the audio; - If the task requires modifying text content (such as clone, para-linguistic), fill in the text to be synthesized in the "clone text" field. For all other tasks, keep the uploaded audio text content unchanged; - Select tasks and subtasks on the right side (some tasks have no subtasks, such as vad, etc.); - Click the "CLONE" or "EDIT" button on the left side, and audio will be generated in the dialog box on the right side. """) gr.Markdown(""" **Para-linguistic Description:** - Supported tags include: [Breathing] [Laughter] [Surprise-oh] [Confirmation-en] [Uhm] [Surprise-ah] [Surprise-wa] [Sigh] [Question-ei] [Dissatisfaction-hnn] - Example: - Fill in "clone text" field: "Great, the weather is so nice today." Click the "CLONE" button to get audio. - Change "clone text" field to: "Great[Laughter], the weather is so nice today[Surprise-ah]." Click the "EDIT" button to get para-linguistic audio. """) def register_events(self): """Register event handlers""" # Create independent state for each session state = gr.State(self.init_state()) self.button_tts.click(self.generate_clone, inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state], outputs=[self.chat_box, state]) self.button_edit.click(self.generate_edit, inputs=[self.prompt_text_input, self.prompt_audio_input, self.generated_text, self.edit_type, self.edit_info, state], outputs=[self.chat_box, state]) self.clean_history_submit.click(self.clear_history, inputs=[state], outputs=[self.chat_box, state]) self.edit_type.change( fn=self.update_edit_info, inputs=self.edit_type, outputs=self.edit_info, ) # Add audio transcription event only if enabled if self.enable_auto_transcribe: self.prompt_audio_input.change( fn=transcribe_audio, inputs=[self.prompt_audio_input, self.prompt_text_input], outputs=self.prompt_text_input, ) def update_edit_info(self, category): """Update sub-task dropdown based on main task selection""" category_items = get_supported_edit_types() choices = category_items.get(category, []) value = None if len(choices) == 0 else choices[0] return gr.Dropdown(label="Sub-task", choices=choices, value=value) def launch_demo(args, editx_tab): """Launch the gradio demo""" with gr.Blocks( theme=gr.themes.Soft(), title="๐ŸŽ™๏ธ Step-Audio-EditX", css=""" :root { --font: "Helvetica Neue", Helvetica, Arial, sans-serif; --font-mono: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace; } """) as demo: gr.Markdown("## ๐ŸŽ™๏ธ Step-Audio-EditX") gr.Markdown("Audio Editing and Zero-Shot Cloning using Step-Audio-EditX") # Register components editx_tab.register_components() # Register events editx_tab.register_events() # Launch demo demo.queue().launch( server_name=args.server_name, server_port=args.server_port, share=args.share if hasattr(args, 'share') else False ) if __name__ == "__main__": # Parse command line arguments parser = argparse.ArgumentParser(description="Step-Audio Edit Demo") parser.add_argument("--model-path", type=str, default="stepfun-ai", help="Model path.") parser.add_argument("--server-name", type=str, default="0.0.0.0", help="Demo server name.") parser.add_argument("--server-port", type=int, default=7860, help="Demo server port.") parser.add_argument("--tmp-dir", type=str, default="/tmp/gradio", help="Save path.") parser.add_argument("--share", action="store_true", help="Share gradio app.") # Multi-source loading support parameters parser.add_argument( "--model-source", type=str, default="huggingface", choices=["auto", "local", "modelscope", "huggingface"], help="Model source: auto (detect automatically), local, modelscope, or huggingface" ) parser.add_argument( "--tokenizer-model-id", type=str, default="dengcunqin/speech_paraformer-large_asr_nat-zh-cantonese-en-16k-vocab8501-online", help="Tokenizer model ID for online loading" ) parser.add_argument( "--tts-model-id", type=str, default=None, help="TTS model ID for online loading (if different from model-path)" ) parser.add_argument( "--quantization", type=str, default=None, choices=["int4", "int8"], help="Enable quantization for the TTS model to reduce memory usage." "Choices: int4 (online), int8 (online)." "When quantization is enabled, data types are handled automatically by the quantization library." ) parser.add_argument( "--torch-dtype", type=str, default="bfloat16", choices=["float16", "bfloat16", "float32"], help="PyTorch data type for model operations. This setting only applies when quantization is disabled. " "When quantization is enabled, data types are managed automatically." ) parser.add_argument( "--device-map", type=str, default="cuda", help="Device mapping for model loading (default: cuda)" ) parser.add_argument( "--enable-auto-transcribe", action="store_true", help="Enable automatic audio transcription when uploading audio files (default: disabled)" ) parser.set_defaults(enable_auto_transcribe=True) args = parser.parse_args() # Store args globally for model configuration args_global = args logger.info(f"Configuration loaded:") # Map string arguments to actual types source_mapping = { "auto": ModelSource.AUTO, "local": ModelSource.LOCAL, "modelscope": ModelSource.MODELSCOPE, "huggingface": ModelSource.HUGGINGFACE } model_source = source_mapping[args.model_source] # Map torch dtype string to actual torch dtype dtype_mapping = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32 } torch_dtype = dtype_mapping[args.torch_dtype] logger.info(f"Loading models with source: {args.model_source}") logger.info(f"Model path: {args.model_path}") logger.info(f"Tokenizer model ID: {args.tokenizer_model_id}") logger.info(f"Torch dtype: {args.torch_dtype}") logger.info(f"Device map: {args.device_map}") if args.tts_model_id: logger.info(f"TTS model ID: {args.tts_model_id}") if args.quantization: logger.info(f"๐Ÿ”ง {args.quantization.upper()} quantization enabled") # Create EditxTab instance editx_tab = EditxTab(args) # Launch demo launch_demo(args, editx_tab)