peppinob-ol
Initial deployment: Attribution Graph Probing app
cb8a7e5
# ========================= Neuronpedia Activations: multi-prompt one-shot (Colab) =========================
# COME USARE IN COLAB:
# 1. IMPORTANTE - Setup runtime:
# - Vai a Runtime > Change runtime type > Hardware accelerator: T4 GPU (o superiore)
# - Se ri-esegui questo codice, prima fai: Runtime > Restart session (per liberare GPU)
# 2. Carica questo file in una cella Colab (copia/incolla o usa "Upload to session storage")
# 3. Prepara i file JSON di input (vedi esempi sotto)
# 4. Esegui la cella - il codice:
# - clona automaticamente il repo neuronpedia
# - inizializza Model e SAEManager
# - processa tutti i prompt
# - salva un unico JSON con tutti i risultati
#
# INPUT FILES:
# 1) prompts.json - Formati accettati:
# ["prompt1", "prompt2", ...]
# oppure: [{"id": "p1", "text": "prompt1"}, {"id": "p2", "text": "prompt2"}]
# oppure: {"prompts": [...come sopra...]}
#
# 2) features.json - Formati accettati:
# [{"source": "10-gemmascope-res-16k", "index": 1234}, ...]
# oppure: [{"layer": 10, "index": 1234}, ...] (converte automaticamente)
# oppure: {"features": [...come sopra...]}
#
# OUTPUT:
# activations_dump.json con formato:
# {
# "model": "gemma-2-2b",
# "source_set": "gemmascope-transcoder-16k",
# "n_prompts": 5,
# "n_features_requested": 55,
# "results": [
# {
# "probe_id": "p1",
# "prompt": "testo prompt",
# "tokens": [...],
# "counts": [[...]],
# "activations": [{"source": "10-...", "index": 123, "values": [...], ...}, ...]
# },
# ...
# ]
# }
#
# NOTA IMPORTANTE:
# - Con INCLUDE_ZERO_ACTIVATIONS=True: include TUTTE le feature richieste (anche con valore 0)
# - Con INCLUDE_ZERO_ACTIVATIONS=False: include SOLO le feature che si sono attivate (più compatto)
# - Per analisi complete è consigliato True, per visualizzazioni è consigliato False
#
# NOTA: Usa la stessa pipeline di /activation/all (ActivationProcessor), garantendo coerenza.
# =========================================================================================================
import os, sys, json, shutil, time, re, traceback
from typing import Any
# --------------------------- CONFIGURAZIONE BASE (EDITA QUI) ---------------------------------------------
# Modello e SAE set (devono essere compatibili!)
# COMPATIBILITÀ PRINCIPALI:
# gpt2-small → "res-jb"
# gemma-2-2b → "gemmascope-res-16k", "gemmascope-transcoder-16k", "clt-hp" (Circuit Tracer)
# gemma-2-2b-it → "gemmascope-res-16k", "gemmascope-transcoder-16k"
# NOTA: altri SAE set potrebbero funzionare anche se non nel registry standard di SAELens
# RIFERIMENTI: https://www.neuronpedia.org/transcoders-hp (il set si chiama "clt-hp", non "transcoders-hp")
MODEL_ID = "gemma-2-2b" # "gpt2-small" | "gemma-2-2b" | "gemma-2-2b-it" | ecc.
SOURCE_SET = "clt-hp" # "res-jb" | "gemmascope-res-16k" | "gemmascope-transcoder-16k" | "clt-hp"
# File di input (caricali su Colab/Drive; vedi schema sopra e example_*.json)
PROMPTS_JSON_PATH = "/content/prompts.json"
FEATURES_JSON_PATH = "/content/features.json"
# File di output
OUT_JSON_PATH = "/content/activations_dump.json"
# Se True, processa layer-by-layer per tutti i prompt (ottimizzato per SAE pesanti)
# IMPORTANTE: per SAE pesanti come clt-hp, impostare True per evitare OOM
# STRATEGIA OTTIMIZZATA:
# - Scarica ogni layer UNA SOLA VOLTA
# - Processa tutti i prompt con quel layer
# - Pulisce cache HF e libera GPU
# - Passa al layer successivo
# Risultato: ~5x più veloce rispetto a processare prompt-by-prompt con re-download
CHUNK_BY_LAYER = True # ← True per clt-hp (consigliato), False per res-jb/gemmascope (leggeri)
# Se True, include nell'output anche le feature richieste con attivazione = 0
# Se False, include solo le feature che si sono effettivamente attivate (più compatto)
INCLUDE_ZERO_ACTIVATIONS = True # ← True per vedere tutte le 55 feature, False solo quelle attive
# Eventuale token HF per modelli gated (es. Gemma)
# IMPORTANTE: Non inserire mai token direttamente nel codice!
# Opzioni per fornire il token:
# 1. Variabile d'ambiente: export HF_TOKEN="your_token_here" (prima di eseguire lo script)
# 2. Colab Secrets: from google.colab import userdata; os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')
# 3. File .env locale (NON committare il file!)
# Se non è già impostato, lo script proseguirà senza token (funziona per modelli pubblici)
if "HF_TOKEN" not in os.environ:
print("⚠️ HF_TOKEN non trovato. Se usi modelli gated (es. Gemma), imposta la variabile d'ambiente HF_TOKEN")
# ---------------------------------------------------------------------------------------------------------
# ========================= verifica preliminare file input (prima di caricare modello) ===================
if not os.path.exists(PROMPTS_JSON_PATH):
raise FileNotFoundError(f"File prompts non trovato: {PROMPTS_JSON_PATH}")
if not os.path.exists(FEATURES_JSON_PATH):
raise FileNotFoundError(f"File features non trovato: {FEATURES_JSON_PATH}")
print(f"✓ File input verificati:\n - {PROMPTS_JSON_PATH}\n - {FEATURES_JSON_PATH}")
# ========================= clona repo neuronpedia e setta sys.path =======================================
REPO_URL = "https://github.com/hijohnnylin/neuronpedia.git"
REPO_DIR = "/content/neuronpedia"
if not os.path.exists(REPO_DIR):
import subprocess
subprocess.run(["git", "clone", "-q", REPO_URL, REPO_DIR], check=True)
# Percorsi per importare i moduli interni (inference) e i modelli OpenAPI (client Python)
sys.path.append(f"{REPO_DIR}/apps/inference")
sys.path.append(f"{REPO_DIR}/packages/python/neuronpedia-inference-client")
# ========================= import: fedeli a all.py =======================================================
# (vedi all.py: usa Config/Model/SAEManager e ActivationProcessor)
# NOTA: i moduli sotto non sono in locale, verranno trovati dopo il clone del repo
from neuronpedia_inference.config import Config # type: ignore
from neuronpedia_inference.shared import Model # type: ignore
from neuronpedia_inference.sae_manager import SAEManager # type: ignore
from neuronpedia_inference.endpoints.activation.all import ActivationProcessor # type: ignore
from neuronpedia_inference_client.models.activation_all_post_request import ActivationAllPostRequest # type: ignore
from neuronpedia_inference_client.models.activation_all_post200_response import ActivationAllPost200Response # type: ignore
from neuronpedia_inference_client.models.activation_all_post200_response_activations_inner import ActivationAllPost200ResponseActivationsInner # type: ignore
# Import per inizializzazione Model
try:
from transformer_lens import HookedSAETransformer # type: ignore
USE_SAE_TRANSFORMER = True
except ImportError:
from transformer_lens import HookedTransformer # type: ignore
USE_SAE_TRANSFORMER = False
from neuronpedia_inference.shared import STR_TO_DTYPE # type: ignore
# ========================= cleanup memoria GPU + DISCO (importante per SAE pesanti come clt-hp) ===========
import torch
import gc
# CLEANUP DISCO: cache Hugging Face (importante per clt-hp che scarica 8-12 GB per layer!)
print("Pulizia cache Hugging Face (spazio disco)...")
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub")
if os.path.exists(HF_CACHE_DIR):
# Calcola spazio usato prima
def get_dir_size(path):
total = 0
try:
for entry in os.scandir(path):
if entry.is_file(follow_symlinks=False):
total += entry.stat().st_size
elif entry.is_dir(follow_symlinks=False):
total += get_dir_size(entry.path)
except (PermissionError, FileNotFoundError):
pass
return total
cache_size_gb = get_dir_size(HF_CACHE_DIR) / 1024**3
print(f" Cache HF attuale: {cache_size_gb:.2f} GB")
# Rimuovi solo i modelli SAE vecchi (non il modello base che è già scaricato)
# Pattern: models--mntss--clt-* (i SAE clt-hp)
cleaned_gb = 0
for item in os.listdir(HF_CACHE_DIR):
item_path = os.path.join(HF_CACHE_DIR, item)
if "mntss--clt-" in item and os.path.isdir(item_path):
item_size = get_dir_size(item_path) / 1024**3
print(f" Rimuovo cache SAE: {item} ({item_size:.2f} GB)")
shutil.rmtree(item_path, ignore_errors=True)
cleaned_gb += item_size
if cleaned_gb > 0:
print(f" ✓ Liberati {cleaned_gb:.2f} GB di spazio disco")
else:
print(f" Cache SAE già pulita")
# Check spazio disco disponibile
if hasattr(shutil, 'disk_usage'):
disk = shutil.disk_usage("/")
free_gb = disk.free / 1024**3
total_gb = disk.total / 1024**3
print(f" Spazio disco: {free_gb:.2f} GB liberi / {total_gb:.2f} GB totali")
# Warn se poco spazio (serve almeno 15 GB per un layer clt-hp)
if free_gb < 15.0:
print(f"\n⚠️ ATTENZIONE: Spazio disco limitato ({free_gb:.2f} GB)")
print(f" I SAE clt-hp richiedono ~10-15 GB temporanei per layer durante il download")
print(f" Lo script procederà comunque - i layer verranno scaricati/eliminati uno alla volta")
# CLEANUP GPU
if torch.cuda.is_available():
print("\nPulizia memoria GPU prima di iniziare...")
torch.cuda.empty_cache()
gc.collect()
mem_free_gb = torch.cuda.mem_get_info()[0] / 1024**3
mem_total_gb = torch.cuda.mem_get_info()[1] / 1024**3
print(f" GPU: {torch.cuda.get_device_name(0)}")
print(f" Memoria libera: {mem_free_gb:.2f} GB / {mem_total_gb:.2f} GB totali")
# Check se c'è abbastanza memoria (gemma-2-2b richiede ~5-6 GB)
if mem_free_gb < 4.0:
print(f"\n⚠️ ATTENZIONE: Memoria GPU insufficiente ({mem_free_gb:.2f} GB liberi)")
print(" Gemma-2-2b richiede almeno 4-5 GB liberi.")
print(" SOLUZIONE: Runtime > Restart session, poi ri-esegui questa cella")
raise RuntimeError(f"GPU memoria insufficiente: {mem_free_gb:.2f} GB liberi, servono almeno 4 GB")
# ========================= init environment come nel server ==============================================
# (all.py si aspetta che i singleton siano configurati via env Config)
device_guess = "cuda" if torch.cuda.is_available() else "cpu"
os.environ.setdefault("MODEL_ID", MODEL_ID)
os.environ.setdefault("SAE_SETS", json.dumps([SOURCE_SET])) # lista di SAE set caricabili
os.environ.setdefault("DEVICE", device_guess)
os.environ.setdefault("TOKEN_LIMIT", "4096")
os.environ.setdefault("MODEL_DTYPE", "bfloat16" if device_guess == "cuda" else "float32")
os.environ.setdefault("SAE_DTYPE", "float32")
print(f"Configurazione environment:")
print(f" MODEL_ID: {os.environ['MODEL_ID']}")
print(f" DEVICE: {os.environ['DEVICE']}")
print(f" MODEL_DTYPE: {os.environ['MODEL_DTYPE']}")
# ========================= init Config con parametri espliciti (non legge env vars!) ======================
# IMPORTANTE: Config.__init__() ha solo defaults, dobbiamo passare i parametri esplicitamente
Config._instance = None # reset singleton
cfg = Config.__new__(Config)
cfg.__init__(
model_id=MODEL_ID,
sae_sets=[SOURCE_SET],
device=device_guess,
model_dtype=os.environ["MODEL_DTYPE"],
sae_dtype=os.environ["SAE_DTYPE"],
token_limit=int(os.environ["TOKEN_LIMIT"]),
)
Config._instance = cfg # Registra come singleton
print(f"Config inizializzato: device={cfg.device}, dtype={cfg.model_dtype}, token_limit={cfg.token_limit}")
# ========================= validazione MODEL_ID <-> SOURCE_SET =======================================
valid_models = cfg.get_valid_model_ids()
if MODEL_ID not in valid_models and cfg.custom_hf_model_id not in valid_models:
print(f"\n⚠️ WARNING: SAE set '{SOURCE_SET}' non trovato nel registry standard di SAELens")
print(f" Modelli registrati per '{SOURCE_SET}': {valid_models if valid_models else '(nessuno)'}")
print(f"\n Proseguo comunque - se il SAE set esiste verrà caricato dinamicamente.")
print(f" Se ottieni errori di caricamento SAE, verifica la compatibilità MODEL_ID <-> SOURCE_SET")
print(f"\n Combinazioni standard note:")
print(f" - gpt2-small → res-jb")
print(f" - gemma-2-2b → gemmascope-res-16k, gemmascope-transcoder-16k, clt-hp")
print(f" - gemma-2-2b-it → gemmascope-res-16k, gemmascope-transcoder-16k")
else:
print(f"✓ Validazione: {MODEL_ID} + {SOURCE_SET} trovati nel registry SAELens")
# ========================= init Model (TransformerLens) ==============================================
print(f"\nCaricamento modello {MODEL_ID} su {cfg.device}...")
if USE_SAE_TRANSFORMER:
print(" (usando HookedSAETransformer per ottimizzazione SAE)")
model = HookedSAETransformer.from_pretrained(
MODEL_ID,
device=cfg.device,
dtype=STR_TO_DTYPE[cfg.model_dtype],
**cfg.model_kwargs
)
else:
model = HookedTransformer.from_pretrained(
MODEL_ID,
device=cfg.device,
dtype=STR_TO_DTYPE[cfg.model_dtype],
**cfg.model_kwargs
)
Model.set_instance(model)
print(f"✓ Modello caricato: {model.cfg.n_layers} layer")
# ========================= init SAEManager e carica SAE ==============================================
# IMPORTANTE: inizializza SAEManager con i parametri corretti PRIMA di chiamare get_instance()
SAEManager._instance = None # reset singleton per coerenza
# Crea l'istanza manualmente con i parametri corretti invece di usare il costruttore di default
sae_mgr = SAEManager.__new__(SAEManager)
sae_mgr.__init__(num_layers=model.cfg.n_layers, device=cfg.device)
SAEManager._instance = sae_mgr # Registra come singleton
print(f"SAEManager configurato: device={sae_mgr.device}, layers={sae_mgr.num_layers}")
# IMPORTANTE: per SAE pesanti come clt-hp, NON caricare tutto subito (OOM)
# Il caricamento avverrà on-demand quando serve (tramite get_sae())
if CHUNK_BY_LAYER:
print(f"⚠️ SAE set '{SOURCE_SET}' - caricamento on-demand per risparmiare memoria")
print(f" (i layer verranno caricati uno alla volta quando necessario)")
# Setup solo metadati senza caricare i pesi
sae_mgr.setup_neuron_layers()
# Configura sae_set_to_saes per i metadati
from neuronpedia_inference.config import get_saelens_neuronpedia_directory_df, config_to_json
directory_df = get_saelens_neuronpedia_directory_df()
config_json = config_to_json(directory_df, selected_sets_neuronpedia=[SOURCE_SET], selected_model=MODEL_ID)
for sae_set in config_json:
sae_mgr.valid_sae_sets.append(sae_set["set"])
sae_mgr.sae_set_to_saes[sae_set["set"]] = sae_set["saes"]
print(f"✓ SAE manager pronto (on-demand loading)")
else:
print(f"Caricamento completo SAE set '{SOURCE_SET}'...")
sae_mgr.load_saes()
print(f"✓ SAE manager pronto")
# ========================= helper: lettura input robusta ================================================
def load_prompts(path: str) -> list[dict]:
"""
Accetta:
- lista di stringhe: ["text1", "text2", ...]
- lista di oggetti: [{"id": "...", "text": "..."}, ...]
- oggetto con chiave "prompts": come sopra
Normalizza in lista di dict: [{"id": str, "text": str}, ...]
"""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
# unwrap "prompts"
if isinstance(data, dict) and "prompts" in data:
data = data["prompts"]
out = []
if isinstance(data, list):
for i, item in enumerate(data):
if isinstance(item, str):
out.append({"id": f"p{i}", "text": item})
elif isinstance(item, dict):
# campi tollerati: "text" o "prompt" + opzionale "id"
text = item.get("text", item.get("prompt", None))
if not isinstance(text, str):
raise ValueError(f"Prompt #{i} non valido: {item}")
pid = str(item.get("id", f"p{i}"))
out.append({"id": pid, "text": text})
else:
raise ValueError(f"Elemento prompt non riconosciuto: {type(item)}")
else:
raise ValueError("Formato prompts.json non valido")
return out
def load_features(path: str, source_set: str) -> list[dict]:
"""
Accetta:
- lista di oggetti: [{"source":"L-source_set","index":int}, ...]
- oppure [{"layer":int,"index":int}, ...] -> converto a {"source": f"{layer}-{source_set}", "index": int}
- oppure oggetto {"features":[...]} come sopra
Verifica che tutte le source abbiano il suffisso == source_set (coerenza col request all.py).
"""
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, dict) and "features" in data:
data = data["features"]
if not isinstance(data, list):
raise ValueError("Formato features.json non valido")
norm = []
for i, item in enumerate(data):
if not isinstance(item, dict):
raise ValueError(f"feature #{i}: atteso oggetto, trovato {type(item)}")
if "source" in item and "index" in item:
source = str(item["source"])
idx = int(item["index"])
# controllo coerenza del suffisso con SOURCE_SET richiesto
# Esempio: "10-res-jb" → suffisso "res-jb"
if "-" not in source:
# accetto anche formati tipo "blocks.10.hook_resid_post" → estraggo layer
m = re.search(r"(\d+)", source)
if not m:
raise ValueError(f"source non riconosciuta: {source}")
layer = int(m.group(1))
source = f"{layer}-{source_set}"
else:
suff = source.split("-", 1)[1]
if suff != source_set:
raise ValueError(f"feature #{i}: source_set '{suff}' != atteso '{source_set}'")
norm.append({"source": source, "index": idx})
elif "layer" in item and "index" in item:
layer = int(item["layer"])
idx = int(item["index"])
norm.append({"source": f"{layer}-{source_set}", "index": idx})
else:
raise ValueError(f"feature #{i}: campi attesi ('source','index') o ('layer','index')")
return norm
# ========================= core: una chiamata ActivationProcessor per prompt =============================
def run_all_for_prompt(prompt_text: str, wanted_features: list[dict]) -> dict[str, Any]:
"""
Costruisce un ActivationAllPostRequest come in all.py:
- selected_sources = tutti i layer presenti nelle feature
- niente feature_filter cross-layer (filtriamo dopo)
Esegue ActivationProcessor.process_activations() e poi filtra le feature richieste.
Ritorna dict con {"tokens":[...], "counts":[[...]], "activations":[...solo richieste...] }.
"""
proc = ActivationProcessor()
# layer distinti dalle feature richieste
layers = sorted({int(f["source"].split("-")[0]) for f in wanted_features})
selected_sources = [f"{L}-{SOURCE_SET}" for L in layers]
# richieste identiche allo schema di /activation/all
req = ActivationAllPostRequest(
prompt=prompt_text,
model=MODEL_ID,
source_set=SOURCE_SET,
selected_sources=selected_sources,
ignore_bos=False,
sort_by_token_indexes=[], # puoi riempire se vuoi sommare su token specifici
num_results=100_000
)
resp: ActivationAllPost200Response = proc.process_activations(req) # metodo all.py (fedelissimo)
# filtro: tengo solo le feature richieste
want = {(f["source"], int(f["index"])) for f in wanted_features}
found_features = set()
filtered = []
for a in resp.activations:
src = a.source
idx = int(a.index)
if (src, idx) in want:
obj = {
"source": src,
"index": idx,
"values": list(a.values),
"sum_values": float(a.sum_values) if a.sum_values is not None else None,
"max_value": float(a.max_value),
"max_value_index": int(a.max_value_index),
}
if getattr(a, "dfa_values", None) is not None:
obj["dfa_values"] = list(a.dfa_values)
obj["dfa_target_index"] = int(a.dfa_target_index)
obj["dfa_max_value"] = float(a.dfa_max_value)
filtered.append(obj)
found_features.add((src, idx))
# Aggiungi feature richieste ma non attivate (valore 0) se richiesto
if INCLUDE_ZERO_ACTIVATIONS:
num_tokens = len(resp.tokens)
for f in wanted_features:
src = f["source"]
idx = int(f["index"])
if (src, idx) not in found_features:
obj = {
"source": src,
"index": idx,
"values": [0.0] * num_tokens,
"sum_values": 0.0,
"max_value": 0.0,
"max_value_index": 0,
}
filtered.append(obj)
return {
"tokens": list(resp.tokens),
"counts": [[float(x) for x in row] for row in resp.counts], # tabella attivazioni>0 per layer x token
"activations": filtered
}
def run_per_layer_for_prompt(prompt_text: str, wanted_features: list[dict]) -> dict[str, Any]:
"""
Variante chunking: processa un layer per volta usando lo stesso ActivationProcessor.
Identico al metodo, ma riduce memoria perché non calcola tutti i layer insieme.
Unisce i risultati e poi filtra.
IMPORTANTE: unload del SAE dopo ogni layer per liberare memoria GPU.
"""
proc = ActivationProcessor()
sae_mgr = SAEManager.get_instance()
# raggruppa wanted_features per layer
by_layer: dict[int, list[dict]] = {}
for f in wanted_features:
L = int(f["source"].split("-")[0])
by_layer.setdefault(L, []).append(f)
tokens_ref = None
counts_accum = None
activations_all = []
for idx, (L, feats) in enumerate(sorted(by_layer.items())):
sae_id = f"{L}-{SOURCE_SET}"
print(f" Layer {L} [{idx+1}/{len(by_layer)}]...", end=" ", flush=True)
req = ActivationAllPostRequest(
prompt=prompt_text,
model=MODEL_ID,
source_set=SOURCE_SET,
selected_sources=[sae_id], # un solo layer
ignore_bos=False,
sort_by_token_indexes=[],
num_results=100_000
)
resp = proc.process_activations(req)
if tokens_ref is None:
tokens_ref = list(resp.tokens)
# somma "counts" per layer (sono già per-layer; qui copio/riallineo)
if counts_accum is None:
counts_accum = [[float(x) for x in row] for row in resp.counts]
else:
# espandi counts_accum se necessario
max_rows = max(len(counts_accum), len(resp.counts))
if len(counts_accum) < max_rows:
counts_accum += [[0.0]*len(tokens_ref) for _ in range(max_rows-len(counts_accum))]
for r in range(len(resp.counts)):
for c in range(len(resp.counts[r])):
counts_accum[r][c] += float(resp.counts[r][c])
# prendi solo le feature desiderate di questo layer
want = {(f"{L}-{SOURCE_SET}", int(f["index"])) for f in feats}
found_in_layer = set()
for a in resp.activations:
src = a.source
idx = int(a.index)
if (src, idx) in want:
obj = {
"source": src,
"index": idx,
"values": list(a.values),
"sum_values": float(a.sum_values) if a.sum_values is not None else None,
"max_value": float(a.max_value),
"max_value_index": int(a.max_value_index),
}
if getattr(a, "dfa_values", None) is not None:
obj["dfa_values"] = list(a.dfa_values)
obj["dfa_target_index"] = int(a.dfa_target_index)
obj["dfa_max_value"] = float(a.dfa_max_value)
activations_all.append(obj)
found_in_layer.add((src, idx))
# Aggiungi feature richieste ma non attivate (valore 0) se richiesto
if INCLUDE_ZERO_ACTIVATIONS and tokens_ref:
for f in feats:
src = f"{L}-{SOURCE_SET}"
idx = int(f["index"])
if (src, idx) not in found_in_layer:
obj = {
"source": src,
"index": idx,
"values": [0.0] * len(tokens_ref),
"sum_values": 0.0,
"max_value": 0.0,
"max_value_index": 0,
}
activations_all.append(obj)
# Unload SAE per liberare memoria GPU (importante per SAE pesanti)
if sae_id in sae_mgr.loaded_saes:
sae_mgr.unload_sae(sae_id)
# PULIZIA CACHE DISCO dopo ogni layer per liberare spazio (critico per clt-hp!)
# Rimuove i file del layer appena processato dalla cache HF
if SOURCE_SET == "clt-hp":
try:
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub")
for item in os.listdir(HF_CACHE_DIR):
if "mntss--clt-" in item:
item_path = os.path.join(HF_CACHE_DIR, item)
if os.path.isdir(item_path):
shutil.rmtree(item_path, ignore_errors=True)
torch.cuda.empty_cache()
gc.collect()
except Exception:
pass # Ignora errori di cleanup
print(f"OK (unloaded + cleaned)", flush=True)
return {
"tokens": tokens_ref or [],
"counts": counts_accum or [],
"activations": activations_all
}
def run_layer_by_layer_all_prompts(prompts: list[dict], wanted_features: list[dict]) -> list[dict]:
"""
OTTIMIZZAZIONE: processa tutti i prompt layer-by-layer invece di prompt-by-prompt.
Questo minimizza i re-download dei SAE (ogni layer viene scaricato 1 sola volta).
Strategia:
1. Raggruppa features per layer
2. Per ogni layer:
- Scarica il SAE una sola volta
- Processa TUTTI i prompt con quel layer
- Scarica il SAE e pulisce la cache
3. Riorganizza i risultati per prompt
"""
proc = ActivationProcessor()
sae_mgr = SAEManager.get_instance()
# Raggruppa features per layer
by_layer: dict[int, list[dict]] = {}
for f in wanted_features:
L = int(f["source"].split("-")[0])
by_layer.setdefault(L, []).append(f)
# Dizionario per accumulare risultati per prompt
# prompt_id -> {"tokens": [...], "counts": [...], "activations": [...]}
results_by_prompt: dict[str, dict] = {p["id"]: {"tokens": None, "counts": None, "activations": []} for p in prompts}
layers_sorted = sorted(by_layer.keys())
print(f"\n⚡ OTTIMIZZAZIONE: processando {len(layers_sorted)} layer per {len(prompts)} prompt")
print(f" (ogni layer viene scaricato 1 sola volta)\n")
for idx, L in enumerate(layers_sorted, 1):
sae_id = f"{L}-{SOURCE_SET}"
feats = by_layer[L]
print(f" Layer {L} [{idx}/{len(layers_sorted)}] - processando {len(prompts)} prompt...", end=" ", flush=True)
# Processa tutti i prompt con questo layer
for p in prompts:
pid, text = p["id"], p["text"]
req = ActivationAllPostRequest(
prompt=text,
model=MODEL_ID,
source_set=SOURCE_SET,
selected_sources=[sae_id],
ignore_bos=False,
sort_by_token_indexes=[],
num_results=100_000
)
resp = proc.process_activations(req)
# Salva tokens e counts (solo la prima volta per questo prompt)
if results_by_prompt[pid]["tokens"] is None:
results_by_prompt[pid]["tokens"] = list(resp.tokens)
results_by_prompt[pid]["counts"] = [[float(x) for x in row] for row in resp.counts]
# Aggiungi activations filtrate
want = {(f"{L}-{SOURCE_SET}", int(f["index"])) for f in feats}
found_features = set() # traccia quali feature sono state trovate
for a in resp.activations:
src = a.source
idx_feat = int(a.index)
if (src, idx_feat) in want:
obj = {
"source": src,
"index": idx_feat,
"values": list(a.values),
"sum_values": float(a.sum_values) if a.sum_values is not None else None,
"max_value": float(a.max_value),
"max_value_index": int(a.max_value_index),
}
if getattr(a, "dfa_values", None) is not None:
obj["dfa_values"] = list(a.dfa_values)
obj["dfa_target_index"] = int(a.dfa_target_index)
obj["dfa_max_value"] = float(a.dfa_max_value)
results_by_prompt[pid]["activations"].append(obj)
found_features.add((src, idx_feat))
# Aggiungi feature richieste ma non attivate (valore 0) se richiesto
if INCLUDE_ZERO_ACTIVATIONS:
num_tokens = len(results_by_prompt[pid]["tokens"]) if results_by_prompt[pid]["tokens"] else 0
for f in feats:
src = f"{L}-{SOURCE_SET}"
idx_feat = int(f["index"])
if (src, idx_feat) not in found_features and num_tokens > 0:
# Feature richiesta ma non attivata - aggiungi con valori a zero
obj = {
"source": src,
"index": idx_feat,
"values": [0.0] * num_tokens,
"sum_values": 0.0,
"max_value": 0.0,
"max_value_index": 0,
}
results_by_prompt[pid]["activations"].append(obj)
# Unload SAE e pulisci cache (una sola volta dopo aver processato tutti i prompt)
if sae_id in sae_mgr.loaded_saes:
sae_mgr.unload_sae(sae_id)
if SOURCE_SET == "clt-hp":
try:
HF_CACHE_DIR = os.path.expanduser("~/.cache/huggingface/hub")
for item in os.listdir(HF_CACHE_DIR):
if "mntss--clt-" in item:
item_path = os.path.join(HF_CACHE_DIR, item)
if os.path.isdir(item_path):
shutil.rmtree(item_path, ignore_errors=True)
torch.cuda.empty_cache()
gc.collect()
except Exception:
pass
print(f"✓ (cleaned)", flush=True)
# Riorganizza risultati nell'ordine originale dei prompt
return [
{
"probe_id": p["id"],
"prompt": p["text"],
"tokens": results_by_prompt[p["id"]]["tokens"] or [],
"counts": results_by_prompt[p["id"]]["counts"] or [],
"activations": results_by_prompt[p["id"]]["activations"],
}
for p in prompts
]
# ========================= carica input, esegui per tutti i prompt, salva JSON ===========================
try:
print(f"\n{'='*60}")
print(f"Caricamento input files...")
prompts = load_prompts(PROMPTS_JSON_PATH)
features = load_features(FEATURES_JSON_PATH, SOURCE_SET)
print(f"✓ {len(prompts)} prompt(s), {len(features)} feature(s)")
if CHUNK_BY_LAYER:
# OTTIMIZZAZIONE: processa layer-by-layer per tutti i prompt insieme
# Questo minimizza i re-download (ogni layer viene scaricato 1 sola volta)
results = run_layer_by_layer_all_prompts(prompts, features)
# Stampa riepilogo per prompt
print(f"\n{'='*60}")
print(f"Riepilogo risultati:")
for i, res in enumerate(results, 1):
print(f" [{i}/{len(results)}] {res['probe_id']}: {len(res['activations'])} attivazioni, {len(res['tokens'])} token")
else:
# Metodo classico: tutti i layer insieme per ogni prompt
results = []
for i, p in enumerate(prompts, 1):
pid, text = p["id"], p["text"]
print(f"\n[{i}/{len(prompts)}] Processando prompt '{pid}'...")
print(f" Text: {text[:60]}{'...' if len(text) > 60 else ''}")
res = run_all_for_prompt(text, features)
print(f" ✓ {len(res['activations'])} attivazioni trovate, {len(res['tokens'])} token")
results.append({
"probe_id": pid,
"prompt": text,
"tokens": res["tokens"],
"counts": res["counts"],
"activations": res["activations"],
})
out = {
"model": MODEL_ID,
"source_set": SOURCE_SET,
"device": Config.get_instance().device,
"n_prompts": len(results),
"n_features_requested": len(features),
"results": results,
}
with open(OUT_JSON_PATH, "w", encoding="utf-8") as f:
json.dump(out, f, ensure_ascii=False, indent=2)
print(f"\n{'='*60}")
print(f"✔ COMPLETATO!")
print(f"{'='*60}")
print(f"File salvato: {OUT_JSON_PATH}")
print(f"Statistiche:")
print(f" - Prompt processati: {len(results)}")
print(f" - Features richieste: {len(features)}")
print(f" - Modello: {MODEL_ID}")
print(f" - SAE set: {SOURCE_SET}")
print(f" - Device: {out['device']}")
# Dimensione file
file_size_mb = os.path.getsize(OUT_JSON_PATH) / (1024 * 1024)
print(f" - Dimensione output: {file_size_mb:.2f} MB")
print(f"{'='*60}")
except Exception as e:
print(f"\n{'='*60}")
print(f"✗ ERRORE!")
print(f"{'='*60}")
print(f"Messaggio: {e}")
print(f"\nStack trace completo:")
traceback.print_exc()
print(f"{'='*60}")