aungkomyat commited on
Commit
f1d74e2
Β·
verified Β·
1 Parent(s): a50ce25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -65
app.py CHANGED
@@ -1,93 +1,99 @@
1
  import os
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
- import scipy.io.wavfile
6
- from utils.hparams import create_hparams
7
- from train import load_model
8
- from synthesis import generate_speech
9
- from text import text_to_sequence
10
 
11
- # Path configurations
 
12
  MODEL_DIR = "trained_model"
13
- MODEL_PATH = os.path.join(MODEL_DIR, "checkpoint_latest.pth.tar")
14
- CONFIG_PATH = os.path.join(MODEL_DIR, "hparams.yml")
15
- OUTPUT_PATH = "output.wav"
16
 
17
- # Download model if it doesn't exist
18
- def download_model():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  if not os.path.exists(MODEL_DIR):
20
  os.makedirs(MODEL_DIR)
 
21
 
22
- if not os.path.exists(MODEL_PATH):
23
- print("Downloading model...")
24
- # Add model download code here
25
- # For example:
26
- # !wget -O MODEL_PATH https://path/to/model
27
- raise Exception("You need to download the model checkpoint file and place it in trained_model/checkpoint_latest.pth.tar")
28
-
29
- if not os.path.exists(CONFIG_PATH):
30
- print("Downloading config...")
31
- # Add config download code here
32
- # For example:
33
- # !wget -O CONFIG_PATH https://path/to/config
34
- raise Exception("You need to download the hparams.yml file and place it in trained_model/hparams.yml")
35
 
36
- # Initialize model
37
- def init_model():
38
  try:
39
- download_model()
 
 
 
 
 
 
 
 
 
 
40
 
41
- hparams = create_hparams(CONFIG_PATH)
 
 
 
 
 
 
 
 
42
  model = load_model(hparams)
43
- model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu'))['state_dict'])
44
  model.eval()
45
 
46
- return model, hparams
47
- except Exception as e:
48
- print(f"Error initializing model: {str(e)}")
49
- return None, None
50
-
51
- # Generate speech
52
- def synthesize(text, model, hparams):
53
- try:
54
  sequence = np.array(text_to_sequence(text, ['burmese_cleaners']))[None, :]
55
  sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cpu().long()
56
 
 
57
  mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
58
 
 
59
  with torch.no_grad():
60
  waveform = generate_speech(mel_outputs_postnet, hparams)
61
 
62
- scipy.io.wavfile.write(OUTPUT_PATH, hparams.sampling_rate, waveform)
 
 
63
 
64
- return OUTPUT_PATH, None
 
65
  except Exception as e:
66
- return None, str(e)
67
 
68
- # Gradio interface
69
  def tts_interface(text):
70
  if not text.strip():
71
  return None, "Please enter some text."
72
 
73
- global MODEL, HPARAMS
74
- if MODEL is None or HPARAMS is None:
75
- MODEL, HPARAMS = init_model()
76
-
77
- if MODEL is None:
78
- return None, "Model could not be initialized. Please check logs."
79
-
80
- audio_path, error = synthesize(text, MODEL, HPARAMS)
81
-
82
- if error:
83
- return None, f"Error generating speech: {error}"
84
-
85
- return audio_path, "Speech generated successfully!"
86
 
87
- # Initialize global model variables
88
- MODEL, HPARAMS = None, None
 
89
 
