MSAE_ICML2025 / app.py
Vladimir Zaigrajew
Updated with new example
4e9a667
raw
history blame
10.4 kB
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import clip
from sae import SAE
import os
# --- 1. Setup and Model Loading ---
# Use GPU if available, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Define file paths for clarity
SAE_MODEL_PATH = '6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768.pth'
VOCAB_SCORES_PATH = 'Concept_Interpreter_6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768_disect_ViT-L~14_-1_text_20000_768.npy'
VOCAB_NAMES_PATH = 'clip_disect_20k.txt'
# Load models and data
try:
# Load CLIP model
model, preprocess = clip.load("ViT-L/14", device=device)
# Load Sparse Autoencoder (SAE) model
# Ensure the SAE class correctly handles moving the model to the specified device
sae_model = SAE(SAE_MODEL_PATH).to(device).eval()
# Load pre-computed vocabulary scores and names
vocab_scores = np.load(VOCAB_SCORES_PATH)
with open(VOCAB_NAMES_PATH, 'r') as f:
vocab_names = [line.strip().lower() for line in f.readlines()]
except FileNotFoundError as e:
print(f"ERROR: A required file was not found: {e.filename}")
print("Please ensure all model and vocabulary files are present in the correct paths.")
# Exit if essential files are missing
exit()
# Pre-calculate mappings for faster lookup
# For a given feature index, what is the best concept name?
feature_to_concept_score = np.max(vocab_scores, axis=0)
feature_to_concept_name_idx = np.argmax(vocab_scores, axis=0)
# For a given concept name, what is the best feature index?
concept_to_feature_score = np.max(vocab_scores, axis=1)
concept_to_feature_idx = np.argmax(vocab_scores, axis=1)
# --- 2. Helper and Core Logic Functions ---
def calculate_fvu(original_input, reconstruction):
"""Calculates the Fraction of Variance Unexplained (FVU)."""
variance = (original_input - original_input.mean(dim=-1, keepdim=True)).var(dim=-1)
recon_error_variance = (original_input - reconstruction).var(dim=-1)
# Clamp to avoid division by zero or tiny numbers
fvu_val = (recon_error_variance / (variance + 1e-8)).mean()
return fvu_val.item()
def predict(input_img, top_k, concept, neg_concept, max_strength):
"""
Main function to process an image, identify top concepts, and visualize concept manipulation.
"""
if not input_img:
raise gr.Error("Please provide an input image.")
# --- Part A: Top Concepts Analysis ---
# Preprocess the input image and move to the correct device
image_input_processed = preprocess(input_img.convert("RGB")).unsqueeze(0).to(device)
with torch.no_grad():
# Encode the image with CLIP
image_features = model.encode_image(image_input_processed).to(torch.float32)
# Get SAE reconstruction and latent activations
reconstructed_features, _, full_latents = sae_model(image_features)
fvu_score = calculate_fvu(image_features, reconstructed_features)
# Get the top K activating SAE features for the image
full_latents = full_latents.cpu().flatten()
top_sae_values, top_sae_indices = full_latents.topk(k=top_k)
# Create the bar plot for top concepts
fig_bar, ax_bar = plt.subplots(figsize=(10, 6))
concept_labels = [
f"{vocab_names[feature_to_concept_name_idx[i]]} ({feature_to_concept_score[i]:.2f})"
for i in top_sae_indices
]
ax_bar.barh(range(top_k), top_sae_values.numpy(), color='skyblue')
ax_bar.set_yticks(range(top_k))
ax_bar.set_yticklabels(concept_labels)
ax_bar.invert_yaxis() # Display top concept at the top
ax_bar.set_xlabel("SAE Feature Activation")
ax_bar.set_title(f"Top {top_k} Concepts (Concept Name (Concept Similarity Score)) with FVU: {fvu_score:.2f}")
plt.tight_layout()
# --- Part B: Concept Manipulation ---
# Validate the user-provided concept
concept = concept.lower().strip()
if concept not in vocab_names:
raise gr.Error(f"Concept '{concept}' not found in vocabulary. Please choose a valid concept.")
# Get the feature index corresponding to the chosen concept
concept_feature_id = concept_to_feature_idx[vocab_names.index(concept)]
concept_assign_score = concept_to_feature_score[vocab_names.index(concept)]
# Get the original activation strength of this concept in the input image
original_strength = full_latents[concept_feature_id].item()
# Create positive and negative text prompts
if not neg_concept:
neg_concept_prompt = f"a photo without {concept}"
else:
neg_concept_prompt = f"a photo with {neg_concept.lower().strip()}"
pos_concept_prompt = f"a photo with {concept}"
# Tokenize prompts and encode with CLIP
text_labels = clip.tokenize([pos_concept_prompt, neg_concept_prompt]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_labels)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Define the range of strengths to test
strengths = torch.linspace(0.0, max_strength, 11).to(device)
pos_concept_probs, neg_concept_probs, cos_sims = [], [], []
original_reconstructed_norm = reconstructed_features / reconstructed_features.norm(dim=-1, keepdim=True)
for st in strengths:
with torch.no_grad():
# Create a copy of latents and modify the target concept feature
modified_latents = full_latents.clone().to(device).reshape(1, -1)
modified_latents[:, concept_feature_id] = st
# Decode the modified latents back into feature space
modified_reconstructed = sae_model.model.decode(modified_latents)
# Normalize for comparison
modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True)
# Calculate similarity to the text prompts (probabilities)
probs = (100.0 * modified_reconstructed_norm @ text_features.T).softmax(dim=-1)
pos_concept_probs.append(probs[0, 0].item())
neg_concept_probs.append(probs[0, 1].item())
# Calculate cosine similarity to the original reconstructed image
cos_sims.append(
torch.nn.functional.cosine_similarity(modified_reconstructed_norm, original_reconstructed_norm).item()
)
# Create the line plot for concept manipulation
fig_line, ax_line = plt.subplots(figsize=(10, 6))
strengths_cpu = strengths.cpu().numpy()
ax_line.plot(strengths_cpu, pos_concept_probs, 'o-', label=f'"{pos_concept_prompt}"')
ax_line.plot(strengths_cpu, neg_concept_probs, 'o-', label=f'"{neg_concept_prompt}"')
# ✨ NEW: Add a vertical line indicating the original strength of the concept
ax_line.axvline(x=original_strength, color='purple', linestyle='--', label=f'Original Strength ({original_strength:.2f})')
# Add cosine similarity on a secondary y-axis
ax2 = ax_line.twinx()
ax2.plot(strengths_cpu, cos_sims, 'x-', color='green', label='Similarity to Original')
ax2.set_ylabel('Cosine Similarity', color='green')
ax2.tick_params(axis='y', labelcolor='green')
ax_line.set_xlabel("Magnitude of the Concept SAE Feature")
ax_line.set_ylabel("CLIP Probability")
ax_line.set_title(f"Effect of Modifying Concept '{concept}' (Assignment Score: {concept_assign_score:.2f})")
fig_line.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax_line.transAxes)
plt.tight_layout()
# Close figures to free memory
plt.close(fig_bar)
plt.close(fig_line)
return input_img, fig_bar, fig_line
# --- 3. Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSAE) Example") as demo:
gr.Markdown(
"Based on the paper: [Interpreting CLIP with Hierarchical Sparse Autoencoders](https://openreview.net/forum?id=5MQQsenQBm) with [github code](https://github.com/WolodjaZ/MSAE). "
"Upload an image to see its top activating concepts from a sparse autoencoder. Then, choose a concept (from `clip_disect_20k.txt`) to visualize how manipulating its corresponding concept magnitude affects the image representation."
)
gr.Examples(
examples=[
["./bird.jpg", 10, "birds", "", 10.0],
["./statue.jpg", 10, "statue", "humans", 10.0],
],
# NOTE: You might need to create placeholder images 'bird.jpg' and 'statue.jpg'
# in your directory for the examples to load correctly.
inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider],
outputs=[output_image, output_bar_plot, output_line_plot],
fn=predict,
cache_examples=True # Set to True for faster loading on HF Spaces
)
with gr.Row():
with gr.Column(scale=1):
# Input controls
image_input = gr.Image(label="Input Image", sources=['upload', 'webcam'], type="pil")
gr.Markdown("### Analysis & Manipulation Controls")
top_k_slider = gr.Slider(minimum=3, maximum=20, value=10, step=1, label="Numb of Top K Concepts to visualize")
concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair")
neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown")
max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength")
submit_btn = gr.Button("Analyze and Interpret", variant="primary")
with gr.Column(scale=2):
# Output displays
gr.Markdown("### Results")
output_image = gr.Image(label="Original Image", interactive=False)
output_bar_plot = gr.Plot(label="Top Activating Concepts")
output_line_plot = gr.Plot(label="Concept Manipulation Analysis")
# Wire up the button to the function
submit_btn.click(
fn=predict,
inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider],
outputs=[output_image, output_bar_plot, output_line_plot]
)
if __name__ == "__main__":
demo.launch(debug=True, share=True)