Skip to content

Commit

Permalink
feat: Added implementation of SS-CAM (#11)
Browse files Browse the repository at this point in the history
* feat: Added implementation of Smoothed Score-CAM

* test: Updated unittests

* docs: Updated documentation

* feat: Updated example script

* docs: Updated readme

* refactor: Refactored SS-CAM
  • Loading branch information
frgfm authored Aug 3, 2020
1 parent fb3be81 commit a95d680
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 18 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ This project is developed and maintained by the repo owner, but the implementati
- [Grad-CAM++](https://arxiv.org/abs/1710.11063): improvement of GradCAM++ for more accurate pixel-level contribution to the activation.
- [Smooth Grad-CAM++](https://arxiv.org/abs/1908.01224): SmoothGrad mechanism coupled with GradCAM.
- [Score-CAM](https://arxiv.org/abs/1910.01279): score-weighting of class activation for better interpretability.
- [SS-CAM](https://arxiv.org/abs/2006.14255): SmoothGrad mechanism coupled with Score-CAM.



Expand Down
2 changes: 2 additions & 0 deletions docs/source/cams.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Methods related to activation-based class activation maps.

.. autoclass:: ScoreCAM

.. autoclass:: SSCAM


Grad-CAM
--------
Expand Down
10 changes: 5 additions & 5 deletions scripts/cam_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torchvision import models
from torchvision.transforms.functional import normalize, resize, to_tensor, to_pil_image

from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM
from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM, SSCAM
from torchcam.utils import overlay_mask

VGG_CONFIG = {_vgg: dict(input_layer='features', conv_layer='features')
Expand Down Expand Up @@ -58,9 +58,9 @@ def main(args):
# Hook the corresponding layer in the model
cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer),
GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer),
ScoreCAM(model, conv_layer, input_layer)]
ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer)]

fig, axes = plt.subplots(1, len(cam_extractors))
fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2))
for idx, extractor in enumerate(cam_extractors):
model.zero_grad()
scores = model(img_tensor.unsqueeze(0))
Expand All @@ -80,7 +80,7 @@ def main(args):

axes[idx].imshow(result)
axes[idx].axis('off')
axes[idx].set_title(extractor.__class__.__name__, size=10)
axes[idx].set_title(extractor.__class__.__name__, size=8)

plt.tight_layout()
if args.savefig:
Expand All @@ -95,7 +95,7 @@ def main(args):
parser.add_argument("--img", type=str,
default='https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg',
help="The image to extract CAM from")
parser.add_argument("--class-idx", type=int, default=None, help='Index of the class to inspect')
parser.add_argument("--class-idx", type=int, default=232, help='Index of the class to inspect')
parser.add_argument("--device", type=str, default=None, help='Default device to perform computation on')
parser.add_argument("--savefig", type=str, default=None, help="Path to save figure")
args = parser.parse_args()
Expand Down
Binary file modified static/images/cam_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion test/test_cams.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_smooth_gradcampp(self):
self._test_extractor(extractor, model)


for cam_extractor in ['CAM', 'ScoreCAM']:
for cam_extractor in ['CAM', 'ScoreCAM', 'SSCAM']:
def do_test(self, cam_extractor=cam_extractor):
self._test_cam(cam_extractor)

Expand Down
124 changes: 112 additions & 12 deletions torchcam/cams/cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.nn.functional as F

__all__ = ['CAM', 'ScoreCAM']
__all__ = ['CAM', 'ScoreCAM', 'SSCAM']


class _CAM(object):
Expand Down Expand Up @@ -128,7 +128,8 @@ class CAM(_CAM):
where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
position :math:`(x, y)`,
and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k`.
and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k` in the fully
connected layer..
Example::
>>> from torchvision.models import resnet18
Expand Down Expand Up @@ -172,18 +173,18 @@ class ScoreCAM(_CAM):
with the coefficient :math:`w_k^{(c)}` being defined as:
.. math::
w_k^{(c)} = softmax(Y^{(c)}(M) - Y^{(c)}(X_b))
w_k^{(c)} = softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b))
where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax
for input :math:`X`, :math:`X_b` is a baseline image,
and :math:`M` is defined as follows:
and :math:`M_k` is defined as follows:
.. math::
M = \\Big(\\frac{M^{(d)} - \\min M^{(d)}}{\\max M^{(d)} - \\min M^{(d)}} \\odot X \\Big)_{1 \\leq d \\leq D}
M_k = \\frac{U(A_k) - \\min\\limits_m U(A_m)}{\\max\\limits_m U(A_m) - \\min\\limits_m U(A_m)})
\\odot X
where :math:`\\odot` refers to the element-wise multiplication, :math:`M^{(d)}` is the upsampled version of
:math:`A_d` on node :math:`d`, and :math:`D` is the number of channels on the target convolutional layer.
where :math:`\\odot` refers to the element-wise multiplication and :math:`U` is the upsampling operation.
Example::
>>> from torchvision.models import resnet18
Expand Down Expand Up @@ -222,12 +223,12 @@ def _store_input(self, module, input):
def _get_weights(self, class_idx, scores=None):
"""Computes the weight coefficients of the hooked activation maps"""