90
- # Create Gradio interface
91
  demo = gr.Interface(
92
  fn=tts_interface,
93
  inputs=[
@@ -106,9 +112,12 @@ demo = gr.Interface(
106
  This is a demo of the Myanmar Text-to-Speech system developed by hpbyte.
107
  Enter Burmese text in the box below and click 'Submit' to generate speech.
108
 
 
 
 
 
109
  GitHub Repository: https://github.com/hpbyte/myanmar-tts
110
  """,
111
- allow_flagging="never",
112
  examples=[
113
  ["α€™α€„α€Ία€Ήα€‚α€œα€¬α€•α€«"],
114
  ["မြန်မာစကားပြောစနစ်ကို α€€α€Όα€­α€―α€†α€­α€―α€•α€«α€α€šα€Ί"],
@@ -116,13 +125,6 @@ demo = gr.Interface(
116
  ]
117
  )
118
 
119
- # Initialize model at startup
120
- try:
121
- MODEL, HPARAMS = init_model()
122
- print("Model initialized successfully!")
123
- except Exception as e:
124
- print(f"Error initializing model: {str(e)}")
125
-
126
  # Launch the app
127
  if __name__ == "__main__":
128
  demo.launch()
 
1
  import os
2
+ import sys
3
  import gradio as gr
4
  import numpy as np
5
  import torch
6
+ import subprocess
7
+ import shutil
8
+ from pathlib import Path
 
 
9
 
10
+ # Model repository information
11
+ REPO_URL = "https://github.com/hpbyte/myanmar-tts.git"
12
  MODEL_DIR = "trained_model"
13
+ REPO_DIR = "myanmar-tts"
 
 
14
 
15
+ # Check and install the package if not already installed
16
+ def setup_environment():
17
+ status_msg = ""
18
+
19
+ # Clone the repository if it doesn't exist
20
+ if not os.path.exists(REPO_DIR):
21
+ status_msg += "Cloning repository...\n"
22
+ subprocess.run(["git", "clone", REPO_URL], check=True)
23
+
24
+ # Add the repository to Python path
25
+ repo_path = os.path.abspath(REPO_DIR)
26
+ if repo_path not in sys.path:
27
+ sys.path.append(repo_path)
28
+ status_msg += f"Added {repo_path} to Python path\n"
29
+
30
+ # Create model directory if it doesn't exist
31
  if not os.path.exists(MODEL_DIR):
32
  os.makedirs(MODEL_DIR)
33
+ status_msg += f"Created {MODEL_DIR} directory\n"
34
 
35
+ return status_msg + "Environment setup complete"
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Function to synthesize speech
38
+ def synthesize_speech(text):
39
  try:
40
+ # Import necessary modules from the repository
41
+ sys.path.append(REPO_DIR)
42
+ from myanmar_tts.text import text_to_sequence
43
+ from myanmar_tts.utils.hparams import create_hparams
44
+ from myanmar_tts.train import load_model
45
+ from myanmar_tts.synthesis import generate_speech
46
+ import scipy.io.wavfile
47
+
48
+ # Check if model exists, if not provide instructions
49
+ checkpoint_path = os.path.join(MODEL_DIR, "checkpoint_latest.pth.tar")
50
+ config_path = os.path.join(MODEL_DIR, "hparams.yml")
51
 
52
+ if not os.path.exists(checkpoint_path) or not os.path.exists(config_path):
53
+ return None, f"""Model files not found. Please upload:
54
+ 1. The checkpoint file at: {checkpoint_path}
55
+ 2. The hparams.yml file at: {config_path}
56
+
57
+ You can obtain these files from the original repository or by training the model."""
58
+
59
+ # Load the model and hyperparameters
60
+ hparams = create_hparams(config_path)
61
  model = load_model(hparams)
62
+ model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))['state_dict'])
63
  model.eval()
64
 
65
+ # Process text input
 
 
 
 
 
 
 
66
  sequence = np.array(text_to_sequence(text, ['burmese_cleaners']))[None, :]
67
  sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cpu().long()
68
 
69
+ # Generate mel spectrograms
70
  mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
71
 
72
+ # Generate waveform
73
  with torch.no_grad():
74
  waveform = generate_speech(mel_outputs_postnet, hparams)
75
 
76
+ # Save and return the audio
77
+ output_path = "output.wav"
78
+ scipy.io.wavfile.write(output_path, hparams.sampling_rate, waveform)
79
 
80
+ return output_path, "Speech generated successfully!"
81
+
82
  except Exception as e:
83
+ return None, f"Error: {str(e)}\n\nMake sure you have uploaded the model files to the {MODEL_DIR} directory."
84
 
85
+ # Function for the Gradio interface
86
  def tts_interface(text):
87
  if not text.strip():
88
  return None, "Please enter some text."
89
 
90
+ return synthesize_speech(text)
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
+ # Set up the environment
93
+ setup_message = setup_environment()
94
+ print(setup_message)
95
 
96
+ # Create the Gradio interface
97
  demo = gr.Interface(
98
  fn=tts_interface,
99
  inputs=[
 
112
  This is a demo of the Myanmar Text-to-Speech system developed by hpbyte.
113
  Enter Burmese text in the box below and click 'Submit' to generate speech.
114
 
115
+ **Note:** You need to upload the model files to the 'trained_model' directory:
116
+ - checkpoint_latest.pth.tar
117
+ - hparams.yml
118
+
119
  GitHub Repository: https://github.com/hpbyte/myanmar-tts
120
  """,
 
121
  examples=[
122
  ["α€™α€„α€Ία€Ήα€‚α€œα€¬α€•α€«"],
123
  ["မြန်မာစကားပြောစနစ်ကို α€€α€Όα€­α€―α€†α€­α€―α€•α€«α€α€šα€Ί"],
 
125
  ]
126
  )
127
 
 
 
 
 
 
 
 
128
  # Launch the app
129
  if __name__ == "__main__":
130
  demo.launch()