File size: 10,356 Bytes
fce0aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e9a667
fce0aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7bea99
fce0aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47dc123
fce0aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a113ff8
fce0aa3
 
 
 
 
4e9a667
fce0aa3
 
 
 
 
 
 
a113ff8
 
 
fce0aa3
 
 
 
4e9a667
fce0aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7bea99
 
 
fce0aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a113ff8
 
 
fce0aa3
 
 
 
 
 
a113ff8
fce0aa3
a113ff8
fce0aa3
 
 
 
 
 
 
 
 
 
 
 
 
 
a113ff8
fce0aa3
 
 
 
 
 
 
a113ff8
fce0aa3
 
 
0f8256c
fce0aa3
 
 
 
 
6e27a97
 
76c7e15
 
 
0f8256c
 
76c7e15
b7bea99
76c7e15
 
 
 
fce0aa3
 
 
 
 
 
 
 
 
 
0f8256c
fce0aa3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
import clip
from sae import SAE
import os

# --- 1. Setup and Model Loading ---

# Use GPU if available, otherwise CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Define file paths for clarity
SAE_MODEL_PATH = '6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768.pth'
VOCAB_SCORES_PATH = 'Concept_Interpreter_6144_768_TopKReLU_64_RW_False_False_0.0_cc3m_ViT-L~14_train_image_2905936_768_disect_ViT-L~14_-1_text_20000_768.npy'
VOCAB_NAMES_PATH = 'clip_disect_20k.txt'

# Load models and data
try:
    # Load CLIP model
    model, preprocess = clip.load("ViT-L/14", device=device)
    
    # Load Sparse Autoencoder (SAE) model
    # Ensure the SAE class correctly handles moving the model to the specified device
    sae_model = SAE(SAE_MODEL_PATH).to(device).eval()

    # Load pre-computed vocabulary scores and names
    vocab_scores = np.load(VOCAB_SCORES_PATH)
    with open(VOCAB_NAMES_PATH, 'r') as f:
        vocab_names = [line.strip().lower() for line in f.readlines()]

except FileNotFoundError as e:
    print(f"ERROR: A required file was not found: {e.filename}")
    print("Please ensure all model and vocabulary files are present in the correct paths.")
    # Exit if essential files are missing
    exit()

# Pre-calculate mappings for faster lookup
# For a given feature index, what is the best concept name?
feature_to_concept_score = np.max(vocab_scores, axis=0)
feature_to_concept_name_idx = np.argmax(vocab_scores, axis=0)

# For a given concept name, what is the best feature index?
concept_to_feature_score = np.max(vocab_scores, axis=1)
concept_to_feature_idx = np.argmax(vocab_scores, axis=1)


# --- 2. Helper and Core Logic Functions ---

def calculate_fvu(original_input, reconstruction):
    """Calculates the Fraction of Variance Unexplained (FVU)."""
    variance = (original_input - original_input.mean(dim=-1, keepdim=True)).var(dim=-1)
    recon_error_variance = (original_input - reconstruction).var(dim=-1)
    # Clamp to avoid division by zero or tiny numbers
    fvu_val = (recon_error_variance / (variance + 1e-8)).mean()
    return fvu_val.item()