# Upsample activation to input_size
# 1 * O * M * N
upsampled_a = F.interpolate(self.hook_a, self._input.shape[-2:], mode='bilinear', align_corners=False)
# Normalize the activation
upsampled_a = self._normalize(self.hook_a)

# Normalize it
upsampled_a = self._normalize(upsampled_a)
# Upsample it to input_size
# 1 * O * M * N
upsampled_a = F.interpolate(upsampled_a, self._input.shape[-2:], mode='bilinear', align_corners=False)

# Use it as a mask
# O * I * H * W
Expand All @@ -253,3 +254,102 @@ def _get_weights(self, class_idx, scores=None):

def __repr__(self):
return f"{self.__class__.__name__}(batch_size={self.bs})"


class SSCAM(ScoreCAM):
"""Implements a class activation map extractor as described in `"SS-CAM: Smoothed Score-CAM for
Sharper Visual Feature Localization" <https://arxiv.org/pdf/2006.14255.pdf>`_.
The localization map is computed as follows:
.. math::
L^{(c)}_{SS-CAM}(x, y) = ReLU\\Big(\\sum\\limits_k w_k^{(c)} A_k(x, y)\\Big)
with the coefficient :math:`w_k^{(c)}` being defined as:
.. math::
w_k^{(c)} = \\frac{1}{N} \\sum\\limits_1^N softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b))
where :math:`N` is the number of samples used to smooth the weights,
:math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at
position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax
for input :math:`X`, :math:`X_b` is a baseline image,
and :math:`M_k` is defined as follows:
.. math::
M_k = \\Bigg(\\frac{U(A_k) - \\min\\limits_m U(A_m)}{\\max\\limits_m U(A_m) - \\min\\limits_m U(A_m)} +
\\delta\\Bigg) \\odot X
where :math:`\\odot` refers to the element-wise multiplication, :math:`U` is the upsampling operation,
:math:`\\delta \\sim \\mathcal{N}(0, \\sigma^2)` is the random noise that follows a 0-mean gaussian distribution
with a standard deviation of :math:`\\sigma`.
Example::
>>> from torchvision.models import resnet18
>>> from torchcam.cams import SSCAM
>>> model = resnet18(pretrained=True).eval()
>>> cam = SSCAM(model, 'layer4', 'conv1')
>>> with torch.no_grad(): out = model(input_tensor)
>>> cam(class_idx=100)
Args:
model (torch.nn.Module): input model
conv_layer (str): name of the last convolutional layer
input_layer (str): name of the first layer
batch_size (int, optional): batch size used to forward masked inputs
num_samples (int, optional): number of noisy samples used for weight computation
std (float, optional): standard deviation of the noise added to the normalized activation
"""

hook_a = None
hook_handles = []

def __init__(self, model, conv_layer, input_layer, batch_size=32, num_samples=35, std=2.0):

super().__init__(model, conv_layer, input_layer, batch_size)

self.num_samples = num_samples
self.std = std
self._distrib = torch.distributions.normal.Normal(0, self.std)

def _get_weights(self, class_idx, scores=None):
"""Computes the weight coefficients of the hooked activation maps"""

# Normalize the activation
upsampled_a = self._normalize(self.hook_a)

# Upsample it to input_size
# 1 * O * M * N
upsampled_a = F.interpolate(upsampled_a, self._input.shape[-2:], mode='bilinear', align_corners=False)

# Use it as a mask
# O * I * H * W
upsampled_a = upsampled_a.squeeze(0).unsqueeze(1)

# Initialize weights
weights = torch.zeros(upsampled_a.shape[0], dtype=upsampled_a.dtype).to(device=upsampled_a.device)

# Disable hook updates
self._hooks_enabled = False

for _idx in range(self.num_samples):
noisy_m = self._input * (upsampled_a +
self._distrib.sample(self._input.size()).to(device=self._input.device))

# Process by chunk (GPU RAM limitation)
for idx in range(math.ceil(weights.shape[0] / self.bs)):

selection_slice = slice(idx * self.bs, min((idx + 1) * self.bs, weights.shape[0]))
with torch.no_grad():
# Get the softmax probabilities of the target class
weights[selection_slice] += F.softmax(self.model(noisy_m[selection_slice]), dim=1)[:, class_idx]

weights /= self.num_samples

# Reenable hook updates
self._hooks_enabled = True

return weights

def __repr__(self):
return f"{self.__class__.__name__}(batch_size={self.bs}, num_samples={self.num_samples}, std={self.std})"

0 comments on commit a95d680

Please sign in to comment.