Skip to content

v0.3.0: Support of Layer-CAM & multi-layer CAM computation

Compare
Choose a tag to compare
@frgfm frgfm released this 31 Oct 15:24
d8d722d

This release extends CAM methods with Layer-CAM, greatly improves the core features (CAM computation for multiple layers at once, CAM fusion, support of torch.nn.Module), while improving accessibility for entry users.

Note: TorchCAM 0.3.0 requires PyTorch 1.5.1 or higher.

Highlights

Enters Layer-CAM

The previous release saw the introduction of Score-CAM variants, and this one introduces you to Layer-CAM, which is meant to be considerably faster, while offering very competitive localization cues!

Just like any other CAM methods, you can now use it as follows:

from torchcam.cams import LayerCAM
# model = ....
# Hook the model
cam_extractor = LayerCAM(model)

Consequently, the illustration of visual outputs for all CAM methods has been updated so that you can better choose the option that suits you:

cam_example

Computing CAMs for multiple layers & CAM fusion

A class activation map is specific to a given layer in a model. To fully capture the influence of visual traits on your classification output, you might want to explore the CAMs for multiple layers.

For instance, here are the CAMs on the layers "layer2", "layer3" and "layer4" of a resnet18:

from torchvision.io.image import read_image
from torchvision.models import resnet18
from torchvision.transforms.functional import normalize, resize, to_pil_image
import matplotlib.pyplot as plt

from torchcam.cams import LayerCAM
from torchcam.utils import overlay_mask

# Download an image
!wget https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg
# Set this to your image path if you wish to run it on your own data
img_path = "border-collie.jpg"

