Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -57,7 +57,7 @@ def calculate_fvu(original_input, reconstruction):
|
|
| 57 |
fvu_val = (recon_error_variance / (variance + 1e-8)).mean()
|
| 58 |
return fvu_val.item()
|
| 59 |
|
| 60 |
-
def predict(input_img, top_k, concept, neg_concept, max_strength):
|
| 61 |
"""
|
| 62 |
Main function to process an image, identify top concepts, and visualize concept manipulation.
|
| 63 |
"""
|
|
@@ -138,7 +138,9 @@ def predict(input_img, top_k, concept, neg_concept, max_strength):
|
|
| 138 |
modified_latents[:, concept_feature_id] = st
|
| 139 |
|
| 140 |
# Decode the modified latents back into feature space
|
| 141 |
-
modified_reconstructed = sae_model.decode(modified_latents)
|
|
|
|
|
|
|
| 142 |
|
| 143 |
# Normalize for comparison
|
| 144 |
modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True)
|
|
@@ -197,6 +199,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSA
|
|
| 197 |
concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair")
|
| 198 |
neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown")
|
| 199 |
max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength")
|
|
|
|
| 200 |
submit_btn = gr.Button("Analyze and Interpret", variant="primary")
|
| 201 |
|
| 202 |
with gr.Column(scale=2):
|
|
@@ -210,7 +213,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSA
|
|
| 210 |
["bird.jpg", 10, "birds", "", 10.0],
|
| 211 |
["statue.jpg", 10, "statue", "humans", 10.0],
|
| 212 |
],
|
| 213 |
-
inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider],
|
| 214 |
outputs=[output_image, output_bar_plot, output_line_plot],
|
| 215 |
fn=predict,
|
| 216 |
cache_examples=True
|
|
|
|
| 57 |
fvu_val = (recon_error_variance / (variance + 1e-8)).mean()
|
| 58 |
return fvu_val.item()
|
| 59 |
|
| 60 |
+
def predict(input_img, top_k, concept, neg_concept, max_strength, add_error):
|
| 61 |
"""
|
| 62 |
Main function to process an image, identify top concepts, and visualize concept manipulation.
|
| 63 |
"""
|
|
|
|
| 138 |
modified_latents[:, concept_feature_id] = st
|
| 139 |
|
| 140 |
# Decode the modified latents back into feature space
|
| 141 |
+
modified_reconstructed = sae_model.decode(modified_latents)
|
| 142 |
+
if add_error:
|
| 143 |
+
modified_reconstructed = modified_reconstructed + error
|
| 144 |
|
| 145 |
# Normalize for comparison
|
| 146 |
modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True)
|
|
|
|
| 199 |
concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair")
|
| 200 |
neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown")
|
| 201 |
max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength")
|
| 202 |
+
add_error_checkbox = gr.Checkbox(value=True, label="Add error term to reconstruction")
|
| 203 |
submit_btn = gr.Button("Analyze and Interpret", variant="primary")
|
| 204 |
|
| 205 |
with gr.Column(scale=2):
|
|
|
|
| 213 |
["bird.jpg", 10, "birds", "", 10.0],
|
| 214 |
["statue.jpg", 10, "statue", "humans", 10.0],
|
| 215 |
],
|
| 216 |
+
inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider, add_error_checkbox],
|
| 217 |
outputs=[output_image, output_bar_plot, output_line_plot],
|
| 218 |
fn=predict,
|
| 219 |
cache_examples=True
|