WolodjaZ commited on
Commit
b7bea99
·
verified ·
1 Parent(s): 47dc123

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -3
app.py CHANGED
@@ -57,7 +57,7 @@ def calculate_fvu(original_input, reconstruction):
57
  fvu_val = (recon_error_variance / (variance + 1e-8)).mean()
58
  return fvu_val.item()
59
 
60
- def predict(input_img, top_k, concept, neg_concept, max_strength):
61
  """
62
  Main function to process an image, identify top concepts, and visualize concept manipulation.
63
  """
@@ -138,7 +138,9 @@ def predict(input_img, top_k, concept, neg_concept, max_strength):
138
  modified_latents[:, concept_feature_id] = st
139
 
140
  # Decode the modified latents back into feature space
141
- modified_reconstructed = sae_model.decode(modified_latents) + error
 
 
142
 
143
  # Normalize for comparison
144
  modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True)
@@ -197,6 +199,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSA
197
  concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair")
198
  neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown")
199
  max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength")
 
200
  submit_btn = gr.Button("Analyze and Interpret", variant="primary")
201
 
202
  with gr.Column(scale=2):
@@ -210,7 +213,7 @@ with gr.Blocks(theme=gr.themes.Soft(), title="Matryoshka Sparse Autoencoder (MSA
210
  ["bird.jpg", 10, "birds", "", 10.0],
211
  ["statue.jpg", 10, "statue", "humans", 10.0],
212
  ],
213
- inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider],
214
  outputs=[output_image, output_bar_plot, output_line_plot],
215
  fn=predict,
216
  cache_examples=True
 
57
  fvu_val = (recon_error_variance / (variance + 1e-8)).mean()
58
  return fvu_val.item()
59
 
60
+ def predict(input_img, top_k, concept, neg_concept, max_strength, add_error):
61
  """
62
  Main function to process an image, identify top concepts, and visualize concept manipulation.
63
  """
 
138
  modified_latents[:, concept_feature_id] = st
139
 
140
  # Decode the modified latents back into feature space
141
+ modified_reconstructed = sae_model.decode(modified_latents)
142
+ if add_error:
143
+ modified_reconstructed = modified_reconstructed + error
144
 
145
  # Normalize for comparison
146
  modified_reconstructed_norm = modified_reconstructed / modified_reconstructed.norm(dim=-1, keepdim=True)
 
199
  concept_input = gr.Textbox(label="Concept to Manipulate", value="hair", placeholder="e.g., hair")
200
  neg_concept_input = gr.Textbox(label="Negative Concept (Optional)", placeholder="e.g., a frown")
201
  max_strength_slider = gr.Slider(minimum=1.0, maximum=20.0, value=10.0, step=0.5, label="Max Concept Strength")
202
+ add_error_checkbox = gr.Checkbox(value=True, label="Add error term to reconstruction")
203
  submit_btn = gr.Button("Analyze and Interpret", variant="primary")
204
 
205
  with gr.Column(scale=2):
 
213
  ["bird.jpg", 10, "birds", "", 10.0],
214
  ["statue.jpg", 10, "statue", "humans", 10.0],
215
  ],
216
+ inputs=[image_input, top_k_slider, concept_input, neg_concept_input, max_strength_slider, add_error_checkbox],
217
  outputs=[output_image, output_bar_plot, output_line_plot],
218
  fn=predict,
219
  cache_examples=True