Spaces:
Paused
Paused
Update scripts/download_models.py
Browse files- scripts/download_models.py +283 -31
scripts/download_models.py
CHANGED
|
@@ -1,45 +1,297 @@
|
|
| 1 |
-
|
| 2 |
-
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from huggingface_hub import snapshot_download, login
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
try:
|
| 11 |
-
|
|
|
|
| 12 |
except Exception:
|
| 13 |
-
|
| 14 |
|
| 15 |
-
|
|
|
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
|
| 19 |
|
| 20 |
-
def
|
| 21 |
-
|
| 22 |
-
if isinstance(d, dict):
|
| 23 |
-
for k, v in d.items():
|
| 24 |
if isinstance(v, dict):
|
| 25 |
mid = v.get("model_id")
|
| 26 |
if isinstance(mid, str) and mid.strip():
|
| 27 |
ids.add(mid.strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
return sorted(ids)
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
import time
|
| 6 |
+
import yaml
|
| 7 |
+
import logging
|
| 8 |
+
import shutil
|
| 9 |
+
import traceback
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import List, Dict, Set
|
| 12 |
from huggingface_hub import snapshot_download, login
|
| 13 |
|
| 14 |
+
# -------------------------
|
| 15 |
+
# Configuração de logging
|
| 16 |
+
# -------------------------
|
| 17 |
+
VERBOSE = int(os.environ.get("VERBOSE", "1"))
|
| 18 |
+
LOG_LEVEL = logging.DEBUG if VERBOSE >= 2 else (logging.INFO if VERBOSE == 1 else logging.WARNING)
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
level=LOG_LEVEL,
|
| 21 |
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
| 22 |
+
datefmt="%H:%M:%S",
|
| 23 |
+
)
|
| 24 |
+
log = logging.getLogger("download_models")
|
| 25 |
|
| 26 |
+
# -------------------------
|
| 27 |
+
# Variáveis de ambiente
|
| 28 |
+
# -------------------------
|
| 29 |
+
CFG_PATH = os.environ.get("CONFIG_PATH", "/app/config.yaml")
|
| 30 |
+
MODELS_DIR = os.environ.get("MODELS_DIR", "/app/models")
|
| 31 |
+
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
|
| 32 |
+
MAX_WORKERS = int(os.environ.get("MAX_WORKERS", "4")) # reduzido para poupar RAM
|
| 33 |
+
FAIL_FAST = os.environ.get("FAIL_FAST", "0") == "1"
|
| 34 |
+
DRY_RUN = os.environ.get("DRY_RUN", "0") == "1"
|
| 35 |
+
SKIP_EXISTING = os.environ.get("SKIP_EXISTING", "1") == "1" # pular se dir já existir e não estiver vazio
|
| 36 |
+
INCLUDE_DIFFUSERS_DIRS = os.environ.get("INCLUDE_DIFFUSERS_DIRS", "1") == "1"
|
| 37 |
+
ALLOW_BIN = os.environ.get("ALLOW_BIN", "0") == "1" # por padrão evita .bin pesados
|
| 38 |
+
EXTRA_ALLOW = [p.strip() for p in os.environ.get("EXTRA_ALLOW_PATTERNS", "").split(",") if p.strip()]
|
| 39 |
+
EXTRA_IGNORE = [p.strip() for p in os.environ.get("EXTRA_IGNORE_PATTERNS", "").split(",") if p.strip()]
|
| 40 |
+
|
| 41 |
+
# Sugerido pelos docs para acelerar quando disponível
|
| 42 |
+
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") is None:
|
| 43 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 44 |
+
|
| 45 |
+
def bytes_to_human(n: int) -> str:
|
| 46 |
+
for unit in ["B","KB","MB","GB","TB"]:
|
| 47 |
+
if n < 1024:
|
| 48 |
+
return f"{n:.1f} {unit}"
|
| 49 |
+
n /= 1024
|
| 50 |
+
return f"{n:.1f} PB"
|
| 51 |
+
|
| 52 |
+
def dir_size(path: Path) -> int:
|
| 53 |
+
total = 0
|
| 54 |
+
if not path.exists():
|
| 55 |
+
return 0
|
| 56 |
+
for p in path.rglob("*"):
|
| 57 |
+
if p.is_file():
|
| 58 |
+
try:
|
| 59 |
+
total += p.stat().st_size
|
| 60 |
+
except Exception:
|
| 61 |
+
pass
|
| 62 |
+
return total
|
| 63 |
+
|
| 64 |
+
def disk_free(dirname: Path) -> int:
|
| 65 |
try:
|
| 66 |
+
usage = shutil.disk_usage(dirname)
|
| 67 |
+
return usage.free
|
| 68 |
except Exception:
|
| 69 |
+
return 0
|
| 70 |
|
| 71 |
+
def load_yaml(path: str) -> Dict:
|
| 72 |
+
with open(path, "r") as f:
|
| 73 |
+
return yaml.safe_load(f)
|
| 74 |
|
| 75 |
+
def collect_model_ids(cfg: Dict) -> List[str]:
|
| 76 |
+
ids: Set[str] = set()
|
| 77 |
|
| 78 |
+
def collect_from_section(d: Dict):
|
| 79 |
+
for _, v in d.items():
|
|
|
|
|
|
|
| 80 |
if isinstance(v, dict):
|
| 81 |
mid = v.get("model_id")
|
| 82 |
if isinstance(mid, str) and mid.strip():
|
| 83 |
ids.add(mid.strip())
|
| 84 |
+
|
| 85 |
+
# Primeiro nível (como no config.yaml enviado)
|
| 86 |
+
if isinstance(cfg, dict):
|
| 87 |
+
collect_from_section(cfg)
|
| 88 |
+
|
| 89 |
+
# Se existir seção "specialists"
|
| 90 |
+
if isinstance(cfg.get("specialists"), dict):
|
| 91 |
+
collect_from_section(cfg["specialists"])
|
| 92 |
+
|
| 93 |
return sorted(ids)
|
| 94 |
|
| 95 |
+
def build_patterns() -> (List[str], List[str]):
|
| 96 |
+
allow_patterns = [
|
| 97 |
+
# Pesos e índices leves
|
| 98 |
+
"*.safetensors",
|
| 99 |
+
"model.safetensors",
|
| 100 |
+
"model.safetensors.index.json",
|
| 101 |
+
"pytorch_model.bin.index.json", # pequeno, sem puxar os .bin grandes
|
| 102 |
+
# Metadados/configs
|
| 103 |
+
"*.json",
|
| 104 |
+
"config.json",
|
| 105 |
+
"generation_config.json",
|
| 106 |
+
"model_index.json",
|
| 107 |
+
"configs/*.yaml",
|
| 108 |
+
# Tokenizer
|
| 109 |
+
"tokenizer.json",
|
| 110 |
+
"tokenizer_config.json",
|
| 111 |
+
"special_tokens_map.json",
|
| 112 |
+
"vocab.json",
|
| 113 |
+
"merges.txt",
|
| 114 |
+
"*.model", # sentencepiece
|
| 115 |
+
"tokenizer/*",
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
if INCLUDE_DIFFUSERS_DIRS:
|
| 119 |
+
allow_patterns += [
|
| 120 |
+
# Estrutura típica de Diffusers
|
| 121 |
+
"unet/*",
|
| 122 |
+
"vae/*",
|
| 123 |
+
"text_encoder/*",
|
| 124 |
+
"text_encoder_2/*",
|
| 125 |
+
"scheduler/*",
|
| 126 |
+
"feature_extractor/*",
|
| 127 |
+
"processor/*",
|
| 128 |
+
"preprocessor_config.json",
|
| 129 |
+
]
|
| 130 |
+
|
| 131 |
+
# Opcionalmente permitir .bin (desativado por padrão para reduzir payload)
|
| 132 |
+
if ALLOW_BIN:
|
| 133 |
+
allow_patterns += [
|
| 134 |
+
"*.bin", # cuidado: grande
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
# Ignorar conteúdos pesados/dispensáveis
|
| 138 |
+
ignore_patterns = [
|
| 139 |
+
"*.h5",
|
| 140 |
+
"*.msgpack",
|
| 141 |
+
"*.onnx",
|
| 142 |
+
"*.npz",
|
| 143 |
+
"*.tar",
|
| 144 |
+
"*.zip",
|
| 145 |
+
"*.ckpt",
|
| 146 |
+
"*.pt",
|
| 147 |
+
"*.tflite",
|
| 148 |
+
"*.onnx_data",
|
| 149 |
+
"flax_model.msgpack",
|
| 150 |
+
"rust_model.ot",
|
| 151 |
+
# mídias e docs
|
| 152 |
+
"*.png", "*.jpg", "*.jpeg", "*.gif", "*.webp", "*.bmp", "*.svg",
|
| 153 |
+
"*.md", "README*", "LICENSE*", "docs/*", "images/*", "samples/*", "assets/*",
|
| 154 |
+
".gitattributes", ".gitignore",
|
| 155 |
+
# variantes quantizadas pesadas (ajuste conforme o caso)
|
| 156 |
+
"int8/*", "int4/*", "fp16/*",
|
| 157 |
+
]
|
| 158 |
+
|
| 159 |
+
# Extensões extras via ambiente
|
| 160 |
+
allow_patterns += EXTRA_ALLOW
|
| 161 |
+
ignore_patterns += EXTRA_IGNORE
|
| 162 |
+
|
| 163 |
+
# Deduplicar mantendo ordem
|
| 164 |
+
def dedup(seq):
|
| 165 |
+
seen = set()
|
| 166 |
+
out = []
|
| 167 |
+
for x in seq:
|
| 168 |
+
if x not in seen:
|
| 169 |
+
seen.add(x)
|
| 170 |
+
out.append(x)
|
| 171 |
+
return out
|
| 172 |
+
|
| 173 |
+
return dedup(allow_patterns), dedup(ignore_patterns)
|
| 174 |
+
|
| 175 |
+
def main():
|
| 176 |
+
start = time.time()
|
| 177 |
+
Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)
|
| 178 |
+
|
| 179 |
+
# Info de ambiente
|
| 180 |
+
log.info(f"Config: CFG_PATH={CFG_PATH} MODELS_DIR={MODELS_DIR} MAX_WORKERS={MAX_WORKERS} DRY_RUN={DRY_RUN} FAIL_FAST={FAIL_FAST} SKIP_EXISTING={SKIP_EXISTING} VERBOSE={VERBOSE}") # noqa
|
| 181 |
+
log.info(f"HF_HUB_ENABLE_HF_TRANSFER={os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}") # noqa
|
| 182 |
+
|
| 183 |
+
# Login opcional com token
|
| 184 |
+
if HF_TOKEN:
|
| 185 |
+
try:
|
| 186 |
+
login(token=HF_TOKEN)
|
| 187 |
+
log.info("Autenticado no Hugging Face Hub (token fornecido).")
|
| 188 |
+
except Exception as e:
|
| 189 |
+
log.warning(f"Falha no login HF: {e}")
|
| 190 |
+
|
| 191 |
+
# Ler config.yaml e coletar model_ids
|
| 192 |
+
cfg = load_yaml(CFG_PATH)
|
| 193 |
+
model_ids = collect_model_ids(cfg)
|
| 194 |
+
if not model_ids:
|
| 195 |
+
log.warning("Nenhum model_id encontrado no config.yaml; nada a baixar.")
|
| 196 |
+
return
|
| 197 |
+
|
| 198 |
+
allow_patterns, ignore_patterns = build_patterns()
|
| 199 |
+
log.info("Allow patterns:")
|
| 200 |
+
for p in allow_patterns:
|
| 201 |
+
log.info(f" - {p}")
|
| 202 |
+
log.info("Ignore patterns:")
|
| 203 |
+
for p in ignore_patterns:
|
| 204 |
+
log.info(f" - {p}")
|
| 205 |
+
|
| 206 |
+
# Relatório de disco
|
| 207 |
+
free_before = disk_free(Path(MODELS_DIR))
|
| 208 |
+
log.info(f"Espaço livre antes: {bytes_to_human(free_before)}")
|
| 209 |
+
|
| 210 |
+
downloaded: Dict[str, Dict] = {}
|
| 211 |
+
errors: Dict[str, str] = {}
|
| 212 |
+
|
| 213 |
+
for mid in model_ids:
|
| 214 |
+
safe_dir = mid.replace("/", "__")
|
| 215 |
+
local_dir = Path(MODELS_DIR) / safe_dir
|
| 216 |
+
|
| 217 |
+
if SKIP_EXISTING and local_dir.exists():
|
| 218 |
+
try:
|
| 219 |
+
non_empty = any(local_dir.iterdir())
|
| 220 |
+
except Exception:
|
| 221 |
+
non_empty = False
|
| 222 |
+
if non_empty:
|
| 223 |
+
log.info(f"[skip] {mid} -> {local_dir} já existe e não está vazio.")
|
| 224 |
+
downloaded[mid] = {
|
| 225 |
+
"local_dir": str(local_dir),
|
| 226 |
+
"skipped": True,
|
| 227 |
+
"size_bytes": dir_size(local_dir),
|
| 228 |
+
}
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
log.info(f"[start] {mid} -> {local_dir}")
|
| 232 |
+
if DRY_RUN:
|
| 233 |
+
downloaded[mid] = {
|
| 234 |
+
"local_dir": str(local_dir),
|
| 235 |
+
"dry_run": True,
|
| 236 |
+
"size_bytes": 0,
|
| 237 |
+
}
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
try:
|
| 241 |
+
t0 = time.time()
|
| 242 |
+
out_dir = snapshot_download(
|
| 243 |
+
repo_id=mid,
|
| 244 |
+
local_dir=str(local_dir),
|
| 245 |
+
local_dir_use_symlinks=False,
|
| 246 |
+
resume_download=True,
|
| 247 |
+
token=HF_TOKEN,
|
| 248 |
+
max_workers=MAX_WORKERS,
|
| 249 |
+
allow_patterns=allow_patterns,
|
| 250 |
+
ignore_patterns=ignore_patterns,
|
| 251 |
+
)
|
| 252 |
+
elapsed = time.time() - t0
|
| 253 |
+
size_b = dir_size(Path(out_dir))
|
| 254 |
+
log.info(f"[done] {mid} baixado em {elapsed:.1f}s | tamanho {bytes_to_human(size_b)} | destino {out_dir}")
|
| 255 |
+
downloaded[mid] = {
|
| 256 |
+
"local_dir": out_dir,
|
| 257 |
+
"elapsed_sec": elapsed,
|
| 258 |
+
"size_bytes": size_b,
|
| 259 |
+
}
|
| 260 |
+
except Exception as e:
|
| 261 |
+
tb = traceback.format_exc(limit=1)
|
| 262 |
+
msg = f"Erro ao baixar {mid}: {e} | {tb}"
|
| 263 |
+
log.error(msg)
|
| 264 |
+
errors[mid] = str(e)
|
| 265 |
+
if FAIL_FAST:
|
| 266 |
+
break
|
| 267 |
+
|
| 268 |
+
free_after = disk_free(Path(MODELS_DIR))
|
| 269 |
+
log.info(f"Espaço livre depois: {bytes_to_human(free_after)}")
|
| 270 |
+
|
| 271 |
+
report = {
|
| 272 |
+
"models_requested": model_ids,
|
| 273 |
+
"downloaded": downloaded,
|
| 274 |
+
"errors": errors,
|
| 275 |
+
"free_before": free_before,
|
| 276 |
+
"free_after": free_after,
|
| 277 |
+
"elapsed_total_sec": time.time() - start,
|
| 278 |
+
"patterns": {
|
| 279 |
+
"allow": allow_patterns,
|
| 280 |
+
"ignore": ignore_patterns,
|
| 281 |
+
},
|
| 282 |
+
"env": {
|
| 283 |
+
"MAX_WORKERS": MAX_WORKERS,
|
| 284 |
+
"ALLOW_BIN": ALLOW_BIN,
|
| 285 |
+
"INCLUDE_DIFFUSERS_DIRS": INCLUDE_DIFFUSERS_DIRS,
|
| 286 |
+
"SKIP_EXISTING": SKIP_EXISTING,
|
| 287 |
+
"DRY_RUN": DRY_RUN,
|
| 288 |
+
"VERBOSE": VERBOSE,
|
| 289 |
+
},
|
| 290 |
+
}
|
| 291 |
+
print(json.dumps(report, indent=2, ensure_ascii=False))
|
| 292 |
+
|
| 293 |
+
if errors and FAIL_FAST:
|
| 294 |
+
sys.exit(1)
|
| 295 |
+
|
| 296 |
+
if __name__ == "__main__":
|
| 297 |
+
main()
|