Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import io | |
| import base64 | |
| import logging | |
| from scipy.io import wavfile | |
| from typing import Tuple, Dict, Any | |
| from transformers import AutoTokenizer, AutoModel | |
| from parler_tts import ParlerTTSForConditionalGeneration | |
| from huggingface_hub import hf_hub_download | |
| import os | |
| # --- Logging --- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("hi-tts") | |
| logger.addHandler(logging.StreamHandler()) | |
| # --- TTS Wrapper --- | |
| class IndicParlerTTS: | |
| def __init__(self, model_type: str = "parler", model_name: str = "ai4bharat/indic-parler-tts"): | |
| self.model_type = model_type # "parler" or "indicf5" | |
| self.model_name = model_name | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| self.tokenizer = None | |
| self.description_tokenizer = None | |
| self.sample_rate = 24000 # Default for both | |
| self.ref_audio_path = None | |
| self.ref_text = None | |
| self._load_model() | |
| # Supported languages (expanded for IndicF5) | |
| self.language_codes = { | |
| "as": "Assamese", "bn": "Bengali", "gu": "Gujarati", "hi": "Hindi", | |
| "kn": "Kannada", "ml": "Malayalam", "mr": "Marathi", "or": "Odia", | |
| "pa": "Punjabi", "ta": "Tamil", "te": "Telugu" | |
| } | |
| # Voice style mappings to descriptive terms (for Parler-TTS) | |
| self.voice_map = { | |
| "neutral": "neutral", | |
| "formal": "formal and clear", | |
| "casual": "casual and relaxed", | |
| "expressive": "expressive and animated", | |
| "emotional": "emotional and varied" | |
| } | |
| # For IndicF5, map voices to reference prompts (simplified; expand as needed) | |
| self.ref_map = { | |
| "neutral": ("prompts/PAN_F_HAPPY_00001.wav", "ਭਹੰਪੀ ਵਿੱਚ ਸਮਾਰਕਾਂ ਦੇ ਭਵਨ ਨਿਰਮਾਣ ਕਲਾ ਦੇ ਵੇਰਵੇ ਗੁੰਝਲਦਾਰ ਅਤੇ ਹੈਰਾਨ ਕਰਨ ਵਾਲੇ ਹਨ, ਜੋ ਮੈਨੂੰ ਖੁਸ਼ ਕਰਦੇ ਹਨ।"), | |
| # Add more mappings, e.g., for other styles/languages from prompts/ | |
| # "formal": ("path/to/formal.wav", "ref text"), | |
| } | |
| def _load_model(self): | |
| try: | |
| if self.model_type == "parler": | |
| logger.info(f"Loading Indic Parler-TTS ({self.model_name}) on {self.device}") | |
| self.model = ParlerTTSForConditionalGeneration.from_pretrained( | |
| self.model_name | |
| ).to(self.device) | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
| try: | |
| self.description_tokenizer = AutoTokenizer.from_pretrained( | |
| self.model.config.text_encoder._name_or_path | |
| ) | |
| except Exception: | |
| logger.warning("Falling back to main tokenizer for descriptions") | |
| self.description_tokenizer = self.tokenizer | |
| self.sample_rate = self.model.config.sampling_rate | |
| logger.info("✅ Indic Parler-TTS loaded") | |
| elif self.model_type == "indicf5": | |
| logger.info(f"Loading IndicF5 on {self.device}") | |
| self.model = AutoModel.from_pretrained("ai4bharat/IndicF5", trust_remote_code=True).to(self.device) | |
| # Download default reference for neutral (expand for other voices) | |
| default_ref_path = "prompts/PAN_F_HAPPY_00001.wav" | |
| self.ref_audio_path = hf_hub_download( | |
| repo_id="ai4bharat/IndicF5", | |
| filename=default_ref_path, | |
| local_dir="./prompts" | |
| ) | |
| self.ref_text = "ਭਹੰਪੀ ਵਿੱਚ ਸਮਾਰਕਾਂ ਦੇ ਭਵਨ ਨਿਰਮਾਣ ਕਲਾ ਦੇ ਵੇਰਵੇ ਗੁੰਝਲਦਾਰ ਅਤੇ ਹੈਰਾਨ ਕਰਨ ਵਾਲੇ ਹਨ, ਜੋ ਮੈਨੂੰ ਖੁਸ਼ ਕਰਦੇ ਹਨ।" | |
| # For other voices, override in generate() | |
| self.sample_rate = 24000 | |
| logger.info("✅ IndicF5 loaded with default reference") | |
| else: | |
| raise ValueError(f"Unsupported model_type: {self.model_type}") | |
| except Exception as e: | |
| logger.exception(f"Failed to load {self.model_type} model") | |
| self.model = None | |
| def generate(self, text: str, language: str = "hi", voice: str = "neutral", | |
| pitch: float = 1.0, speed: float = 1.0, emotion: float = 0.5, | |
| reverb: float = 0.0) -> Tuple[np.ndarray, int]: | |
| """ | |
| Generate speech using the selected model. | |
| Returns int16 numpy audio and sample rate. | |
| For IndicF5: Uses reference-based generation for humanized output. | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not available") | |
| if not text.strip(): | |
| raise ValueError("Empty text provided") | |
| if self.model_type == "parler": | |
| # Existing Parler-TTS logic (without noise) | |
| full_lang = self.language_codes.get(language, "Indian") | |
| voice_desc = self.voice_map.get(voice, "neutral") # Safe get to avoid errors | |
| pitch_desc = "high" if pitch > 1.2 else "low" if pitch < 0.8 else "balanced" | |
| speed_desc = "fast" if speed > 1.3 else "slow" if speed < 0.7 else "moderate" | |
| emotion_desc = "highly expressive" if emotion > 0.7 else "slightly expressive" if emotion > 0.3 else "neutral" | |
| reverb_desc = "with noticeable reverb as if in a room" if reverb > 0.5 else "clear and close-up" | |
| description = ( | |
| f"A {full_lang} speaker with a {voice_desc} voice, {pitch_desc} pitch, " | |
| f"{speed_desc} speaking pace, {emotion_desc} delivery, {reverb_desc}." | |
| ) | |
| # Tokenize | |
| prompt_input_ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device) | |
| prompt_attention_mask = self.tokenizer(text, return_tensors="pt").attention_mask.to(self.device) | |
| description_input_ids = self.description_tokenizer(description, return_tensors="pt").input_ids.to(self.device) | |
| description_attention_mask = self.description_tokenizer(description, return_tensors="pt").attention_mask.to(self.device) | |
| with torch.no_grad(): | |
| audio_tensor = self.model.generate( | |
| input_ids=description_input_ids, | |
| attention_mask=description_attention_mask, | |
| prompt_input_ids=prompt_input_ids, | |
| prompt_attention_mask=prompt_attention_mask | |
| ) | |
| audio = audio_tensor.cpu().numpy().squeeze() | |
| elif self.model_type == "indicf5": | |
| # IndicF5 logic: Use reference for voice style (humanized output) | |
| # For now, use default ref; map voice to specific ref if available | |
| ref_path, ref_txt = self.ref_map.get(voice, (self.ref_audio_path, self.ref_text)) | |
| with torch.no_grad(): | |
| audio_float = self.model( | |
| text, | |
| ref_audio_path=ref_path, | |
| ref_text=ref_txt | |
| ) | |
| # Normalize if needed (per example) | |
| if audio_float.dtype == np.int16: | |
| audio_float = audio_float.astype(np.float32) / 32768.0 | |
| audio = audio_float | |
| # Common post-processing: float32 [-1,1] → int16 | |
| audio = np.clip(audio, -1.0, 1.0) | |
| audio_int16 = (audio * 32767).astype(np.int16) | |
| return audio_int16, self.sample_rate | |
| # --- Utility helpers --- | |
| def wav_bytes_from_numpy(audio_np: np.ndarray, sample_rate: int) -> bytes: | |
| buffer = io.BytesIO() | |
| wavfile.write(buffer, sample_rate, audio_np) | |
| buffer.seek(0) | |
| return buffer.read() | |
| def encode_wav_base64(audio_bytes: bytes) -> str: | |
| return base64.b64encode(audio_bytes).decode("utf-8") | |
| # Instantiate TTS (default to Parler; will reinstantiate per selection) | |
| def get_tts(model_type): | |
| if model_type == "indicf5": | |
| return IndicParlerTTS(model_type="indicf5") | |
| else: | |
| return IndicParlerTTS(model_type="parler") | |
| tts = get_tts("parler") | |
| # --- Gradio functions / API functions --- | |
| def synthesize_speech(text: str, model_type: str, language: str, voice: str, | |
| pitch: float, speed: float, emotion: float, | |
| reverb: float): | |
| global tts | |
| try: | |
| if not text or not text.strip(): | |
| return None, "Please enter text to synthesize." | |
| if len(text) > 4000: | |
| return None, "Text too long. Maximum 4000 characters supported." | |
| # Re-instantiate TTS if model changed | |
| if tts.model_type != model_type: | |
| tts = get_tts(model_type) | |
| audio_np, sr = tts.generate(text=text, language=language, voice=voice, | |
| pitch=pitch, speed=speed, emotion=emotion, | |
| reverb=reverb) | |
| model_note = " (Parler-TTS: Style via description)" if model_type == "parler" else " (IndicF5: Humanized via reference voice)" | |
| return (sr, audio_np), f"Speech generated successfully{model_note}." | |
| except Exception as e: | |
| logger.exception("Error in synthesize_speech:") | |
| return None, f"Error: {str(e)}" | |
| def api_synthesize(text: str, model_type: str = "parler", language: str = "hi", voice: str = "neutral", | |
| pitch: float = 1.0, speed: float = 1.0, emotion: float = 0.5, | |
| reverb: float = 0.0) -> Dict[str, Any]: | |
| global tts | |
| try: | |
| if not text or not text.strip(): | |
| return {"error": "Please provide non-empty text."} | |
| if tts.model_type != model_type: | |
| tts = get_tts(model_type) | |
| audio_np, sr = tts.generate(text=text, language=language, voice=voice, | |
| pitch=float(pitch), speed=float(speed), | |
| emotion=float(emotion), reverb=float(reverb)) | |
| wav_bytes = wav_bytes_from_numpy(audio_np, sr) | |
| return { | |
| "audio": encode_wav_base64(wav_bytes), | |
| "sample_rate": sr, | |
| "message": "OK" | |
| } | |
| except Exception as e: | |
| logger.exception("API synthesis failed") | |
| return {"error": str(e)} | |
| def get_voice_list(language_code: str = "hi"): | |
| voices = ["neutral", "formal", "casual", "expressive", "emotional"] | |
| return gr.Dropdown.update(choices=voices, value="neutral") | |
| # --- Gradio UI --- # | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="HI-TTS - Humanized Indic Text-to-Speech", | |
| css=""" | |
| .gradio-container { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| } | |
| .gr-form { | |
| background: rgba(255, 255, 255, 0.08); | |
| backdrop-filter: blur(8px); | |
| border-radius: 14px; | |
| padding: 10px; | |
| } | |
| """ | |
| ) as demo: | |
| gr.Markdown("# 🎤 HI-TTS — Humanized Indic TTS") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_input = gr.Textbox(label="Enter Text", placeholder="Type text here...", lines=4) | |
| model_dropdown = gr.Dropdown( | |
| choices=["Indic Parler-TTS", "IndicF5 (Humanized)"], | |
| value="Indic Parler-TTS", | |
| label="Model", | |
| info="Parler-TTS: Style via text description. IndicF5: Near-human via voice reference (combined for ultimate humanization)." | |
| ) | |
| language_dropdown = gr.Dropdown( | |
| choices=list(tts.language_codes.keys()), | |
| value="hi", | |
| label="Language (code)", | |
| info="Select language code (e.g. hi, bn, ta). Model auto-detects from text." | |
| ) | |
| voice_dropdown = gr.Dropdown(choices=["neutral", "formal", "casual", "expressive", "emotional"], | |
| value="neutral", label="Voice Style") | |
| with gr.Column(scale=1): | |
| pitch_slider = gr.Slider(0.5, 2.0, value=1.0, step=0.1, label="Pitch (normal: 1.0)") | |
| speed_slider = gr.Slider(0.3, 3.0, value=1.0, step=0.1, label="Speed (normal: 1.0)") | |
| emotion_slider = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Emotion (normal: 0.5)") | |
| reverb_slider = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Reverb (normal: 0.0)") | |
| with gr.Row(): | |
| generate_btn = gr.Button("🎵 Generate Speech", variant="primary") | |
| clear_btn = gr.Button("🗑️ Clear") | |
| with gr.Row(): | |
| audio_output = gr.Audio(label="Generated Speech", type="numpy") | |
| status_output = gr.Textbox(label="Status", interactive=False, value="Ready") | |
| # Bind UI | |
| generate_btn.click( | |
| fn=synthesize_speech, | |
| inputs=[text_input, model_dropdown, language_dropdown, voice_dropdown, pitch_slider, speed_slider, emotion_slider, reverb_slider], | |
| outputs=[audio_output, status_output] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ("", gr.update(value=None), "Ready"), | |
| outputs=[text_input, audio_output, status_output] | |
| ) | |
| language_dropdown.change( | |
| fn=get_voice_list, | |
| inputs=[language_dropdown], | |
| outputs=[voice_dropdown] | |
| ) | |
| demo.load(lambda: "Ready", outputs=[status_output]) | |
| # --- Launch --- # | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True) |