|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, sys, json, shutil, time, re, traceback |
|
|
from typing import Any |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_ID = "gemma-2-2b" |
|
|
SOURCE_SET = "clt-hp" |
|
|
|
|
|
|
|
|
PROMPTS_JSON_PATH = "/content/prompts.json" |
|
|
FEATURES_JSON_PATH = "/content/features.json" |
|
|
|
|
|
|
|
|
OUT_JSON_PATH = "/content/activations_dump.json" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CHUNK_BY_LAYER = True |
|
|
|
|
|
|
|
|
|
|
|
INCLUDE_ZERO_ACTIVATIONS = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "HF_TOKEN" not in os.environ: |
|
|
print("⚠️ HF_TOKEN non trovato. Se usi modelli gated (es. Gemma), imposta la variabile d'ambiente HF_TOKEN") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(PROMPTS_JSON_PATH): |
|
|
raise FileNotFoundError(f"File prompts non trovato: {PROMPTS_JSON_PATH}") |
|
|
if not os.path.exists(FEATURES_JSON_PATH): |
|
|
raise FileNotFoundError(f"File features non trovato: {FEATURES_JSON_PATH}") |
|
|
print(f"✓ File input verificati:\n - {PROMPTS_JSON_PATH}\n - {FEATURES_JSON_PATH}") |
|
|
|
|
|
|
|
|
REPO_URL = "https://github.com/hijohnnylin/neuronpedia.git" |
|
|
REPO_DIR = "/content/neuronpedia" |
|
|
|
|
|
if not os.path.exists(REPO_DIR): |
|
|
import subprocess |
|
|
subprocess.run(["git", "clone", "-q", REPO_URL, REPO_DIR], check=True) |
|
|
|
|
|
|
|
|
sys.path.append(f"{REPO_DIR}/apps/inference") |
|
|
sys.path.append(f"{REPO_DIR}/packages/python/neuronpedia-inference-client") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from neuronpedia_inference.config import Config |
|
|
from neuronpedia_inference.shared import Model |
|
|
from neuronpedia_inference.sae_manager import SAEManager |
|
|
from neuronpedia_inference.endpoints.activation.all import ActivationProcessor |
|
|
|
|
|
from neuronpedia_inference_client.models.activation_all_post_request import ActivationAllPostRequest |
|
|
from neuronpedia_inference_client.models.activation_all_post200_response import ActivationAllPost200Response |
|
|
from neuronpedia_inference_client.models.activation_all_post200_response_activations_inner import ActivationAllPost200ResponseActivationsInner |
|
|
|
|
|
|
|
|
try: |
|
|
from transformer_lens import HookedSAETransformer |
|
|
USE_SAE_TRANSFORMER = True |
|
|
except ImportError: |
|
|
from transformer_lens import HookedTransformer |
|
|
USE_SAE_TRANSFORMER = False |
|
|
from neuronpedia_inference.shared import STR_TO_DTYPE |
|
|
|
|
|
|
|
|
import torch |
|
|
import gc |
|
|
|
|
|
|
|
|
print("Pulizia cache Hugging Face (spazio disco)...") |
|
|
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub") |
|
|
if os.path.exists(HF_CACHE_DIR): |
|
|
|
|
|
def get_dir_size(path): |
|
|
total = 0 |
|
|
try: |
|
|
for entry in os.scandir(path): |
|
|
if entry.is_file(follow_symlinks=False): |
|
|
total += entry.stat().st_size |
|
|
elif entry.is_dir(follow_symlinks=False): |
|
|
total += get_dir_size(entry.path) |
|
|
except (PermissionError, FileNotFoundError): |
|
|
pass |
|
|
return total |
|
|
|
|
|
cache_size_gb = get_dir_size(HF_CACHE_DIR) / 1024**3 |
|
|
print(f" Cache HF attuale: {cache_size_gb:.2f} GB") |
|
|
|
|
|
|
|
|
|
|
|
cleaned_gb = 0 |
|
|
for item in os.listdir(HF_CACHE_DIR): |
|
|
item_path = os.path.join(HF_CACHE_DIR, item) |
|
|
if "mntss--clt-" in item and os.path.isdir(item_path): |
|
|
item_size = get_dir_size(item_path) / 1024**3 |
|
|
print(f" Rimuovo cache SAE: {item} ({item_size:.2f} GB)") |
|
|
shutil.rmtree(item_path, ignore_errors=True) |
|
|
cleaned_gb += item_size |
|
|
|
|
|
if cleaned_gb > 0: |
|
|
print(f" ✓ Liberati {cleaned_gb:.2f} GB di spazio disco") |
|
|
else: |
|
|
print(f" Cache SAE già pulita") |
|
|
|
|
|
|
|
|
if hasattr(shutil, 'disk_usage'): |
|
|
disk = shutil.disk_usage("/") |
|
|
free_gb = disk.free / 1024**3 |
|
|
total_gb = disk.total / 1024**3 |
|
|
print(f" Spazio disco: {free_gb:.2f} GB liberi / {total_gb:.2f} GB totali") |
|
|
|
|
|
|
|
|
if free_gb < 15.0: |
|
|
print(f"\n⚠️ ATTENZIONE: Spazio disco limitato ({free_gb:.2f} GB)") |
|
|
print(f" I SAE clt-hp richiedono ~10-15 GB temporanei per layer durante il download") |
|
|
print(f" Lo script procederà comunque - i layer verranno scaricati/eliminati uno alla volta") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print("\nPulizia memoria GPU prima di iniziare...") |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
|
|
|
mem_free_gb = torch.cuda.mem_get_info()[0] / 1024**3 |
|
|
mem_total_gb = torch.cuda.mem_get_info()[1] / 1024**3 |
|
|
print(f" GPU: {torch.cuda.get_device_name(0)}") |
|
|
print(f" Memoria libera: {mem_free_gb:.2f} GB / {mem_total_gb:.2f} GB totali") |
|
|
|
|
|
|
|
|
if mem_free_gb < 4.0: |
|
|
print(f"\n⚠️ ATTENZIONE: Memoria GPU insufficiente ({mem_free_gb:.2f} GB liberi)") |
|
|
print(" Gemma-2-2b richiede almeno 4-5 GB liberi.") |
|
|
print(" SOLUZIONE: Runtime > Restart session, poi ri-esegui questa cella") |
|
|
raise RuntimeError(f"GPU memoria insufficiente: {mem_free_gb:.2f} GB liberi, servono almeno 4 GB") |
|
|
|
|
|
|
|
|
|
|
|
device_guess = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
os.environ.setdefault("MODEL_ID", MODEL_ID) |
|
|
os.environ.setdefault("SAE_SETS", json.dumps([SOURCE_SET])) |
|
|
os.environ.setdefault("DEVICE", device_guess) |
|
|
os.environ.setdefault("TOKEN_LIMIT", "4096") |
|
|
os.environ.setdefault("MODEL_DTYPE", "bfloat16" if device_guess == "cuda" else "float32") |
|
|
os.environ.setdefault("SAE_DTYPE", "float32") |
|
|
|
|
|
print(f"Configurazione environment:") |
|
|
print(f" MODEL_ID: {os.environ['MODEL_ID']}") |
|
|
print(f" DEVICE: {os.environ['DEVICE']}") |
|
|
print(f" MODEL_DTYPE: {os.environ['MODEL_DTYPE']}") |
|
|
|
|
|
|
|
|
|
|
|
Config._instance = None |
|
|
cfg = Config.__new__(Config) |
|
|
cfg.__init__( |
|
|
model_id=MODEL_ID, |
|
|
sae_sets=[SOURCE_SET], |
|
|
device=device_guess, |
|
|
model_dtype=os.environ["MODEL_DTYPE"], |
|
|
sae_dtype=os.environ["SAE_DTYPE"], |
|
|
token_limit=int(os.environ["TOKEN_LIMIT"]), |
|
|
) |
|
|
Config._instance = cfg |
|
|
|
|
|
print(f"Config inizializzato: device={cfg.device}, dtype={cfg.model_dtype}, token_limit={cfg.token_limit}") |
|
|
|
|
|
|
|
|
valid_models = cfg.get_valid_model_ids() |
|
|
if MODEL_ID not in valid_models and cfg.custom_hf_model_id not in valid_models: |
|
|
print(f"\n⚠️ WARNING: SAE set '{SOURCE_SET}' non trovato nel registry standard di SAELens") |
|
|
print(f" Modelli registrati per '{SOURCE_SET}': {valid_models if valid_models else '(nessuno)'}") |
|
|
print(f"\n Proseguo comunque - se il SAE set esiste verrà caricato dinamicamente.") |
|
|
print(f" Se ottieni errori di caricamento SAE, verifica la compatibilità MODEL_ID <-> SOURCE_SET") |
|
|
print(f"\n Combinazioni standard note:") |
|
|
print(f" - gpt2-small → res-jb") |
|
|
print(f" - gemma-2-2b → gemmascope-res-16k, gemmascope-transcoder-16k, clt-hp") |
|
|
print(f" - gemma-2-2b-it → gemmascope-res-16k, gemmascope-transcoder-16k") |
|
|
else: |
|
|
print(f"✓ Validazione: {MODEL_ID} + {SOURCE_SET} trovati nel registry SAELens") |
|
|
|
|
|
|
|
|
print(f"\nCaricamento modello {MODEL_ID} su {cfg.device}...") |
|
|
if USE_SAE_TRANSFORMER: |
|
|
print(" (usando HookedSAETransformer per ottimizzazione SAE)") |
|
|
model = HookedSAETransformer.from_pretrained( |
|
|
MODEL_ID, |
|
|
device=cfg.device, |
|
|
dtype=STR_TO_DTYPE[cfg.model_dtype], |
|
|
**cfg.model_kwargs |
|
|
) |
|
|
else: |
|
|
model = HookedTransformer.from_pretrained( |
|
|
MODEL_ID, |
|
|
device=cfg.device, |
|
|
dtype=STR_TO_DTYPE[cfg.model_dtype], |
|
|
**cfg.model_kwargs |
|
|
) |
|
|
Model.set_instance(model) |
|
|
print(f"✓ Modello caricato: {model.cfg.n_layers} layer") |
|
|
|
|
|
|
|
|
|
|
|
SAEManager._instance = None |
|
|
|
|
|
sae_mgr = SAEManager.__new__(SAEManager) |
|
|
sae_mgr.__init__(num_layers=model.cfg.n_layers, device=cfg.device) |
|
|
SAEManager._instance = sae_mgr |
|
|
|
|
|
print(f"SAEManager configurato: device={sae_mgr.device}, layers={sae_mgr.num_layers}") |
|
|
|
|
|
|
|
|
|
|
|
if CHUNK_BY_LAYER: |
|
|
print(f"⚠️ SAE set '{SOURCE_SET}' - caricamento on-demand per risparmiare memoria") |
|
|
print(f" (i layer verranno caricati uno alla volta quando necessario)") |
|
|
|
|
|
sae_mgr.setup_neuron_layers() |
|
|
|
|
|
from neuronpedia_inference.config import get_saelens_neuronpedia_directory_df, config_to_json |
|
|
directory_df = get_saelens_neuronpedia_directory_df() |
|
|
config_json = config_to_json(directory_df, selected_sets_neuronpedia=[SOURCE_SET], selected_model=MODEL_ID) |
|
|
for sae_set in config_json: |
|
|
sae_mgr.valid_sae_sets.append(sae_set["set"]) |
|
|
sae_mgr.sae_set_to_saes[sae_set["set"]] = sae_set["saes"] |
|
|
print(f"✓ SAE manager pronto (on-demand loading)") |
|
|
else: |
|
|
print(f"Caricamento completo SAE set '{SOURCE_SET}'...") |
|
|
sae_mgr.load_saes() |
|
|
print(f"✓ SAE manager pronto") |
|
|
|
|
|
|
|
|
|
|
|
def load_prompts(path: str) -> list[dict]: |
|
|
""" |
|
|
Accetta: |
|
|
- lista di stringhe: ["text1", "text2", ...] |
|
|
- lista di oggetti: [{"id": "...", "text": "..."}, ...] |
|
|
- oggetto con chiave "prompts": come sopra |
|
|
Normalizza in lista di dict: [{"id": str, "text": str}, ...] |
|
|
""" |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
|
|
|
if isinstance(data, dict) and "prompts" in data: |
|
|
data = data["prompts"] |
|
|
out = [] |
|
|
if isinstance(data, list): |
|
|
for i, item in enumerate(data): |
|
|
if isinstance(item, str): |
|
|
out.append({"id": f"p{i}", "text": item}) |
|
|
elif isinstance(item, dict): |
|
|
|
|
|
text = item.get("text", item.get("prompt", None)) |
|
|
if not isinstance(text, str): |
|
|
raise ValueError(f"Prompt #{i} non valido: {item}") |
|
|
pid = str(item.get("id", f"p{i}")) |
|
|
out.append({"id": pid, "text": text}) |
|
|
else: |
|
|
raise ValueError(f"Elemento prompt non riconosciuto: {type(item)}") |
|
|
else: |
|
|
raise ValueError("Formato prompts.json non valido") |
|
|
return out |
|
|
|
|
|
def load_features(path: str, source_set: str) -> list[dict]: |
|
|
""" |
|
|
Accetta: |
|
|
- lista di oggetti: [{"source":"L-source_set","index":int}, ...] |
|
|
- oppure [{"layer":int,"index":int}, ...] -> converto a {"source": f"{layer}-{source_set}", "index": int} |
|
|
- oppure oggetto {"features":[...]} come sopra |
|
|
Verifica che tutte le source abbiano il suffisso == source_set (coerenza col request all.py). |
|
|
""" |
|
|
with open(path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
if isinstance(data, dict) and "features" in data: |
|
|
data = data["features"] |
|
|
if not isinstance(data, list): |
|
|
raise ValueError("Formato features.json non valido") |
|
|
norm = [] |
|
|
for i, item in enumerate(data): |
|
|
if not isinstance(item, dict): |
|
|
raise ValueError(f"feature #{i}: atteso oggetto, trovato {type(item)}") |
|
|
if "source" in item and "index" in item: |
|
|
source = str(item["source"]) |
|
|
idx = int(item["index"]) |
|
|
|
|
|
|
|
|
if "-" not in source: |
|
|
|
|
|
m = re.search(r"(\d+)", source) |
|
|
if not m: |
|
|
raise ValueError(f"source non riconosciuta: {source}") |
|
|
layer = int(m.group(1)) |
|
|
source = f"{layer}-{source_set}" |
|
|
else: |
|
|
suff = source.split("-", 1)[1] |
|
|
if suff != source_set: |
|
|
raise ValueError(f"feature #{i}: source_set '{suff}' != atteso '{source_set}'") |
|
|
norm.append({"source": source, "index": idx}) |
|
|
elif "layer" in item and "index" in item: |
|
|
layer = int(item["layer"]) |
|
|
idx = int(item["index"]) |
|
|
norm.append({"source": f"{layer}-{source_set}", "index": idx}) |
|
|
else: |
|
|
raise ValueError(f"feature #{i}: campi attesi ('source','index') o ('layer','index')") |
|
|
return norm |
|
|
|
|
|
|
|
|
|
|
|
def run_all_for_prompt(prompt_text: str, wanted_features: list[dict]) -> dict[str, Any]: |
|
|
""" |
|
|
Costruisce un ActivationAllPostRequest come in all.py: |
|
|
- selected_sources = tutti i layer presenti nelle feature |
|
|
- niente feature_filter cross-layer (filtriamo dopo) |
|
|
Esegue ActivationProcessor.process_activations() e poi filtra le feature richieste. |
|
|
Ritorna dict con {"tokens":[...], "counts":[[...]], "activations":[...solo richieste...] }. |
|
|
""" |
|
|
proc = ActivationProcessor() |
|
|
|
|
|
|
|
|
layers = sorted({int(f["source"].split("-")[0]) for f in wanted_features}) |
|
|
selected_sources = [f"{L}-{SOURCE_SET}" for L in layers] |
|
|
|
|
|
|
|
|
req = ActivationAllPostRequest( |
|
|
prompt=prompt_text, |
|
|
model=MODEL_ID, |
|
|
source_set=SOURCE_SET, |
|
|
selected_sources=selected_sources, |
|
|
ignore_bos=False, |
|
|
sort_by_token_indexes=[], |
|
|
num_results=100_000 |
|
|
) |
|
|
|
|
|
resp: ActivationAllPost200Response = proc.process_activations(req) |
|
|
|
|
|
|
|
|
want = {(f["source"], int(f["index"])) for f in wanted_features} |
|
|
found_features = set() |
|
|
filtered = [] |
|
|
for a in resp.activations: |
|
|
src = a.source |
|
|
idx = int(a.index) |
|
|
if (src, idx) in want: |
|
|
obj = { |
|
|
"source": src, |
|
|
"index": idx, |
|
|
"values": list(a.values), |
|
|
"sum_values": float(a.sum_values) if a.sum_values is not None else None, |
|
|
"max_value": float(a.max_value), |
|
|
"max_value_index": int(a.max_value_index), |
|
|
} |
|
|
if getattr(a, "dfa_values", None) is not None: |
|
|
obj["dfa_values"] = list(a.dfa_values) |
|
|
obj["dfa_target_index"] = int(a.dfa_target_index) |
|
|
obj["dfa_max_value"] = float(a.dfa_max_value) |
|
|
filtered.append(obj) |
|
|
found_features.add((src, idx)) |
|
|
|
|
|
|
|
|
if INCLUDE_ZERO_ACTIVATIONS: |
|
|
num_tokens = len(resp.tokens) |
|
|
for f in wanted_features: |
|
|
src = f["source"] |
|
|
idx = int(f["index"]) |
|
|
if (src, idx) not in found_features: |
|
|
obj = { |
|
|
"source": src, |
|
|
"index": idx, |
|
|
"values": [0.0] * num_tokens, |
|
|
"sum_values": 0.0, |
|
|
"max_value": 0.0, |
|
|
"max_value_index": 0, |
|
|
} |
|
|
filtered.append(obj) |
|
|
|
|
|
return { |
|
|
"tokens": list(resp.tokens), |
|
|
"counts": [[float(x) for x in row] for row in resp.counts], |
|
|
"activations": filtered |
|
|
} |
|
|
|
|
|
def run_per_layer_for_prompt(prompt_text: str, wanted_features: list[dict]) -> dict[str, Any]: |
|
|
""" |
|
|
Variante chunking: processa un layer per volta usando lo stesso ActivationProcessor. |
|
|
Identico al metodo, ma riduce memoria perché non calcola tutti i layer insieme. |
|
|
Unisce i risultati e poi filtra. |
|
|
IMPORTANTE: unload del SAE dopo ogni layer per liberare memoria GPU. |
|
|
""" |
|
|
proc = ActivationProcessor() |
|
|
sae_mgr = SAEManager.get_instance() |
|
|
|
|
|
|
|
|
by_layer: dict[int, list[dict]] = {} |
|
|
for f in wanted_features: |
|
|
L = int(f["source"].split("-")[0]) |
|
|
by_layer.setdefault(L, []).append(f) |
|
|
|
|
|
tokens_ref = None |
|
|
counts_accum = None |
|
|
activations_all = [] |
|
|
|
|
|
for idx, (L, feats) in enumerate(sorted(by_layer.items())): |
|
|
sae_id = f"{L}-{SOURCE_SET}" |
|
|
print(f" Layer {L} [{idx+1}/{len(by_layer)}]...", end=" ", flush=True) |
|
|
|
|
|
req = ActivationAllPostRequest( |
|
|
prompt=prompt_text, |
|
|
model=MODEL_ID, |
|
|
source_set=SOURCE_SET, |
|
|
selected_sources=[sae_id], |
|
|
ignore_bos=False, |
|
|
sort_by_token_indexes=[], |
|
|
num_results=100_000 |
|
|
) |
|
|
resp = proc.process_activations(req) |
|
|
|
|
|
if tokens_ref is None: |
|
|
tokens_ref = list(resp.tokens) |
|
|
|
|
|
if counts_accum is None: |
|
|
counts_accum = [[float(x) for x in row] for row in resp.counts] |
|
|
else: |
|
|
|
|
|
max_rows = max(len(counts_accum), len(resp.counts)) |
|
|
if len(counts_accum) < max_rows: |
|
|
counts_accum += [[0.0]*len(tokens_ref) for _ in range(max_rows-len(counts_accum))] |
|
|
for r in range(len(resp.counts)): |
|
|
for c in range(len(resp.counts[r])): |
|
|
counts_accum[r][c] += float(resp.counts[r][c]) |
|
|
|
|
|
|
|
|
want = {(f"{L}-{SOURCE_SET}", int(f["index"])) for f in feats} |
|
|
found_in_layer = set() |
|
|
for a in resp.activations: |
|
|
src = a.source |
|
|
idx = int(a.index) |
|
|
if (src, idx) in want: |
|
|
obj = { |
|
|
"source": src, |
|
|
"index": idx, |
|
|
"values": list(a.values), |
|
|
"sum_values": float(a.sum_values) if a.sum_values is not None else None, |
|
|
"max_value": float(a.max_value), |
|
|
"max_value_index": int(a.max_value_index), |
|
|
} |
|
|
if getattr(a, "dfa_values", None) is not None: |
|
|
obj["dfa_values"] = list(a.dfa_values) |
|
|
obj["dfa_target_index"] = int(a.dfa_target_index) |
|
|
obj["dfa_max_value"] = float(a.dfa_max_value) |
|
|
activations_all.append(obj) |
|
|
found_in_layer.add((src, idx)) |
|
|
|
|
|
|
|
|
if INCLUDE_ZERO_ACTIVATIONS and tokens_ref: |
|
|
for f in feats: |
|
|
src = f"{L}-{SOURCE_SET}" |
|
|
idx = int(f["index"]) |
|
|
if (src, idx) not in found_in_layer: |
|
|
obj = { |
|
|
"source": src, |
|
|
"index": idx, |
|
|
"values": [0.0] * len(tokens_ref), |
|
|
"sum_values": 0.0, |
|
|
"max_value": 0.0, |
|
|
"max_value_index": 0, |
|
|
} |
|
|
activations_all.append(obj) |
|
|
|
|
|
|
|
|
if sae_id in sae_mgr.loaded_saes: |
|
|
sae_mgr.unload_sae(sae_id) |
|
|
|
|
|
|
|
|
|
|
|
if SOURCE_SET == "clt-hp": |
|
|
try: |
|
|
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub") |
|
|
for item in os.listdir(HF_CACHE_DIR): |
|
|
if "mntss--clt-" in item: |
|
|
item_path = os.path.join(HF_CACHE_DIR, item) |
|
|
if os.path.isdir(item_path): |
|
|
shutil.rmtree(item_path, ignore_errors=True) |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
print(f"OK (unloaded + cleaned)", flush=True) |
|
|
|
|
|
return { |
|
|
"tokens": tokens_ref or [], |
|
|
"counts": counts_accum or [], |
|
|
"activations": activations_all |
|
|
} |
|
|
|
|
|
def run_layer_by_layer_all_prompts(prompts: list[dict], wanted_features: list[dict]) -> list[dict]: |
|
|
""" |
|
|
OTTIMIZZAZIONE: processa tutti i prompt layer-by-layer invece di prompt-by-prompt. |
|
|
Questo minimizza i re-download dei SAE (ogni layer viene scaricato 1 sola volta). |
|
|
|
|
|
Strategia: |
|
|
1. Raggruppa features per layer |
|
|
2. Per ogni layer: |
|
|
- Scarica il SAE una sola volta |
|
|
- Processa TUTTI i prompt con quel layer |
|
|
- Scarica il SAE e pulisce la cache |
|
|
3. Riorganizza i risultati per prompt |
|
|
""" |
|
|
proc = ActivationProcessor() |
|
|
sae_mgr = SAEManager.get_instance() |
|
|
|
|
|
|
|
|
by_layer: dict[int, list[dict]] = {} |
|
|
for f in wanted_features: |
|
|
L = int(f["source"].split("-")[0]) |
|
|
by_layer.setdefault(L, []).append(f) |
|
|
|
|
|
|
|
|
|
|
|
results_by_prompt: dict[str, dict] = {p["id"]: {"tokens": None, "counts": None, "activations": []} for p in prompts} |
|
|
|
|
|
layers_sorted = sorted(by_layer.keys()) |
|
|
print(f"\n⚡ OTTIMIZZAZIONE: processando {len(layers_sorted)} layer per {len(prompts)} prompt") |
|
|
print(f" (ogni layer viene scaricato 1 sola volta)\n") |
|
|
|
|
|
for idx, L in enumerate(layers_sorted, 1): |
|
|
sae_id = f"{L}-{SOURCE_SET}" |
|
|
feats = by_layer[L] |
|
|
|
|
|
print(f" Layer {L} [{idx}/{len(layers_sorted)}] - processando {len(prompts)} prompt...", end=" ", flush=True) |
|
|
|
|
|
|
|
|
for p in prompts: |
|
|
pid, text = p["id"], p["text"] |
|
|
|
|
|
req = ActivationAllPostRequest( |
|
|
prompt=text, |
|
|
model=MODEL_ID, |
|
|
source_set=SOURCE_SET, |
|
|
selected_sources=[sae_id], |
|
|
ignore_bos=False, |
|
|
sort_by_token_indexes=[], |
|
|
num_results=100_000 |
|
|
) |
|
|
resp = proc.process_activations(req) |
|
|
|
|
|
|
|
|
if results_by_prompt[pid]["tokens"] is None: |
|
|
results_by_prompt[pid]["tokens"] = list(resp.tokens) |
|
|
results_by_prompt[pid]["counts"] = [[float(x) for x in row] for row in resp.counts] |
|
|
|
|
|
|
|
|
want = {(f"{L}-{SOURCE_SET}", int(f["index"])) for f in feats} |
|
|
found_features = set() |
|
|
|
|
|
for a in resp.activations: |
|
|
src = a.source |
|
|
idx_feat = int(a.index) |
|
|
if (src, idx_feat) in want: |
|
|
obj = { |
|
|
"source": src, |
|
|
"index": idx_feat, |
|
|
"values": list(a.values), |
|
|
"sum_values": float(a.sum_values) if a.sum_values is not None else None, |
|
|
"max_value": float(a.max_value), |
|
|
"max_value_index": int(a.max_value_index), |
|
|
} |
|
|
if getattr(a, "dfa_values", None) is not None: |
|
|
obj["dfa_values"] = list(a.dfa_values) |
|
|
obj["dfa_target_index"] = int(a.dfa_target_index) |
|
|
obj["dfa_max_value"] = float(a.dfa_max_value) |
|
|
results_by_prompt[pid]["activations"].append(obj) |
|
|
found_features.add((src, idx_feat)) |
|
|
|
|
|
|
|
|
if INCLUDE_ZERO_ACTIVATIONS: |
|
|
num_tokens = len(results_by_prompt[pid]["tokens"]) if results_by_prompt[pid]["tokens"] else 0 |
|
|
for f in feats: |
|
|
src = f"{L}-{SOURCE_SET}" |
|
|
idx_feat = int(f["index"]) |
|
|
if (src, idx_feat) not in found_features and num_tokens > 0: |
|
|
|
|
|
obj = { |
|
|
"source": src, |
|
|
"index": idx_feat, |
|
|
"values": [0.0] * num_tokens, |
|
|
"sum_values": 0.0, |
|
|
"max_value": 0.0, |
|
|
"max_value_index": 0, |
|
|
} |
|
|
results_by_prompt[pid]["activations"].append(obj) |
|
|
|
|
|
|
|
|
if sae_id in sae_mgr.loaded_saes: |
|
|
sae_mgr.unload_sae(sae_id) |
|
|
|
|
|
if SOURCE_SET == "clt-hp": |
|
|
try: |
|
|
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub") |
|
|
for item in os.listdir(HF_CACHE_DIR): |
|
|
if "mntss--clt-" in item: |
|
|
item_path = os.path.join(HF_CACHE_DIR, item) |
|
|
if os.path.isdir(item_path): |
|
|
shutil.rmtree(item_path, ignore_errors=True) |
|
|
torch.cuda.empty_cache() |
|
|
gc.collect() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
print(f"✓ (cleaned)", flush=True) |
|
|
|
|
|
|
|
|
return [ |
|
|
{ |
|
|
"probe_id": p["id"], |
|
|
"prompt": p["text"], |
|
|
"tokens": results_by_prompt[p["id"]]["tokens"] or [], |
|
|
"counts": results_by_prompt[p["id"]]["counts"] or [], |
|
|
"activations": results_by_prompt[p["id"]]["activations"], |
|
|
} |
|
|
for p in prompts |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Caricamento input files...") |
|
|
prompts = load_prompts(PROMPTS_JSON_PATH) |
|
|
features = load_features(FEATURES_JSON_PATH, SOURCE_SET) |
|
|
print(f"✓ {len(prompts)} prompt(s), {len(features)} feature(s)") |
|
|
|
|
|
if CHUNK_BY_LAYER: |
|
|
|
|
|
|
|
|
results = run_layer_by_layer_all_prompts(prompts, features) |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Riepilogo risultati:") |
|
|
for i, res in enumerate(results, 1): |
|
|
print(f" [{i}/{len(results)}] {res['probe_id']}: {len(res['activations'])} attivazioni, {len(res['tokens'])} token") |
|
|
else: |
|
|
|
|
|
results = [] |
|
|
for i, p in enumerate(prompts, 1): |
|
|
pid, text = p["id"], p["text"] |
|
|
print(f"\n[{i}/{len(prompts)}] Processando prompt '{pid}'...") |
|
|
print(f" Text: {text[:60]}{'...' if len(text) > 60 else ''}") |
|
|
|
|
|
res = run_all_for_prompt(text, features) |
|
|
|
|
|
print(f" ✓ {len(res['activations'])} attivazioni trovate, {len(res['tokens'])} token") |
|
|
results.append({ |
|
|
"probe_id": pid, |
|
|
"prompt": text, |
|
|
"tokens": res["tokens"], |
|
|
"counts": res["counts"], |
|
|
"activations": res["activations"], |
|
|
}) |
|
|
|
|
|
out = { |
|
|
"model": MODEL_ID, |
|
|
"source_set": SOURCE_SET, |
|
|
"device": Config.get_instance().device, |
|
|
"n_prompts": len(results), |
|
|
"n_features_requested": len(features), |
|
|
"results": results, |
|
|
} |
|
|
|
|
|
with open(OUT_JSON_PATH, "w", encoding="utf-8") as f: |
|
|
json.dump(out, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"✔ COMPLETATO!") |
|
|
print(f"{'='*60}") |
|
|
print(f"File salvato: {OUT_JSON_PATH}") |
|
|
print(f"Statistiche:") |
|
|
print(f" - Prompt processati: {len(results)}") |
|
|
print(f" - Features richieste: {len(features)}") |
|
|
print(f" - Modello: {MODEL_ID}") |
|
|
print(f" - SAE set: {SOURCE_SET}") |
|
|
print(f" - Device: {out['device']}") |
|
|
|
|
|
|
|
|
file_size_mb = os.path.getsize(OUT_JSON_PATH) / (1024 * 1024) |
|
|
print(f" - Dimensione output: {file_size_mb:.2f} MB") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\n{'='*60}") |
|
|
print(f"✗ ERRORE!") |
|
|
print(f"{'='*60}") |
|
|
print(f"Messaggio: {e}") |
|
|
print(f"\nStack trace completo:") |
|
|
traceback.print_exc() |
|
|
print(f"{'='*60}") |
|
|
|