carlex3321 commited on
Commit
0113d93
·
verified ·
1 Parent(s): 0f53712

Update info.sh

Browse files
Files changed (1) hide show
  1. info.sh +25 -15
info.sh CHANGED
@@ -20,17 +20,31 @@ fi
20
  echo
21
  echo "[PyTorch / CUDA backend]"
22
  python - <<'PY'
23
- import json, os, torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  info = {
25
  "torch": getattr(torch, "__version__", None),
26
  "cuda_available": torch.cuda.is_available(),
27
  "cuda_device_count": torch.cuda.device_count(),
28
  "cuda_runtime_version": getattr(torch.version, "cuda", None),
29
  "cudnn_version": torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None,
30
- "tf32": torch.backends.cuda.matmul.allow_tf32 if torch.cuda.is_available() else None,
31
- "flash_sdp": torch.backends.cuda.enable_flash_sdp if torch.cuda.is_available() else None,
32
- "mem_efficient_sdp": torch.backends.cuda.enable_mem_efficient_sdp if torch.cuda.is_available() else None,
33
- "math_sdp": torch.backends.cuda.enable_math_sdp if torch.cuda.is_available() else None,
34
  }
35
  print(json.dumps(info, indent=2))
36
  for i in range(min(torch.cuda.device_count(), 8)):
@@ -51,9 +65,8 @@ PY
51
  echo
52
  echo "[FlashAttention]"
53
  python - <<'PY'
54
- import importlib, sys
55
- mods = ["flash_attn", "flash_attn_2_cuda"]
56
- for m in mods:
57
  try:
58
  importlib.import_module(m); print(f"{m}: OK")
59
  except Exception as e:
@@ -95,11 +108,8 @@ names_to_try = [
95
  ]
96
  ok = ok_import(names_to_try)
97
  if not ok:
98
- print(" Hint: faltam kernels de layer_norm/RMSNorm do FlashAttention.")
99
- print(" Aceleração ficará reduzida; para instalar:")
100
- print(" - Rodar builder para compilar e instalar flash_attn e salvar wheel para reuso;")
101
- print(" - Ou instalar manualmente a tag compatível: Dao-AILab/flash-attention (csrc/layer_norm).")
102
- print(" Doc: https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm")
103
  PY
104
 
105
  echo
@@ -110,7 +120,7 @@ try:
110
  print("triton:", triton.__version__)
111
  try:
112
  import triton.ops as _; print("triton.ops: OK")
113
- except Exception as e:
114
  print("triton.ops: not present (ok on Triton>=3.x)")
115
  except Exception as e:
116
  print("triton: FAIL ->", e)
@@ -149,7 +159,7 @@ env | grep -E '^(CUDA_VISIBLE_DEVICES|NCCL_|TORCH_|ENABLE_.*SDP|HF_HUB_.*|CUDA_|
149
 
150
  echo
151
  echo "[Caminhos e permissões de saída]"
152
- OUT="/app/output"
153
  echo "OUT dir: $OUT"
154
  mkdir -p "$OUT"
155
  ls -la "$OUT" || true
 
20
  echo
21
  echo "[PyTorch / CUDA backend]"
22
  python - <<'PY'
23
+ import json, os, torch, inspect
24
+ def to_bool(x):
25
+ try:
26
+ if callable(x):
27
+ try:
28
+ sig = inspect.signature(x)
29
+ if len(sig.parameters)==0:
30
+ return bool(x())
31
+ except Exception:
32
+ pass
33
+ return True
34
+ return bool(x)
35
+ except Exception:
36
+ return None
37
+
38
  info = {
39
  "torch": getattr(torch, "__version__", None),
40
  "cuda_available": torch.cuda.is_available(),
41
  "cuda_device_count": torch.cuda.device_count(),
42
  "cuda_runtime_version": getattr(torch.version, "cuda", None),
43
  "cudnn_version": torch.backends.cudnn.version() if torch.backends.cudnn.is_available() else None,
44
+ "tf32": (torch.backends.cuda.matmul.allow_tf32 if torch.cuda.is_available() else None),
45
+ "flash_sdp": (to_bool(getattr(torch.backends.cuda, "enable_flash_sdp", None)) if torch.cuda.is_available() else None),
46
+ "mem_efficient_sdp": (to_bool(getattr(torch.backends.cuda, "enable_mem_efficient_sdp", None)) if torch.cuda.is_available() else None),
47
+ "math_sdp": (to_bool(getattr(torch.backends.cuda, "enable_math_sdp", None)) if torch.cuda.is_available() else None),
48
  }
49
  print(json.dumps(info, indent=2))
50
  for i in range(min(torch.cuda.device_count(), 8)):
 
65
  echo
66
  echo "[FlashAttention]"
67
  python - <<'PY'
68
+ import importlib
69
+ for m in ("flash_attn","flash_attn_2_cuda"):
 
70
  try:
71
  importlib.import_module(m); print(f"{m}: OK")
72
  except Exception as e:
 
108
  ]
109
  ok = ok_import(names_to_try)
110
  if not ok:
111
+ print(" Hint: faltam kernels LN/RMSNorm do FlashAttention (performance reduzida).")
112
+ print(" Use builder.sh para compilar flash_attn e reutilizar a wheel.")
 
 
 
113
  PY
114
 
115
  echo
 
120
  print("triton:", triton.__version__)
121
  try:
122
  import triton.ops as _; print("triton.ops: OK")
123
+ except Exception:
124
  print("triton.ops: not present (ok on Triton>=3.x)")
125
  except Exception as e:
126
  print("triton: FAIL ->", e)
 
159
 
160
  echo
161
  echo "[Caminhos e permissões de saída]"
162
+ OUT="/app/outputs"
163
  echo "OUT dir: $OUT"
164
  mkdir -p "$OUT"
165
  ls -la "$OUT" || true