def predict(input_img, top_k, concept, neg_concept, max_strength, add_error):
    """
    Main function to process an image, identify top concepts, and visualize concept manipulation.
    """
    if not input_img:
        raise gr.Error("Please provide an input image.")

    # --- Part A: Top Concepts Analysis ---
    
    # Preprocess the input image and move to the correct device
    image_input_processed = preprocess(input_img.convert("RGB")).unsqueeze(0).to(device)

    with torch.no_grad():
        # Encode the image with CLIP
        image_features = model.encode_image(image_input_processed).to(torch.float32)
        
        # Get SAE reconstruction and latent activations
        reconstructed_features, _, full_latents = sae_model(image_features)
        
    fvu_score = calculate_fvu(image_features, reconstructed_features)
    error = image_features - reconstructed_features
    
    # Get the top K activating SAE features for the image
    full_latents = full_latents.cpu().flatten()
    top_sae_values, top_sae_indices = full_latents.topk(k=top_k)

    # Create the bar plot for top concepts
    fig_bar, ax_bar = plt.subplots(figsize=(10, 6))
    concept_labels = [
        f"{vocab_names[feature_to_concept_name_idx[i]]} ({feature_to_concept_score[i]:.2f})"
        for i in top_sae_indices
    ]
    ax_bar.barh(range(top_k), top_sae_values.numpy(), color='skyblue')
    ax_bar.set_yticks(range(top_k))
    ax_bar.set_yticklabels(concept_labels)
    ax_bar.invert_yaxis() # Display top concept at the top
    ax_bar.set_xlabel("SAE Feature Activation")
    ax_bar.set_title(f"Top {top_k} Concepts (Concept Name (Concept Similarity Score)) with FVU: {fvu_score:.2f}")
    plt.tight_layout()
    
    # --- Part B: Concept Manipulation ---

    # Validate the user-provided concept
    concept = concept.lower().strip()
    if concept not in vocab_names:
        raise gr.Error(f"Concept '{concept}' not found in vocabulary. Please choose a valid concept.")
    
    # Get the feature index corresponding to the chosen concept
    concept_feature_id = concept_to_feature_idx[vocab_names.index(concept)]
    concept_assign_score = concept_to_feature_score[vocab_names.index(concept)]
    
    # Get the original activation strength of this concept in the input image
    original_strength = full_latents[concept_feature_id].item()
    
    # Create positive and negative text prompts
    if not neg_concept:
        neg_concept_prompt = f"a photo without {concept}"
    else:
        neg_concept_prompt = f"a photo with {neg_concept.lower().strip()}"
    
    pos_concept_prompt = f"a photo with {concept}"
    
    # Tokenize prompts and encode with CLIP
    text_labels = clip.tokenize([pos_concept_prompt, neg_concept_prompt]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text_labels)
        text_features /= text_features.norm(dim=-1, keepdim=True)

    # Define the range of strengths to test
    strengths = torch.linspace(0.0, max_strength, 11).to(device)
    pos_concept_probs, neg_concept_probs, cos_sims = [], [], []

    original_reconstructed_norm = reconstructed_features / reconstructed_features.norm(dim=-1, keepdim=True)

    for st in strengths:
        with torch.no_grad():
            # Create a copy of latents and modify the target concept feature
            modified_latents = full_latents.clone().to(device).reshape(1, -1)
            modified_latents[:, concept_feature_id] = st
            
            # Decode the modified latents back into feature space
            modified_reconstructed = sae_model.decode(modified_latents) 
            if add_error:
                modified_reconstructed = modified_reconstructed + error
            
            # Normalize for comparison
            modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True)
            
            # Calculate similarity to the text prompts (probabilities)
            probs = (100.0 * modified_reconstructed_norm @ text_features.T).softmax(dim=-1)
            pos_concept_probs.append(probs[0, 0].item())
            neg_concept_probs.append(probs[0, 1].item())
            
            # Calculate cosine similarity to the original reconstructed image
            cos_sims.append(
                torch.nn.functional.cosine_similarity(modified_reconstructed_norm, original_reconstructed_norm).item()
            )

    # Create the line plot for concept manipulation
    fig_line, ax_line = plt.subplots(figsize=(10, 6))
    strengths_cpu = strengths.cpu().numpy()
    ax_line.plot(strengths_cpu, pos_concept_probs, 'o-', label=f'"{pos_concept_prompt}"')
    ax_line.plot(strengths_cpu, neg_concept_probs, 'o-', label=f'"{neg_concept_prompt}"')
    
    # ✨ NEW: Add a vertical line indicating the original strength of the concept
    ax_line.axvline(x=original_strength, color='purple', linestyle='--', label=f'Original Strength ({original_strength:.2f})')
    
    # Add cosine similarity on a secondary y-axis
    ax2 = ax_line.twinx()
    ax2.plot(strengths_cpu, cos_sims, 'x-', color='green', label='Similarity to Original')
    ax2.set_ylabel('Cosine Similarity', color='green')
    ax2.tick_params(axis='y', labelcolor='green')

    ax_line.set_xlabel("Magnitude of the Concept SAE Feature")
    ax_line.set_ylabel("CLIP Probability")
    ax_line.set_title(f"Effect of Modifying Concept '{concept}' (Assignment Score: {concept_assign_score:.2f})")
    fig_line.legend(loc="upper right", bbox_to_anchor=(1, 1), bbox_transform=ax_line.transAxes)
    plt.tight_layout()

    # Close figures to free memory
    plt.close(fig_bar)
    plt.close(fig_line)

    return input_img, fig_bar, fig_line


# --- 3. Gradio Interface ---

with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSAE) Example") as demo:
    gr.Markdown(
        "Based on the paper: [Interpreting CLIP with Hierarchical Sparse Autoencoders](https://openreview.net/forum?id=5MQQsenQBm) with [github code](https://github.com/WolodjaZ/MSAE). "
        "Upload an image to see its top activating concepts from a sparse autoencoder. Then, choose a concept (from `clip_disect_20k.txt`) to visualize how manipulating its corresponding concept magnitude affects the image representation."
    )
    
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(label="Input Image", sources=['upload', 'webcam'], type="pil")
            gr.Markdown("### Analysis & Manipulation Controls")
            top_k_slider = gr.Slider(minimum=3, maximum=20, value=10, step=1, label="Numb of Top K Concepts to visualize")
            concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair")
            neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown")
            max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength")
            add_error_checkbox = gr.Checkbox(label="Add error term to reconstruction")
            submit_btn = gr.Button("Analyze and Interpret", variant="primary")
        
        with gr.Column(scale=2):
            gr.Markdown("### Results")
            output_image = gr.Image(label="Original Image", interactive=False)
            output_bar_plot = gr.Plot()
            output_line_plot = gr.Plot()
    
    gr.Examples(
        examples=[
            ["bird.jpg", 10, "birds", "", 10.0, True],
            ["statue.jpg", 10, "statue", "humans", 10.0, True],
        ],
        inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider, add_error_checkbox],
        outputs=[output_image, output_bar_plot, output_line_plot],
        fn=predict,
        cache_examples=True
    )

    # Wire up the button to the function
    submit_btn.click(
        fn=predict,
        inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider],
        outputs=[output_image, output_bar_plot, output_line_plot]
    )


if __name__ == "__main__":
    demo.launch(debug=True)