Update app.py
Browse files
app.py
CHANGED
|
@@ -6,7 +6,7 @@ import requests
|
|
| 6 |
from io import BytesIO
|
| 7 |
import numpy as np
|
| 8 |
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
| 9 |
-
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 10 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 11 |
from timm.data import create_transform
|
| 12 |
|
|
@@ -51,7 +51,7 @@ def process_image(image_path, model):
|
|
| 51 |
return tensor
|
| 52 |
|
| 53 |
def get_cam_image(model, image, target_layer, cam_method):
|
| 54 |
-
cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer]
|
| 55 |
grayscale_cam = cam(input_tensor=image)
|
| 56 |
|
| 57 |
config = model.pretrained_cfg
|
|
@@ -69,14 +69,23 @@ def get_feature_info(model):
|
|
| 69 |
else:
|
| 70 |
return []
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
def explain_image(model_name, image_path, cam_method, feature_module):
|
| 73 |
model = load_model(model_name)
|
| 74 |
image = process_image(image_path, model)
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
else:
|
| 80 |
# Fallback to the last feature module or last convolutional layer
|
| 81 |
feature_info = get_feature_info(model)
|
| 82 |
if feature_info:
|
|
@@ -99,22 +108,29 @@ def explain_image(model_name, image_path, cam_method, feature_module):
|
|
| 99 |
def update_feature_modules(model_name):
|
| 100 |
model = load_model(model_name)
|
| 101 |
feature_modules = get_feature_info(model)
|
| 102 |
-
return gr.Dropdown
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
gr.
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
-
|
|
|
|
| 6 |
from io import BytesIO
|
| 7 |
import numpy as np
|
| 8 |
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
|
| 9 |
+
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
| 10 |
from pytorch_grad_cam.utils.image import show_cam_on_image
|
| 11 |
from timm.data import create_transform
|
| 12 |
|
|
|
|
| 51 |
return tensor
|
| 52 |
|
| 53 |
def get_cam_image(model, image, target_layer, cam_method):
|
| 54 |
+
cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer])
|
| 55 |
grayscale_cam = cam(input_tensor=image)
|
| 56 |
|
| 57 |
config = model.pretrained_cfg
|
|
|
|
| 69 |
else:
|
| 70 |
return []
|
| 71 |
|
| 72 |
+
def get_target_layer(model, target_layer_name):
|
| 73 |
+
if target_layer_name is None:
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
return model.get_submodule(target_layer_name)
|
| 78 |
+
except AttributeError:
|
| 79 |
+
print(f"WARNING: Layer '{target_layer_name}' not found in the model.")
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
def explain_image(model_name, image_path, cam_method, feature_module):
|
| 83 |
model = load_model(model_name)
|
| 84 |
image = process_image(image_path, model)
|
| 85 |
|
| 86 |
+
target_layer = get_target_layer(model, feature_module)
|
| 87 |
+
|
| 88 |
+
if target_layer is None:
|
|
|
|
| 89 |
# Fallback to the last feature module or last convolutional layer
|
| 90 |
feature_info = get_feature_info(model)
|
| 91 |
if feature_info:
|
|
|
|
| 108 |
def update_feature_modules(model_name):
|
| 109 |
model = load_model(model_name)
|
| 110 |
feature_modules = get_feature_info(model)
|
| 111 |
+
return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)
|
| 112 |
|
| 113 |
+
with gr.Blocks() as demo:
|
| 114 |
+
gr.Markdown("# Explainable AI with timm models")
|
| 115 |
+
gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module to visualize the explanation.")
|
| 116 |
+
|
| 117 |
+
with gr.Row():
|
| 118 |
+
with gr.Column():
|
| 119 |
+
model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model")
|
| 120 |
+
image_input = gr.Image(type="filepath", label="Upload Image")
|
| 121 |
+
cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method")
|
| 122 |
+
feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)")
|
| 123 |
+
explain_button = gr.Button("Explain Image")
|
| 124 |
+
|
| 125 |
+
with gr.Column():
|
| 126 |
+
output_image = gr.Image(type="pil", label="Explained Image")
|
| 127 |
+
|
| 128 |
+
model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown])
|
| 129 |
+
|
| 130 |
+
explain_button.click(
|
| 131 |
+
fn=explain_image,
|
| 132 |
+
inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown],
|
| 133 |
+
outputs=[output_image]
|
| 134 |
+
)
|
| 135 |
|
| 136 |
+
demo.launch()
|