from typing import Callable, Any from functools import partial import warnings import torch import torch.nn as nn """ Sparse Autoencoder (SAE) Implementation This module implements various sparse autoencoder architectures and activation functions designed to learn interpretable features in high-dimensional data. """ def normalize_data(x: torch.Tensor, eps: float = 1e-5) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Normalize input data to zero mean and unit variance. Args: x (torch.Tensor): Input tensor to normalize eps (float, optional): Small constant for numerical stability. Defaults to 1e-5. Returns: tuple: (normalized_data, mean, std) - normalized_data: Data normalized to zero mean and unit variance - mean: Mean of the original data (for denormalization) - std: Standard deviation of the original data (for denormalization) """ mu = x.mean(dim=-1, keepdim=True) x = x - mu std = x.std(dim=-1, keepdim=True) x = x / (std + eps) return x, mu, std class SoftCapping(nn.Module): """ Soft capping layer to prevent latent activations from growing excessively large. This layer applies a scaled tanh transformation that smoothly saturates values without hard truncation, helping stabilize training. Args: soft_cap (float): The scale factor for the tanh transformation """ def __init__(self, soft_cap): super(SoftCapping, self).__init__() self.soft_cap = soft_cap def forward(self, logits): """ Apply soft capping to input values. Args: logits (torch.Tensor): Input tensor Returns: torch.Tensor: Soft-capped values with range approximately [-soft_cap, soft_cap] """ return self.soft_cap * torch.tanh(logits / self.soft_cap) class TopK(nn.Module): """ Top-K activation function that only keeps the K largest activations per sample. This activation enforces sparsity by zeroing out all but the k highest values in each input vector. Can optionally use absolute values for selection and apply a subsequent activation function. Args: k (int): Number of activations to keep act_fn (Callable, optional): Secondary activation function to apply to the kept values. Defaults to nn.ReLU(). use_abs (bool, optional): If True, selection is based on absolute values. Defaults to False. """ def __init__(self, k: int, act_fn: Callable = nn.ReLU(), use_abs: bool = False) -> None: super().__init__() self.k = k self.act_fn = act_fn self.use_abs = use_abs def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass that keeps only the top-k activations for each sample. Args: x (torch.Tensor): Input tensor of shape [batch_size, features] Returns: torch.Tensor: Sparse output tensor with same shape as input, where all but the top k values (per sample) are zero """ if self.use_abs: x = torch.abs(x) # Get indices of top-k values along feature dimension _, indices = torch.topk(x, k=self.k, dim=-1) # Gather the corresponding values from the original input values = torch.gather(x, -1, indices) # Apply the activation function to the selected values activated_values = self.act_fn(values) # Create a tensor of zeros and place the activated values at the correct positions result = torch.zeros_like(x) result.scatter_(-1, indices, activated_values) # Verify sparsity constraint is met assert (result != 0.0).sum(dim=-1).max() <= self.k return result def forward_eval(self, x: torch.Tensor) -> torch.Tensor: """ Evaluation mode forward pass that doesn't enforce sparsity. Used for computing full activations during evaluation or visualization. Args: x (torch.Tensor): Input tensor Returns: torch.Tensor: Output after applying activation function (without top-k filtering) """ if self.use_abs: x = torch.abs(x) x = self.act_fn(x) return x # Mapping of activation function names to their corresponding classes ACTIVATIONS_CLASSES = { "ReLU": nn.ReLU, "Identity": nn.Identity, "TopK": partial(TopK, act_fn=nn.Identity()), "TopKReLU": partial(TopK, act_fn=nn.ReLU()), "TopKabs": partial(TopK, use_abs=True, act_fn=nn.Identity()), "TopKabsReLU": partial(TopK, use_abs=True, act_fn=nn.ReLU()), } def get_activation(activation: str) -> nn.Module: """ Factory function to create activation function instances by name. Handles special cases like parameterized activations (e.g., TopK_64). Args: activation (str): Name of the activation function, with optional parameter (e.g., "TopKReLU_64" for TopKReLU with k=64) Returns: nn.Module: Instantiated activation function """ if "_" in activation: activation, arg = activation.split("_") if "TopK" in activation: return ACTIVATIONS_CLASSES[activation](k=int(arg)) elif "JumpReLU" in activation: return ACTIVATIONS_CLASSES[activation](hidden_dim=int(arg)) return ACTIVATIONS_CLASSES[activation]() class Autoencoder(nn.Module): """ Sparse autoencoder base class. Implements the standard sparse autoencoder architecture: latents = activation(encoder(x - pre_bias) + latent_bias) recons = decoder(latents) + pre_bias Includes various options for controlling activation functions, weight initialization, and feature normalization. Attributes: n_latents (int): Number of latent features (neurons) n_inputs (int): Dimensionality of the input data tied (bool): Whether decoder weights are tied to encoder weights normalize (bool): Whether to normalize input data encoder (nn.Parameter): Encoder weight matrix [n_inputs, n_latents] decoder (nn.Parameter): Decoder weight matrix [n_latents, n_inputs] (if not tied) pre_bias (nn.Parameter): Input bias/offset [n_inputs] latent_bias (nn.Parameter): Latent bias [n_latents] activation (nn.Module): Activation function for the latent layer latents_activation_frequency (torch.Tensor): Tracks how often neurons activate """ def __init__( self, n_latents: int, n_inputs: int, activation: Callable = nn.ReLU(), tied: bool = False, normalize: bool = False, bias_init: torch.Tensor | float = 0.0, init_method: str = "kaiming", latent_soft_cap: float = 30.0, threshold: torch.Tensor | None = None, *args, **kwargs ) -> None: """ Initialize the sparse autoencoder. Args: n_latents (int): Dimension of the autoencoder latent space n_inputs (int): Dimensionality of the original data activation (Callable or str): Activation function or name tied (bool, optional): Whether to tie encoder and decoder weights. Defaults to False. normalize (bool, optional): Whether to normalize input data. Defaults to False. bias_init (torch.Tensor | float, optional): Initial bias value. Defaults to 0.0. init_method (str, optional): Weight initialization method. Defaults to "kaiming". latent_soft_cap (float, optional): Soft cap value for latent activations. Defaults to 30.0. threshold (torch.Tensor, optional): Threshold for JumpReLU. Defaults to None. """ super().__init__() if isinstance(activation, str): activation = get_activation(activation) # Store configuration self.tied = tied self.n_latents = n_latents self.n_inputs = n_inputs self.init_method = init_method self.bias_init = bias_init self.normalize = normalize # Initialize parameters self.pre_bias = nn.Parameter(torch.full((n_inputs,), bias_init) if isinstance(bias_init, float) else bias_init) self.encoder = nn.Parameter(torch.zeros((n_inputs, n_latents))) self.latent_bias = nn.Parameter(torch.zeros(n_latents,)) # For tied weights, decoder is derived from encoder if tied: self.register_parameter('decoder', None) else: self.decoder = nn.Parameter(torch.zeros((n_latents, n_inputs))) # Set up activation functions self.latent_soft_cap = SoftCapping(latent_soft_cap) if latent_soft_cap > 0 else nn.Identity() self.activation = activation self.dead_activations = activation # Initialize weights self._init_weights() # Set up activation tracking self.latents_activation_frequency: torch.Tensor self.register_buffer( "latents_activation_frequency", torch.zeros(n_latents, dtype=torch.int64, requires_grad=False) ) self.num_updates = 0 self.dead_latents = [] def get_and_reset_stats(self) -> torch.Tensor: """ Get activation statistics and reset the counters. Returns: torch.Tensor: Proportion of samples that activated each neuron """ activations = self.latents_activation_frequency.detach().cpu().float() / self.num_updates self.latents_activation_frequency.zero_() self.num_updates = 0 return activations @torch.no_grad() def _init_weights(self, norm=0.1, neuron_indices: list[int] | None = None) -> None: """ Initialize network weights. Args: norm (float, optional): Target norm for the weights. Defaults to 0.1. neuron_indices (list[int] | None, optional): Indices of neurons to initialize. If None, initialize all neurons. Raises: ValueError: If invalid initialization method is specified """ if self.init_method not in ["kaiming", "xavier", "uniform", "normal"]: raise ValueError(f"Invalid init_method: {self.init_method}") # Use transposed encoder if weights are tied if self.tied: decoder_weight = self.encoder.t() else: decoder_weight = self.decoder # Initialize with specified method if self.init_method == "kaiming": new_W_dec = (nn.init.kaiming_uniform_(torch.zeros_like(decoder_weight), nonlinearity='relu')) elif self.init_method == "xavier": new_W_dec = (nn.init.xavier_uniform_(torch.zeros_like(decoder_weight), gain=nn.init.calculate_gain('relu'))) elif self.init_method == "uniform": new_W_dec = (nn.init.uniform_(torch.zeros_like(decoder_weight), a=-1, b=1)) elif self.init_method == "normal": new_W_dec = (nn.init.normal_(torch.zeros_like(decoder_weight))) else: raise ValueError(f"Invalid init_method: {self.init_method}") # Normalize to target norm new_W_dec *= (norm / new_W_dec.norm(p=2, dim=-1, keepdim=True)) # Initialize bias to zero new_l_bias = (torch.zeros_like(self.latent_bias)) # Transpose for encoder new_W_enc = new_W_dec.t().clone() # Apply initialization to all or specific neurons if neuron_indices is None: if not self.tied: self.decoder.data = new_W_dec self.encoder.data = new_W_enc self.latent_bias.data = new_l_bias else: if not self.tied: self.decoder.data[neuron_indices] = new_W_dec[neuron_indices] self.encoder.data[:, neuron_indices] = new_W_enc[:, neuron_indices] self.latent_bias.data[neuron_indices] = new_l_bias[neuron_indices] @torch.no_grad() def project_grads_decode(self): """ Project out components of decoder gradient that would change its norm. This helps maintain normalized decoder norms during training. """ if self.tied: weights = self.encoder.data.T grad = self.encoder.grad.T else: weights = self.decoder.data grad = self.decoder.grad # Project out the component parallel to weights grad_proj = (grad * weights).sum(dim=-1, keepdim=True) * weights # Update gradients if self.tied: self.encoder.grad -= grad_proj.T else: self.decoder.grad -= grad_proj @torch.no_grad() def scale_to_unit_norm(self) -> None: """ Scale decoder rows to unit norm, and adjust other parameters accordingly. This normalization helps with feature interpretability and training stability. """ eps = torch.finfo(self.decoder.dtype).eps # Normalize tied or untied weights if self.tied: norm = self.encoder.data.T.norm(p=2, dim=-1, keepdim=True) + eps self.encoder.data.T /= norm else: norm = self.decoder.data.norm(p=2, dim=-1, keepdim=True) + eps self.decoder.data /= norm self.encoder.data *= norm.t() # Scale biases accordingly self.latent_bias.data *= norm.squeeze() def encode_pre_act(self, x: torch.Tensor) -> torch.Tensor: """ Compute pre-activation latent values. Args: x (torch.Tensor): Input data [batch, n_inputs] Returns: torch.Tensor: Pre-activation latent values [batch, n_latents] """ x = x - self.pre_bias latents_pre_act_full = x @ self.encoder + self.latent_bias return latents_pre_act_full def preprocess(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, Any]]: """ Preprocess input data, optionally normalizing it. Args: x (torch.Tensor): Input data [batch, n_inputs] Returns: tuple: (preprocessed_data, normalization_info) - preprocessed_data: Processed input data - normalization_info: Dict with normalization parameters (if normalize=True) """ if not self.normalize: return x, dict() x_processed, mu, std = normalize_data(x) return x_processed, dict(mu=mu, std=std) def encode(self, x: torch.Tensor, topk_number: int | None = None) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: """ Encode input data to latent representations. Args: x (torch.Tensor): Input data [batch, n_inputs] topk_number (int | None, optional): Number of top-k activations to keep (for inference). Defaults to None. Returns: tuple: (encoded, full_encoded, info) - encoded: Latent activations with sparsity constraints [batch, n_latents] - full_encoded: Latent activations without sparsity (for analysis) [batch, n_latents] - info: Normalization information dictionary """ x, info = self.preprocess(x) pre_encoded = self.encode_pre_act(x) encoded = self.activation(pre_encoded) # Get full activations (for analysis) depending on activation type if isinstance(self.activation, TopK): full_encoded = self.activation.forward_eval(pre_encoded) else: full_encoded = torch.clone(encoded) # Apply topk filtering for inference if requested if topk_number is not None: _, indices = torch.topk(full_encoded, k=topk_number, dim=-1) values = torch.gather(full_encoded, -1, indices) full_encoded = torch.zeros_like(full_encoded) full_encoded.scatter_(-1, indices, values) # Apply soft capping to both outputs caped_encoded = self.latent_soft_cap(encoded) capped_full_encoded = self.latent_soft_cap(full_encoded) return caped_encoded, capped_full_encoded, info def decode(self, latents: torch.Tensor, info: dict[str, Any] | None = None) -> torch.Tensor: """ Decode latent representations to reconstructed inputs. Args: latents (torch.Tensor): Latent activations [batch, n_latents] info (dict[str, Any] | None, optional): Normalization information. Defaults to None. Returns: torch.Tensor: Reconstructed input data [batch, n_inputs] """ # Decode using tied or untied weights if self.tied: ret = latents @ self.encoder.t() + self.pre_bias else: ret = latents @ self.decoder + self.pre_bias # Denormalize if needed if self.normalize: assert info is not None ret = ret * info["std"] + info["mu"] return ret @torch.no_grad() def update_latent_statistics(self, latents: torch.Tensor) -> None: """ Update statistics on latent activations. Args: latents (torch.Tensor): Latent activations [batch, n_latents] """ self.num_updates += latents.shape[0] current_activation_frequency = (latents != 0).to(torch.int64).sum(dim=0) self.latents_activation_frequency += current_activation_frequency def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass through the autoencoder. Args: x (torch.Tensor): Input data [batch, n_inputs] Returns: tuple: (recons, latents, all_recons, all_latents) - recons: Reconstructed data [batch, n_inputs] - latents: Latent activations [batch, n_latents] - all_recons: Reconstructed data without sparsity constraints (for analysis) - all_latents: Latent activations without sparsity constraints (for analysis) """ # Preprocess data x_processed, info = self.preprocess(x) # Compute pre-activations latents_pre_act = self.encode_pre_act(x_processed) # Apply activation function latents = self.activation(latents_pre_act) latents_caped = self.latent_soft_cap(latents) # Decode to reconstruction recons = self.decode(latents_caped, info) # Update activation statistics self.update_latent_statistics(latents_caped) # Handle different activation function types for analysis outputs if isinstance(self.activation, TopK): # For TopK, return both sparse and full activations all_latents = self.activation.forward_eval(latents_pre_act) all_latents_caped = self.latent_soft_cap(all_latents) all_recons = self.decode(all_latents_caped, info) return recons, latents_caped, all_recons, all_latents_caped else: # For other activations, return the same for both return recons, latents_caped, recons, latents_caped class MatryoshkaAutoencoder(Autoencoder): """ Matryoshka Sparse Autoencoder. This extends the base Autoencoder with a nested structure of latent representations, where different numbers of features can be used depending on computational budget or desired level of detail. The model uses multiple TopK activations with different k values and maintains relative importance weights for each level of the hierarchy. """ def __init__( self, n_latents: int, n_inputs: int, activation: str = "TopKReLU", tied: bool = False, normalize: bool = False, bias_init: torch.Tensor | float = 0.0, init_method: str = "kaiming", latent_soft_cap: float = 30.0, nesting_list: list[int] = [16, 32], relative_importance: list[float] | None = None, *args, **kwargs ) -> None: """ Initialize the Matryoshka Sparse Autoencoder. Args: n_latents (int): Dimension of the autoencoder latent space n_inputs (int): Dimensionality of the original data activation (str, optional): Base activation function name. Defaults to "TopKReLU". tied (bool, optional): Whether to tie encoder and decoder weights. Defaults to False. normalize (bool, optional): Whether to normalize input data. Defaults to False. bias_init (torch.Tensor | float, optional): Initial bias value. Defaults to 0.0. init_method (str, optional): Weight initialization method. Defaults to "kaiming". latent_soft_cap (float, optional): Soft cap value for latent activations. Defaults to 30.0. nesting_list (list[int], optional): List of k values for nested representations. Defaults to [16, 32]. relative_importance (list[float] | None, optional): Importance weights for each nesting level. Defaults to equal weights. """ # Initialize nesting hierarchy self.nesting_list = sorted(nesting_list) self.relative_importance = relative_importance if relative_importance is not None else [1.0] * len(nesting_list) assert len(self.relative_importance) == len(self.nesting_list) # Ensure activation is TopK-based if "TopK" not in activation: warnings.warn(f"MatryoshkaAutoencoder: activation {activation} is not a TopK activation. We are changing it to TopKReLU") activation = "TopKReLU" # Initialize with base activation base_activation = activation + f"_{self.nesting_list[0]}" super().__init__(n_latents, n_inputs, base_activation, tied, normalize, bias_init, init_method, latent_soft_cap) # Create multiple activations with different k values self.activation = nn.ModuleList( [get_activation(activation + f"_{nesting}") for nesting in self.nesting_list] ) def encode(self, x: torch.Tensor, topk_number: int | None = None) -> tuple[list[torch.Tensor], torch.Tensor, dict[str, Any]]: """ Encode input data to multiple latent representations with different sparsity levels. Args: x (torch.Tensor): Input data [batch, n_inputs] topk_number (int | None, optional): Number of top-k activations to keep (for inference). Defaults to None. Returns: tuple: (encoded_list, last_encoded, info) - encoded_list: List of latent activations with different sparsity levels - last_encoded: The least sparse latent activations (from largest k value) - info: Normalization information dictionary """ x, info = self.preprocess(x) pre_encoded = self.encode_pre_act(x) # Apply each activation function in the hierarchy encoded = [activation(pre_encoded) for activation in self.activation] caped_encoded = [self.latent_soft_cap(enc) for enc in encoded] # Apply additional top-k filtering for inference if requested if topk_number is not None: last_encoded = caped_encoded[-1] _, indices = torch.topk(last_encoded, k=topk_number, dim=-1) values = torch.gather(last_encoded, -1, indices) last_encoded = torch.zeros_like(last_encoded) last_encoded.scatter_(-1, indices, values) else: last_encoded = caped_encoded[-1] return caped_encoded, last_encoded, info def decode(self, latents: list[torch.Tensor], info: dict[str, Any] | None = None) -> list[torch.Tensor]: """ Decode multiple latent representations to reconstructions. Args: latents (list[torch.Tensor]): List of latent activations at different sparsity levels info (dict[str, Any] | None, optional): Normalization information. Defaults to None. Returns: list[torch.Tensor]: List of reconstructed inputs at different sparsity levels """ # Decode each latent representation if self.tied: ret = [latent @ self.encoder.t() + self.pre_bias for latent in latents] else: ret = [latent @ self.decoder + self.pre_bias for latent in latents] # Denormalize if needed if self.normalize: assert info is not None ret = [re * info["std"] + info["mu"] for re in ret] return ret def forward(self, x: torch.Tensor) -> tuple[list[torch.Tensor], list[torch.Tensor], torch.Tensor, torch.Tensor]: """ Forward pass through the Matryoshka autoencoder. Args: x (torch.Tensor): Input data [batch, n_inputs] Returns: tuple: (recons_list, latents_list, final_recon, final_latent) - recons_list: List of reconstructions at different sparsity levels - latents_list: List of latent activations at different sparsity levels - final_recon: Reconstruction from the largest k value - final_latent: Latent activations from the largest k value """ # Preprocess data x_processed, info = self.preprocess(x) latents_pre_act = self.encode_pre_act(x_processed) # Apply each activation in the hierarchy latents = [activation(latents_pre_act) for activation in self.activation] assert len(latents) == len(self.activation) latents_caped = [self.latent_soft_cap(latent) for latent in latents] # Decode each level recons = self.decode(latents_caped, info) assert len(recons) == len(latents) # Update activation statistics using the largest k self.update_latent_statistics(latents_caped[-1]) # Get full activations for analysis all_latents = self.activation[0].forward_eval(latents_pre_act) all_latents_caped = self.latent_soft_cap(all_latents) all_recons = self.decode([all_latents_caped], info)[0] # Return all reconstructions and the final ones return recons, latents_caped, all_recons, all_latents_caped def load_model(path): """ Load a saved sparse autoencoder model from a file. This function parses the filename to extract model configuration parameters and then loads the saved model weights. Args: path (str): Path to the saved model file (.pt) Returns: tuple: (model, data_mean_center, data_normalized, scaling_factor) - model: The loaded Autoencoder model - mean_center: Boolean indicating if data was mean-centered - target_norm: Target normalization factor for the data """ # Extract configuration from filename path_head = path.split("/")[-1] path_name = path_head[:path_head.find(".pt")] path_name_spited = path_name.split("_") n_latents = int(path_name_spited.pop(0)) n_inputs = int(path_name_spited.pop(0)) activation = path_name_spited.pop(0) if "TopK" in activation: activation += "_" + path_name_spited.pop(0) elif "ReLU" == activation: path_name_spited.pop(0) if "UW" in path_name_spited[0] or "RW" in path_name_spited[0]: path_name_spited.pop(0) tied = False if path_name_spited.pop(0) == "False" else True normalize = False if path_name_spited.pop(0) == "False" else True latent_soft_cap = float(path_name_spited.pop(0)) # Create and load the model model = Autoencoder(n_latents, n_inputs, activation, tied=tied, normalize=normalize, latent_soft_cap=latent_soft_cap) model_state_dict = torch.load(path, map_location='cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict(model_state_dict['model']) mean_center = model_state_dict['mean_center'] scaling_factor = model_state_dict['scaling_factor'] target_norm = model_state_dict['target_norm'] return model, mean_center, scaling_factor, target_norm class SAE(nn.Module): def __init__(self, path: str) -> None: """ Initialize the Sparse Autoencoder (SAE) model. Args: path (str): Path to the saved model file (.pt) """ super().__init__() self.model, mean, scaling_factor, _ = load_model(path) self.register_buffer("mean", mean.clone().detach() if isinstance(mean, torch.Tensor) else torch.tensor(mean)) self.register_buffer("scaling_factor", torch.tensor(scaling_factor)) @property def input_dim(self) -> int: """Return input dimension of the model.""" return self.model.n_inputs @property def latent_dim(self) -> int: """Return latent dimension of the model.""" return self.model.n_latents def preprocess(self, x: torch.Tensor) -> torch.Tensor: """ Preprocess input data (mean-centering and scaling). Args: x: Input tensor Returns: Preprocessed tensor """ # Mean-center and scale the input x = (x - self.mean) * self.scaling_factor return x def postprocess(self, x: torch.Tensor) -> torch.Tensor: """ Post-process output data (denormalization). Args: x: Output tensor Returns: Denormalized tensor """ # Rescale and mean-center the output x = x / self.scaling_factor + self.mean return x def encode(self, x: torch.Tensor, topk: int = -1) -> tuple[torch.Tensor, torch.Tensor]: """ Encode input data to latent representation. Args: x: Input tensor topk (int, optional): Number of top-k activations to keep. Defaults to -1 (no sparsity). Returns: Encoded latents and full latents """ # Preprocess input x = self.preprocess(x) # Validate topk constrain if topk > 0 and topk < self.model.n_latents: topk_number = topk else: topk_number = None # Encode using the model latents, full_latents, _ = self.model.encode(x, topk_number=topk_number) return latents, full_latents def decode(self, latents: torch.Tensor) -> torch.Tensor: """ Decode latent representation to input space. Args: latents: Latent tensor Returns: Reconstructed input tensor """ # Decode using the model reconstructed = self.model.decode(latents) # Postprocess output reconstructed = self.postprocess(reconstructed) return reconstructed def forward(self, x: torch.Tensor, topk: int = -1) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass through the SAE. Args: x: Input tensor topk (int, optional): Number of top-k activations to keep. Defaults to -1 (no sparsity). Returns: - Post-processed reconstructed tensor - Reconstructed tensor - Full latent activations """ # Encode to latent space _, full_latents = self.encode(x, topk=topk) # Decode back to input space reconstructed = self.model.decode(full_latents) # Postprocess output post_reconstructed = self.postprocess(reconstructed) # Return reconstructed, post_reconstructed, full_latents return post_reconstructed, reconstructed, full_latents