carlex3321 commited on
Commit
82e33b7
·
verified ·
1 Parent(s): e9eef92

Update scripts/download_models.py

Browse files
Files changed (1) hide show
  1. scripts/download_models.py +283 -31
scripts/download_models.py CHANGED
@@ -1,45 +1,297 @@
1
- # scripts/download_models.py
2
- import os, yaml
 
 
 
 
 
 
 
 
 
3
  from huggingface_hub import snapshot_download, login
4
 
5
- cfg_path = os.environ.get("CONFIG_PATH", "config.yaml")
6
- models_dir = os.environ.get("MODELS_DIR", "/app/models")
7
- hf_token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
 
 
 
 
 
 
 
 
8
 
9
- if hf_token:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  try:
11
- login(token=hf_token)
 
12
  except Exception:
13
- pass
14
 
15
- os.makedirs(models_dir, exist_ok=True)
 
 
16
 
17
- with open(cfg_path, "r") as f:
18
- data = yaml.safe_load(f)
19
 
20
- def collect_ids(d):
21
- ids = set()
22
- if isinstance(d, dict):
23
- for k, v in d.items():
24
  if isinstance(v, dict):
25
  mid = v.get("model_id")
26
  if isinstance(mid, str) and mid.strip():
27
  ids.add(mid.strip())
 
 
 
 
 
 
 
 
 
28
  return sorted(ids)
29
 
