File size: 4,064 Bytes
74b32ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python3
import os
import sys
import requests
import tarfile
import zipfile
from pathlib import Path

MODEL_DIR = "trained_model"
MODEL_CHECKPOINT = "checkpoint_latest.pth.tar"
CONFIG_FILE = "hparams.yml"

def download_file(url, destination):
    """Download a file from url to destination."""
    print(f"Downloading {url} to {destination}")
    response = requests.get(url, stream=True)
    response.raise_for_status()
    
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 Kibibyte
    downloaded = 0
    
    with open(destination, 'wb') as file:
        for data in response.iter_content(block_size):
            downloaded += len(data)
            file.write(data)
            
            # Update progress bar
            done = int(50 * downloaded / total_size) if total_size > 0 else 0
            sys.stdout.write(f"\r[{'=' * done}{' ' * (50-done)}] {downloaded}/{total_size} bytes")
            sys.stdout.flush()
    
    print("\nDownload complete!")


def extract_archive(archive_path, extract_to):
    """Extract zip or tar archive to the specified directory."""
    print(f"Extracting {archive_path} to {extract_to}")
    
    if archive_path.endswith('.zip'):
        with zipfile.ZipFile(archive_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
    elif archive_path.endswith(('.tar.gz', '.tgz')):
        with tarfile.open(archive_path, 'r:gz') as tar_ref:
            tar_ref.extractall(extract_to)
    elif archive_path.endswith('.tar'):
        with tarfile.open(archive_path, 'r:') as tar_ref:
            tar_ref.extractall(extract_to)
    else:
        print(f"Unsupported archive format: {archive_path}")
        return False
    
    print("Extraction complete!")
    return True


def setup_model():
    """Download and set up the model files."""
    # Create model directory if it doesn't exist
    os.makedirs(MODEL_DIR, exist_ok=True)

    # Path for model checkpoint
    model_path = os.path.join(MODEL_DIR, MODEL_CHECKPOINT)
    # Path for config
    config_path = os.path.join(MODEL_DIR, CONFIG_FILE)

    # Check if files already exist
    if os.path.exists(model_path) and os.path.exists(config_path):
        print("Model files already exist. Skipping download.")
        return True

    # URLs for the model files
    # Note: Replace these with the actual URLs for your model
    model_url = "REPLACE_WITH_ACTUAL_MODEL_URL"
    config_url = "REPLACE_WITH_ACTUAL_CONFIG_URL"

    # Download and setup instructions
    print("""
    =================================================================
    IMPORTANT: Model files need to be manually added
    =================================================================
    
    This demo requires the following files from the Myanmar TTS model:
    1. The model checkpoint: checkpoint_latest.pth.tar
    2. The hyperparameters file: hparams.yml
    
    Please obtain these files from the model creator and place them in:
    - trained_model/checkpoint_latest.pth.tar
    - trained_model/hparams.yml
    
    Alternatively, you can update this script with the correct download URLs.
    
    Model repository: https://github.com/hpbyte/myanmar-tts
    =================================================================
    """)
    
    # If you have working URLs, uncomment these lines:
    # download_file(model_url, model_path)
    # download_file(config_url, config_path)
    
    # Check if we managed to get the files (if using manual instructions)
    if not os.path.exists(model_path) or not os.path.exists(config_path):
        print("Model files are missing. Please add them manually as described above.")
        # Create placeholder files with instructions
        with open(model_path, 'w') as f:
            f.write("This is a placeholder. Replace with actual model file.")
        with open(config_path, 'w') as f:
            f.write("This is a placeholder. Replace with actual hparams.yml file.")
        return False
    
    return True


if __name__ == "__main__":
    setup_model()