Spaces:
Running
Running
fg-mindee
commited on
Commit
·
82a65d2
1
Parent(s):
ea7721e
feat: Added option to retrieve CAMs from multiple layers
Browse files
app.py
CHANGED
|
@@ -11,7 +11,8 @@ from io import BytesIO
|
|
| 11 |
from torchvision import models
|
| 12 |
from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
|
| 13 |
|
| 14 |
-
from torchcam import
|
|
|
|
| 15 |
from torchcam.utils import overlay_mask
|
| 16 |
|
| 17 |
|
|
@@ -59,14 +60,14 @@ def main():
|
|
| 59 |
if tv_model is not None:
|
| 60 |
with st.spinner('Loading model...'):
|
| 61 |
model = models.__dict__[tv_model](pretrained=True).eval()
|
| 62 |
-
default_layer =
|
| 63 |
|
| 64 |
target_layer = st.sidebar.text_input("Target layer", default_layer)
|
| 65 |
cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
|
| 66 |
if cam_method is not None:
|
| 67 |
-
cam_extractor =
|
| 68 |
model,
|
| 69 |
-
target_layer=target_layer if len(target_layer) > 0 else None
|
| 70 |
)
|
| 71 |
|
| 72 |
class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
|
|
@@ -94,16 +95,18 @@ def main():
|
|
| 94 |
else:
|
| 95 |
class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
|
| 96 |
# Retrieve the CAM
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
# Plot the raw heatmap
|
| 99 |
fig, ax = plt.subplots()
|
| 100 |
-
ax.imshow(
|
| 101 |
ax.axis('off')
|
| 102 |
cols[1].pyplot(fig)
|
| 103 |
|
| 104 |
# Overlayed CAM
|
| 105 |
fig, ax = plt.subplots()
|
| 106 |
-
result = overlay_mask(img, to_pil_image(
|
| 107 |
ax.imshow(result)
|
| 108 |
ax.axis('off')
|
| 109 |
cols[-1].pyplot(fig)
|
|
|
|
| 11 |
from torchvision import models
|
| 12 |
from torchvision.transforms.functional import resize, to_tensor, normalize, to_pil_image
|
| 13 |
|
| 14 |
+
from torchcam import methods
|
| 15 |
+
from torchcam.methods._utils import locate_candidate_layer
|
| 16 |
from torchcam.utils import overlay_mask
|
| 17 |
|
| 18 |
|
|
|
|
| 60 |
if tv_model is not None:
|
| 61 |
with st.spinner('Loading model...'):
|
| 62 |
model = models.__dict__[tv_model](pretrained=True).eval()
|
| 63 |
+
default_layer = locate_candidate_layer(model, (3, 224, 224))
|
| 64 |
|
| 65 |
target_layer = st.sidebar.text_input("Target layer", default_layer)
|
| 66 |
cam_method = st.sidebar.selectbox("CAM method", CAM_METHODS)
|
| 67 |
if cam_method is not None:
|
| 68 |
+
cam_extractor = methods.__dict__[cam_method](
|
| 69 |
model,
|
| 70 |
+
target_layer=target_layer.split("+") if len(target_layer) > 0 else None
|
| 71 |
)
|
| 72 |
|
| 73 |
class_choices = [f"{idx + 1} - {class_name}" for idx, class_name in enumerate(LABEL_MAP)]
|
|
|
|
| 95 |
else:
|
| 96 |
class_idx = LABEL_MAP.index(class_selection.rpartition(" - ")[-1])
|
| 97 |
# Retrieve the CAM
|
| 98 |
+
cams = cam_extractor(class_idx, out)
|
| 99 |
+
# Fuse the CAMs if there are several
|
| 100 |
+
cam = cams[0] if len(cams) == 1 else cam_extractor.fuse_cams(cams)
|
| 101 |
# Plot the raw heatmap
|
| 102 |
fig, ax = plt.subplots()
|
| 103 |
+
ax.imshow(cam.numpy())
|
| 104 |
ax.axis('off')
|
| 105 |
cols[1].pyplot(fig)
|
| 106 |
|
| 107 |
# Overlayed CAM
|
| 108 |
fig, ax = plt.subplots()
|
| 109 |
+
result = overlay_mask(img, to_pil_image(cam, mode='F'), alpha=0.5)
|
| 110 |
ax.imshow(result)
|
| 111 |
ax.axis('off')
|
| 112 |
cols[-1].pyplot(fig)
|