30
- # Tenta coletar no root e/ou em specialists
31
- model_ids = collect_ids(data)
32
- if not model_ids and isinstance(data.get("specialists"), dict):
33
- model_ids = collect_ids(data["specialists"])
34
-
35
- for mid in model_ids:
36
- local_dir = os.path.join(models_dir, mid.replace("/", "__"))
37
- snapshot_download(
38
- repo_id=mid,
39
- local_dir=local_dir,
40
- #local_dir_use_symlinks=False,
41
- #resume_download=True,
42
- token=hf_token
43
- )
44
-
45
- print("Downloaded:", model_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import sys
4
+ import json
5
+ import time
6
+ import yaml
7
+ import logging
8
+ import shutil
9
+ import traceback
10
+ from pathlib import Path
11
+ from typing import List, Dict, Set
12
  from huggingface_hub import snapshot_download, login
13
 
14
+ # -------------------------
15
+ # Configuração de logging
16
+ # -------------------------
17
+ VERBOSE = int(os.environ.get("VERBOSE", "1"))
18
+ LOG_LEVEL = logging.DEBUG if VERBOSE >= 2 else (logging.INFO if VERBOSE == 1 else logging.WARNING)
19
+ logging.basicConfig(
20
+ level=LOG_LEVEL,
21
+ format="%(asctime)s | %(levelname)s | %(message)s",
22
+ datefmt="%H:%M:%S",
23
+ )
24
+ log = logging.getLogger("download_models")
25
 
26
+ # -------------------------
27
+ # Variáveis de ambiente
28
+ # -------------------------
29
+ CFG_PATH = os.environ.get("CONFIG_PATH", "/app/config.yaml")
30
+ MODELS_DIR = os.environ.get("MODELS_DIR", "/app/models")
31
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_TOKEN")
32
+ MAX_WORKERS = int(os.environ.get("MAX_WORKERS", "4")) # reduzido para poupar RAM
33
+ FAIL_FAST = os.environ.get("FAIL_FAST", "0") == "1"
34
+ DRY_RUN = os.environ.get("DRY_RUN", "0") == "1"
35
+ SKIP_EXISTING = os.environ.get("SKIP_EXISTING", "1") == "1" # pular se dir já existir e não estiver vazio
36
+ INCLUDE_DIFFUSERS_DIRS = os.environ.get("INCLUDE_DIFFUSERS_DIRS", "1") == "1"
37
+ ALLOW_BIN = os.environ.get("ALLOW_BIN", "0") == "1" # por padrão evita .bin pesados
38
+ EXTRA_ALLOW = [p.strip() for p in os.environ.get("EXTRA_ALLOW_PATTERNS", "").split(",") if p.strip()]
39
+ EXTRA_IGNORE = [p.strip() for p in os.environ.get("EXTRA_IGNORE_PATTERNS", "").split(",") if p.strip()]
40
+
41
+ # Sugerido pelos docs para acelerar quando disponível
42
+ if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER") is None:
43
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
44
+
45
+ def bytes_to_human(n: int) -> str:
46
+ for unit in ["B","KB","MB","GB","TB"]:
47
+ if n < 1024:
48
+ return f"{n:.1f} {unit}"
49
+ n /= 1024
50
+ return f"{n:.1f} PB"
51
+
52
+ def dir_size(path: Path) -> int:
53
+ total = 0
54
+ if not path.exists():
55
+ return 0
56
+ for p in path.rglob("*"):
57
+ if p.is_file():
58
+ try:
59
+ total += p.stat().st_size
60
+ except Exception:
61
+ pass
62
+ return total
63
+
64
+ def disk_free(dirname: Path) -> int:
65
  try:
66
+ usage = shutil.disk_usage(dirname)
67
+ return usage.free
68
  except Exception:
69
+ return 0
70
 
71
+ def load_yaml(path: str) -> Dict:
72
+ with open(path, "r") as f:
73
+ return yaml.safe_load(f)
74
 
75
+ def collect_model_ids(cfg: Dict) -> List[str]:
76
+ ids: Set[str] = set()
77
 
78
+ def collect_from_section(d: Dict):
79
+ for _, v in d.items():
 
 
80
  if isinstance(v, dict):
81
  mid = v.get("model_id")
82
  if isinstance(mid, str) and mid.strip():
83
  ids.add(mid.strip())
84
+
85
+ # Primeiro nível (como no config.yaml enviado)
86
+ if isinstance(cfg, dict):
87
+ collect_from_section(cfg)
88
+
89
+ # Se existir seção "specialists"
90
+ if isinstance(cfg.get("specialists"), dict):
91
+ collect_from_section(cfg["specialists"])
92
+
93
  return sorted(ids)
94
 
95
+ def build_patterns() -> (List[str], List[str]):
96
+ allow_patterns = [
97
+ # Pesos e índices leves
98
+ "*.safetensors",
99
+ "model.safetensors",
100
+ "model.safetensors.index.json",
101
+ "pytorch_model.bin.index.json", # pequeno, sem puxar os .bin grandes
102
+ # Metadados/configs
103
+ "*.json",
104
+ "config.json",
105
+ "generation_config.json",
106
+ "model_index.json",
107
+ "configs/*.yaml",
108
+ # Tokenizer
109
+ "tokenizer.json",
110
+ "tokenizer_config.json",
111
+ "special_tokens_map.json",
112
+ "vocab.json",
113
+ "merges.txt",
114
+ "*.model", # sentencepiece
115
+ "tokenizer/*",
116
+ ]
117
+
118
+ if INCLUDE_DIFFUSERS_DIRS:
119
+ allow_patterns += [
120
+ # Estrutura típica de Diffusers
121
+ "unet/*",
122
+ "vae/*",
123
+ "text_encoder/*",
124
+ "text_encoder_2/*",
125
+ "scheduler/*",
126
+ "feature_extractor/*",
127
+ "processor/*",
128
+ "preprocessor_config.json",
129
+ ]
130
+
131
+ # Opcionalmente permitir .bin (desativado por padrão para reduzir payload)
132
+ if ALLOW_BIN:
133
+ allow_patterns += [
134
+ "*.bin", # cuidado: grande
135
+ ]
136
+
137
+ # Ignorar conteúdos pesados/dispensáveis
138
+ ignore_patterns = [
139
+ "*.h5",
140
+ "*.msgpack",
141
+ "*.onnx",
142
+ "*.npz",
143
+ "*.tar",
144
+ "*.zip",
145
+ "*.ckpt",
146
+ "*.pt",
147
+ "*.tflite",
148
+ "*.onnx_data",
149
+ "flax_model.msgpack",
150
+ "rust_model.ot",
151
+ # mídias e docs
152
+ "*.png", "*.jpg", "*.jpeg", "*.gif", "*.webp", "*.bmp", "*.svg",
153
+ "*.md", "README*", "LICENSE*", "docs/*", "images/*", "samples/*", "assets/*",
154
+ ".gitattributes", ".gitignore",
155
+ # variantes quantizadas pesadas (ajuste conforme o caso)
156
+ "int8/*", "int4/*", "fp16/*",
157
+ ]
158
+
159
+ # Extensões extras via ambiente
160
+ allow_patterns += EXTRA_ALLOW
161
+ ignore_patterns += EXTRA_IGNORE
162
+
163
+ # Deduplicar mantendo ordem
164
+ def dedup(seq):
165
+ seen = set()
166
+ out = []
167
+ for x in seq:
168
+ if x not in seen:
169
+ seen.add(x)
170
+ out.append(x)
171
+ return out
172
+
173
+ return dedup(allow_patterns), dedup(ignore_patterns)
174
+
175
+ def main():
176
+ start = time.time()
177
+ Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)
178
+
179
+ # Info de ambiente
180
+ log.info(f"Config: CFG_PATH={CFG_PATH} MODELS_DIR={MODELS_DIR} MAX_WORKERS={MAX_WORKERS} DRY_RUN={DRY_RUN} FAIL_FAST={FAIL_FAST} SKIP_EXISTING={SKIP_EXISTING} VERBOSE={VERBOSE}") # noqa
181
+ log.info(f"HF_HUB_ENABLE_HF_TRANSFER={os.environ.get('HF_HUB_ENABLE_HF_TRANSFER')}") # noqa
182
+
183
+ # Login opcional com token
184
+ if HF_TOKEN:
185
+ try:
186
+ login(token=HF_TOKEN)
187
+ log.info("Autenticado no Hugging Face Hub (token fornecido).")
188
+ except Exception as e:
189
+ log.warning(f"Falha no login HF: {e}")
190
+
191
+ # Ler config.yaml e coletar model_ids
192
+ cfg = load_yaml(CFG_PATH)
193
+ model_ids = collect_model_ids(cfg)
194
+ if not model_ids:
195
+ log.warning("Nenhum model_id encontrado no config.yaml; nada a baixar.")
196
+ return
197
+
198
+ allow_patterns, ignore_patterns = build_patterns()
199
+ log.info("Allow patterns:")
200
+ for p in allow_patterns:
201
+ log.info(f" - {p}")
202
+ log.info("Ignore patterns:")
203
+ for p in ignore_patterns:
204
+ log.info(f" - {p}")
205
+
206
+ # Relatório de disco
207
+ free_before = disk_free(Path(MODELS_DIR))
208
+ log.info(f"Espaço livre antes: {bytes_to_human(free_before)}")
209
+
210
+ downloaded: Dict[str, Dict] = {}
211
+ errors: Dict[str, str] = {}
212
+
213
+ for mid in model_ids:
214
+ safe_dir = mid.replace("/", "__")
215
+ local_dir = Path(MODELS_DIR) / safe_dir
216
+
217
+ if SKIP_EXISTING and local_dir.exists():
218
+ try:
219
+ non_empty = any(local_dir.iterdir())
220
+ except Exception:
221
+ non_empty = False
222
+ if non_empty:
223
+ log.info(f"[skip] {mid} -> {local_dir} já existe e não está vazio.")
224
+ downloaded[mid] = {
225
+ "local_dir": str(local_dir),
226
+ "skipped": True,
227
+ "size_bytes": dir_size(local_dir),
228
+ }
229
+ continue
230
+
231
+ log.info(f"[start] {mid} -> {local_dir}")
232
+ if DRY_RUN:
233
+ downloaded[mid] = {
234
+ "local_dir": str(local_dir),
235
+ "dry_run": True,
236
+ "size_bytes": 0,
237
+ }
238
+ continue
239
+
240
+ try:
241
+ t0 = time.time()
242
+ out_dir = snapshot_download(
243
+ repo_id=mid,
244
+ local_dir=str(local_dir),
245
+ local_dir_use_symlinks=False,
246
+ resume_download=True,
247
+ token=HF_TOKEN,
248
+ max_workers=MAX_WORKERS,
249
+ allow_patterns=allow_patterns,
250
+ ignore_patterns=ignore_patterns,
251
+ )
252
+ elapsed = time.time() - t0
253
+ size_b = dir_size(Path(out_dir))
254
+ log.info(f"[done] {mid} baixado em {elapsed:.1f}s | tamanho {bytes_to_human(size_b)} | destino {out_dir}")
255
+ downloaded[mid] = {
256
+ "local_dir": out_dir,
257
+ "elapsed_sec": elapsed,
258
+ "size_bytes": size_b,
259
+ }
260
+ except Exception as e:
261
+ tb = traceback.format_exc(limit=1)
262
+ msg = f"Erro ao baixar {mid}: {e} | {tb}"
263
+ log.error(msg)
264
+ errors[mid] = str(e)
265
+ if FAIL_FAST:
266
+ break
267
+
268
+ free_after = disk_free(Path(MODELS_DIR))
269
+ log.info(f"Espaço livre depois: {bytes_to_human(free_after)}")
270
+
271
+ report = {
272
+ "models_requested": model_ids,
273
+ "downloaded": downloaded,
274
+ "errors": errors,
275
+ "free_before": free_before,
276
+ "free_after": free_after,
277
+ "elapsed_total_sec": time.time() - start,
278
+ "patterns": {
279
+ "allow": allow_patterns,
280
+ "ignore": ignore_patterns,
281
+ },
282
+ "env": {
283
+ "MAX_WORKERS": MAX_WORKERS,
284
+ "ALLOW_BIN": ALLOW_BIN,
285
+ "INCLUDE_DIFFUSERS_DIRS": INCLUDE_DIFFUSERS_DIRS,
286
+ "SKIP_EXISTING": SKIP_EXISTING,
287
+ "DRY_RUN": DRY_RUN,
288
+ "VERBOSE": VERBOSE,
289
+ },
290
+ }
291
+ print(json.dumps(report, indent=2, ensure_ascii=False))
292
+
293
+ if errors and FAIL_FAST:
294
+ sys.exit(1)
295
+
296
+ if __name__ == "__main__":
297
+ main()