# Get your input
img = read_image(img_path)
# Preprocess it for your chosen model
input_tensor = normalize(resize(img, (224, 224)) / 255., [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
# Get your model
model = resnet18(pretrained=True).eval()
# Hook the model
cam_extractor = LayerCAM(model, ["layer2", "layer3", "layer4"])

out = model(input_tensor.unsqueeze(0))
cams = cam_extractor(out.squeeze(0).argmax().item(), out)
# Plot the CAMs
_, axes = plt.subplots(1, len(cam_extractor.target_names))
for idx, name, cam in zip(range(len(cam_extractor.target_names)), cam_extractor.target_names, cams):
  axes[idx].imshow(cam.numpy()); axes[idx].axis('off'); axes[idx].set_title(name);
plt.show()

multi_cams

Now, the way you would combine those together is up to you. By default, most approaches use an element-wise maximum. But, LayerCAM has its own fusion method:

# Let's fuse them
fused_cam = cam_extractor.fuse_cams(cams)
# Plot the raw version
plt.imshow(fused_cam.numpy()); plt.axis('off'); plt.title(" + ".join(cam_extractor.target_names)); plt.show()

fused_cams

# Overlay it on the image
result = overlay_mask(to_pil_image(img), to_pil_image(fused_cam, mode='F'), alpha=0.5)
# Plot the result
plt.imshow(result); plt.axis('off'); plt.title(" + ".join(cam_extractor.target_names)); plt.show()

fused_overlay

Support of torch.nn.Module as target_layer

While making the API more robust, CAM constructors now also accept torch.nn.Module as target_layer. Previously, you had to pass the name of the layer as string, but you can now pass the object reference directly if you prefer:

from torchcam.cams import LayerCAM
# model = ....
# Hook the model
cam_extractor = LayerCAM(model, model.layer4)

⚡ Latency benchmark ⚡

Since CAMs can be used from localization or production pipelines, it is important to consider latency along with pure visual output quality. For this reason, a latency evaluation script has been included in this release along with a full benchmark table.

Should you wish to have latency metrics on your dedicated hardware, you can run the script on your own:

python scripts/eval_latency.py SmoothGradCAMpp --size 224

Notebooks ⏯️

Do you prefer to only run code rather than write it? Perhaps you only want to tweak a few things?
Then enjoy the brand new Jupyter notebooks than you can either run locally or on Google Colab!

🤗 Live demo 🤗

The ML community was recently blessed by HuggingFace with their beta of Spaces, which let you host free-of-charge your ML demos!

Previously, you were able to run the demo locally on deploy it on your own, but now, you can enjoy the live demo of TorchCAM 🎨

Breaking changes

Multiple CAM output

Since CAM extractor can now compute the resulting maps for multiple layer at a time, the return type of all CAMs has been changed from torch.Tensor to List[torch.Tensor] with N elements, where N is the number of target layers.

0.2.0 0.3.0
>>> from torchcam.cams import SmoothGradCAMpp
>>> extractor = SmoothGradCAMpp(model)
>>> out = model(input_tensor.unsqueeze(0))
>>> print(type(cam_extractor(out.squeeze(0).argmax().item(), out)))
<class 'torch.Tensor'>
>>> from torchcam.cams import SmoothGradCAMpp
>>> extractor = SmoothGradCAMpp(model)
>>> out = model(input_tensor.unsqueeze(0))
>>> print(type(cam_extractor(out.squeeze(0).argmax().item(), out)))
<class 'list'>

New features

CAMs

Implementations of CAM method

  • Added support of conv1x1 as FC candidate in base CAM #69 (@frgfm)
  • Added support of LayerCAM #77 (@frgfm)
  • Added support of torch.nn.Module as target_layer or fc_layer #83 (@frgfm)
  • Added support of multiple target layers for all CAM methods #89 #92 (@frgfm)
  • Added layer-specific CAM fusion method #93 (@frgfm)

Scripts

Side scripts to make the most out of TorchCAM

  • Added latency evaluation script #95 (@frgfm)

Test

Verifications of the package well-being before release

  • Added unittests to verify that conv1x1 can be used as FC in base CAM #69 (@frgfm)
  • Added unittest for LayerCAM #77 (@frgfm)
  • Added unittest for gradient-based CAM method for models with in-place ops #80 (@frgfm)
  • Added unittest to check support of torch.nn.Module as target_layer in CAM constructor #83 #88 (@frgfm)
  • Added unittest for CAM fusion #93 (@frgfm)

Documentation

Online resources for potential users

  • Added LayerCAM ref in the README and in the documentation #77 (@frgfm)
  • Added CODE_OF_CONDUCT #86 (@frgfm)
  • Added changelog to the documentation #91 (@frgfm)
  • Added latency benchmark & GIF illustration of CAM on a video in README #95 (@frgfm)
  • Added documentation of .fuse_cams method #93 (@frgfm)
  • Added ref to HF Space demo in README and documentation #96 (@frgfm)
  • Added tutorial notebooks and reference page in the documentation #99 #100 #101 #102 (@frgfm)

Others

Other tools and implementations

  • Added class_idx & target_layer selection in the demo #67 (@frgfm)
  • Added CI jobs to build on different OS & Python versions, to validate the demo, and the example script #73 #74 (@frgfm)
  • Added LayerCAM to the demo #77 (@frgfm)
  • Added an environment collection script #78 (@frgfm)
  • Added CI check for the latency evaluation script #95 (@frgfm)

Bug fixes

CAMs

  • Fixes backward hook mechanism for in-place operations #80 (@frgfm)

Documentation

  • Fixed docutils version constraint for documentation building #98 (@frgfm)

Others

Improvements

CAMs

  • Improved weight broadcasting for all CAMs #77 (@frgfm)
  • Refactored hook enabling #80 (@frgfm)
  • Improved the warning message for target automatic resolution #87 #92 (@frgfm)
  • Improved arg type checking for CAM constructor #88 (@frgfm)

Scripts

  • Improved the layout option of the example script #66 (@frgfm)
  • Refactored example script #80 #94 (@frgfm)
  • Updated all scripts for support of multiple target layers #89 (@frgfm)

Test

  • Updated unittests for multiple target layer support #89 (@frgfm)

Documentation

  • Added latest release doc version & updated README badge #63 (@frgfm)
  • Added demo screenshot in the README #67 (@frgfm)
  • Updated instructions in README #89 (@frgfm)
  • Improved documentation landing page #91 (@frgfm)
  • Updated contribution guidelines #94 (@frgfm)
  • Updated documentation requirements #99 (@frgfm)

Others

  • Updated package version and fixed CI jobs to validate release publish #63 (@frgfm)
  • Updated license from MIT to Apache 2.0 #70 (@frgfm)
  • Refactored CI jobs #73 (@frgfm)
  • Improved bug report template #78 (@frgfm)
  • Updated streamlit syntax in demo #94 (@frgfm)
  • Added isort config and CI job #97 (@frgfm)
  • Added CI job for sanity check of the documentation build #98 (@frgfm)