Spaces:
Paused
Paused
File size: 10,356 Bytes
fce0aa3 4e9a667 fce0aa3 b7bea99 fce0aa3 47dc123 fce0aa3 a113ff8 fce0aa3 4e9a667 fce0aa3 a113ff8 fce0aa3 4e9a667 fce0aa3 b7bea99 fce0aa3 a113ff8 fce0aa3 a113ff8 fce0aa3 a113ff8 fce0aa3 a113ff8 fce0aa3 a113ff8 fce0aa3 0f8256c fce0aa3 6e27a97 76c7e15 0f8256c 76c7e15 b7bea99 76c7e15 fce0aa3 0f8256c fce0aa3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import clip
from sae import SAE
import os
# --- 1. Setup and Model Loading ---
# Use GPU if available, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Define file paths for clarity
SAE_MODEL_PATH = '6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768.pth'
VOCAB_SCORES_PATH = 'Concept_Interpreter_6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768_disect_ViT-L~14_-1_text_20000_768.npy'
VOCAB_NAMES_PATH = 'clip_disect_20k.txt'
# Load models and data
try:
# Load CLIP model
model, preprocess = clip.load("ViT-L/14", device=device)
# Load Sparse Autoencoder (SAE) model
# Ensure the SAE class correctly handles moving the model to the specified device
sae_model = SAE(SAE_MODEL_PATH).to(device).eval()
# Load pre-computed vocabulary scores and names
vocab_scores = np.load(VOCAB_SCORES_PATH)
with open(VOCAB_NAMES_PATH, 'r') as f:
vocab_names = [line.strip().lower() for line in f.readlines()]
except FileNotFoundError as e:
print(f"ERROR: A required file was not found: {e.filename}")
print("Please ensure all model and vocabulary files are present in the correct paths.")
# Exit if essential files are missing
exit()
# Pre-calculate mappings for faster lookup
# For a given feature index, what is the best concept name?
feature_to_concept_score = np.max(vocab_scores, axis=0)
feature_to_concept_name_idx = np.argmax(vocab_scores, axis=0)
# For a given concept name, what is the best feature index?
concept_to_feature_score = np.max(vocab_scores, axis=1)
concept_to_feature_idx = np.argmax(vocab_scores, axis=1)
# --- 2. Helper and Core Logic Functions ---
def calculate_fvu(original_input, reconstruction):
"""Calculates the Fraction of Variance Unexplained (FVU)."""
variance = (original_input - original_input.mean(dim=-1, keepdim=True)).var(dim=-1)
recon_error_variance = (original_input - reconstruction).var(dim=-1)
# Clamp to avoid division by zero or tiny numbers
fvu_val = (recon_error_variance / (variance + 1e-8)).mean()
return fvu_val.item()
def predict(input_img, top_k, concept, neg_concept, max_strength, add_error):
"""
Main function to process an image, identify top concepts, and visualize concept manipulation.
"""
if not input_img:
raise gr.Error("Please provide an input image.")
# --- Part A: Top Concepts Analysis ---
# Preprocess the input image and move to the correct device
image_input_processed = preprocess(input_img.convert("RGB")).unsqueeze(0).to(device)
with torch.no_grad():
# Encode the image with CLIP
image_features = model.encode_image(image_input_processed).to(torch.float32)
# Get SAE reconstruction and latent activations
reconstructed_features, _, full_latents = sae_model(image_features)
fvu_score = calculate_fvu(image_features, reconstructed_features)
error = image_features - reconstructed_features
# Get the top K activating SAE features for the image
full_latents = full_latents.cpu().flatten()
top_sae_values, top_sae_indices = full_latents.topk(k=top_k)
# Create the bar plot for top concepts
fig_bar, ax_bar = plt.subplots(figsize=(10, 6))
concept_labels = [
f"{vocab_names[feature_to_concept_name_idx[i]]} ({feature_to_concept_score[i]:.2f})"
for i in top_sae_indices
]
ax_bar.barh(range(top_k), top_sae_values.numpy(), color='skyblue')
ax_bar.set_yticks(range(top_k))
ax_bar.set_yticklabels(concept_labels)
ax_bar.invert_yaxis() # Display top concept at the top
ax_bar.set_xlabel("SAE Feature Activation")
ax_bar.set_title(f"Top {top_k} Concepts (Concept Name (Concept Similarity Score)) with FVU: {fvu_score:.2f}")
plt.tight_layout()
# --- Part B: Concept Manipulation ---
# Validate the user-provided concept
concept = concept.lower().strip()
if concept not in vocab_names:
raise gr.Error(f"Concept '{concept}' not found in vocabulary. Please choose a valid concept.")
# Get the feature index corresponding to the chosen concept
concept_feature_id = concept_to_feature_idx[vocab_names.index(concept)]
concept_assign_score = concept_to_feature_score[vocab_names.index(concept)]
# Get the original activation strength of this concept in the input image
original_strength = full_latents[concept_feature_id].item()
# Create positive and negative text prompts
if not neg_concept:
neg_concept_prompt = f"a photo without {concept}"
else:
neg_concept_prompt = f"a photo with {neg_concept.lower().strip()}"
pos_concept_prompt = f"a photo with {concept}"
# Tokenize prompts and encode with CLIP
text_labels = clip.tokenize([pos_concept_prompt, neg_concept_prompt]).to(device)
with torch.no_grad():
text_features = model.encode_text(text_labels)
text_features /= text_features.norm(dim=-1, keepdim=True)
# Define the range of strengths to test
strengths = torch.linspace(0.0, max_strength, 11).to(device)
pos_concept_probs, neg_concept_probs, cos_sims = [], [], []
original_reconstructed_norm = reconstructed_features / reconstructed_features.norm(dim=-1, keepdim=True)
for st in strengths:
with torch.no_grad():
# Create a copy of latents and modify the target concept feature
modified_latents = full_latents.clone().to(device).reshape(1, -1)
modified_latents[:, concept_feature_id] = st
# Decode the modified latents back into feature space
modified_reconstructed = sae_model.decode(modified_latents)
if add_error:
modified_reconstructed = modified_reconstructed + error
# Normalize for comparison
modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True)
# Calculate similarity to the text prompts (probabilities)
probs = (100.0 * modified_reconstructed_norm @ text_features.T).softmax(dim=-1)
pos_concept_probs.append(probs[0, 0].item())
neg_concept_probs.append(probs[0, 1].item())
# Calculate cosine similarity to the original reconstructed image
cos_sims.append(
torch.nn.functional.cosine_similarity(modified_reconstructed_norm, original_reconstructed_norm).item()
)
# Create the line plot for concept manipulation
fig_line, ax_line = plt.subplots(figsize=(10, 6))
strengths_cpu = strengths.cpu().numpy()
ax_line.plot(strengths_cpu, pos_concept_probs, 'o-', label=f'"{pos_concept_prompt}"')
ax_line.plot(strengths_cpu, neg_concept_probs, 'o-', label=f'"{neg_concept_prompt}"')
# ✨ NEW: Add a vertical line indicating the original strength of the concept
ax_line.axvline(x=original_strength, color='purple', linestyle='--', label=f'Original Strength ({original_strength:.2f})')
# Add cosine similarity on a secondary y-axis
ax2 = ax_line.twinx()
ax2.plot(strengths_cpu, cos_sims, 'x-', color='green', label='Similarity to Original')
ax2.set_ylabel('Cosine Similarity', color='green')
ax2.tick_params(axis='y', labelcolor='green')
ax_line.set_xlabel("Magnitude of the Concept SAE Feature")
ax_line.set_ylabel("CLIP Probability")
ax_line.set_title(f"Effect of Modifying Concept '{concept}' (Assignment Score: {concept_assign_score:.2f})")
fig_line.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax_line.transAxes)
plt.tight_layout()
# Close figures to free memory
plt.close(fig_bar)
plt.close(fig_line)
return input_img, fig_bar, fig_line
# --- 3. Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSAE) Example") as demo:
gr.Markdown(
"Based on the paper: [Interpreting CLIP with Hierarchical Sparse Autoencoders](https://openreview.net/forum?id=5MQQsenQBm) with [github code](https://github.com/WolodjaZ/MSAE). "
"Upload an image to see its top activating concepts from a sparse autoencoder. Then, choose a concept (from `clip_disect_20k.txt`) to visualize how manipulating its corresponding concept magnitude affects the image representation."
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(label="Input Image", sources=['upload', 'webcam'], type="pil")
gr.Markdown("### Analysis & Manipulation Controls")
top_k_slider = gr.Slider(minimum=3, maximum=20, value=10, step=1, label="Numb of Top K Concepts to visualize")
concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair")
neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown")
max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength")
add_error_checkbox = gr.Checkbox(label="Add error term to reconstruction")
submit_btn = gr.Button("Analyze and Interpret", variant="primary")
with gr.Column(scale=2):
gr.Markdown("### Results")
output_image = gr.Image(label="Original Image", interactive=False)
output_bar_plot = gr.Plot()
output_line_plot = gr.Plot()
gr.Examples(
examples=[
["bird.jpg", 10, "birds", "", 10.0, True],
["statue.jpg", 10, "statue", "humans", 10.0, True],
],
inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider, add_error_checkbox],
outputs=[output_image, output_bar_plot, output_line_plot],
fn=predict,
cache_examples=True
)
# Wire up the button to the function
submit_btn.click(
fn=predict,
inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider],
outputs=[output_image, output_bar_plot, output_line_plot]
)
if __name__ == "__main__":
demo.launch(debug=True)
|