diff --git a/README.md b/README.md index f0b77f51..aefa1866 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ This project is developed and maintained by the repo owner, but the implementati - [Grad-CAM++](https://arxiv.org/abs/1710.11063): improvement of GradCAM++ for more accurate pixel-level contribution to the activation. - [Smooth Grad-CAM++](https://arxiv.org/abs/1908.01224): SmoothGrad mechanism coupled with GradCAM. - [Score-CAM](https://arxiv.org/abs/1910.01279): score-weighting of class activation for better interpretability. +- [SS-CAM](https://arxiv.org/abs/2006.14255): SmoothGrad mechanism coupled with Score-CAM. diff --git a/docs/source/cams.rst b/docs/source/cams.rst index 6b7a569d..c2fd4b88 100644 --- a/docs/source/cams.rst +++ b/docs/source/cams.rst @@ -14,6 +14,8 @@ Methods related to activation-based class activation maps. .. autoclass:: ScoreCAM +.. autoclass:: SSCAM + Grad-CAM -------- diff --git a/scripts/cam_example.py b/scripts/cam_example.py index 25aebded..e15d34b1 100644 --- a/scripts/cam_example.py +++ b/scripts/cam_example.py @@ -15,7 +15,7 @@ from torchvision import models from torchvision.transforms.functional import normalize, resize, to_tensor, to_pil_image -from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM +from torchcam.cams import CAM, GradCAM, GradCAMpp, SmoothGradCAMpp, ScoreCAM, SSCAM from torchcam.utils import overlay_mask VGG_CONFIG = {_vgg: dict(input_layer='features', conv_layer='features') @@ -58,9 +58,9 @@ def main(args): # Hook the corresponding layer in the model cam_extractors = [CAM(model, conv_layer, fc_layer), GradCAM(model, conv_layer), GradCAMpp(model, conv_layer), SmoothGradCAMpp(model, conv_layer, input_layer), - ScoreCAM(model, conv_layer, input_layer)] + ScoreCAM(model, conv_layer, input_layer), SSCAM(model, conv_layer, input_layer)] - fig, axes = plt.subplots(1, len(cam_extractors)) + fig, axes = plt.subplots(1, len(cam_extractors), figsize=(7, 2)) for idx, extractor in enumerate(cam_extractors): model.zero_grad() scores = model(img_tensor.unsqueeze(0)) @@ -80,7 +80,7 @@ def main(args): axes[idx].imshow(result) axes[idx].axis('off') - axes[idx].set_title(extractor.__class__.__name__, size=10) + axes[idx].set_title(extractor.__class__.__name__, size=8) plt.tight_layout() if args.savefig: @@ -95,7 +95,7 @@ def main(args): parser.add_argument("--img", type=str, default='https://www.woopets.fr/assets/races/000/066/big-portrait/border-collie.jpg', help="The image to extract CAM from") - parser.add_argument("--class-idx", type=int, default=None, help='Index of the class to inspect') + parser.add_argument("--class-idx", type=int, default=232, help='Index of the class to inspect') parser.add_argument("--device", type=str, default=None, help='Default device to perform computation on') parser.add_argument("--savefig", type=str, default=None, help="Path to save figure") args = parser.parse_args() diff --git a/static/images/cam_example.png b/static/images/cam_example.png index f501c6ad..d5072ea1 100644 Binary files a/static/images/cam_example.png and b/static/images/cam_example.png differ diff --git a/test/test_cams.py b/test/test_cams.py index a47c8a14..577302f8 100644 --- a/test/test_cams.py +++ b/test/test_cams.py @@ -101,7 +101,7 @@ def test_smooth_gradcampp(self): self._test_extractor(extractor, model) -for cam_extractor in ['CAM', 'ScoreCAM']: +for cam_extractor in ['CAM', 'ScoreCAM', 'SSCAM']: def do_test(self, cam_extractor=cam_extractor): self._test_cam(cam_extractor) diff --git a/torchcam/cams/cam.py b/torchcam/cams/cam.py index 682297b3..fa13d696 100644 --- a/torchcam/cams/cam.py +++ b/torchcam/cams/cam.py @@ -9,7 +9,7 @@ import torch import torch.nn.functional as F -__all__ = ['CAM', 'ScoreCAM'] +__all__ = ['CAM', 'ScoreCAM', 'SSCAM'] class _CAM(object): @@ -128,7 +128,8 @@ class CAM(_CAM): where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at position :math:`(x, y)`, - and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k`. + and :math:`w_k^{(c)}` is the weight corresponding to class :math:`c` for unit :math:`k` in the fully + connected layer.. Example:: >>> from torchvision.models import resnet18 @@ -172,18 +173,18 @@ class ScoreCAM(_CAM): with the coefficient :math:`w_k^{(c)}` being defined as: .. math:: - w_k^{(c)} = softmax(Y^{(c)}(M) - Y^{(c)}(X_b)) + w_k^{(c)} = softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b)) where :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax for input :math:`X`, :math:`X_b` is a baseline image, - and :math:`M` is defined as follows: + and :math:`M_k` is defined as follows: .. math:: - M = \\Big(\\frac{M^{(d)} - \\min M^{(d)}}{\\max M^{(d)} - \\min M^{(d)}} \\odot X \\Big)_{1 \\leq d \\leq D} + M_k = \\frac{U(A_k) - \\min\\limits_m U(A_m)}{\\max\\limits_m U(A_m) - \\min\\limits_m U(A_m)}) + \\odot X - where :math:`\\odot` refers to the element-wise multiplication, :math:`M^{(d)}` is the upsampled version of - :math:`A_d` on node :math:`d`, and :math:`D` is the number of channels on the target convolutional layer. + where :math:`\\odot` refers to the element-wise multiplication and :math:`U` is the upsampling operation. Example:: >>> from torchvision.models import resnet18 @@ -222,12 +223,12 @@ def _store_input(self, module, input): def _get_weights(self, class_idx, scores=None): """Computes the weight coefficients of the hooked activation maps""" - # Upsample activation to input_size - # 1 * O * M * N - upsampled_a = F.interpolate(self.hook_a, self._input.shape[-2:], mode='bilinear', align_corners=False) + # Normalize the activation + upsampled_a = self._normalize(self.hook_a) - # Normalize it - upsampled_a = self._normalize(upsampled_a) + # Upsample it to input_size + # 1 * O * M * N + upsampled_a = F.interpolate(upsampled_a, self._input.shape[-2:], mode='bilinear', align_corners=False) # Use it as a mask # O * I * H * W @@ -253,3 +254,102 @@ def _get_weights(self, class_idx, scores=None): def __repr__(self): return f"{self.__class__.__name__}(batch_size={self.bs})" + + +class SSCAM(ScoreCAM): + """Implements a class activation map extractor as described in `"SS-CAM: Smoothed Score-CAM for + Sharper Visual Feature Localization" `_. + + The localization map is computed as follows: + + .. math:: + L^{(c)}_{SS-CAM}(x, y) = ReLU\\Big(\\sum\\limits_k w_k^{(c)} A_k(x, y)\\Big) + + with the coefficient :math:`w_k^{(c)}` being defined as: + + .. math:: + w_k^{(c)} = \\frac{1}{N} \\sum\\limits_1^N softmax(Y^{(c)}(M_k) - Y^{(c)}(X_b)) + + where :math:`N` is the number of samples used to smooth the weights, + :math:`A_k(x, y)` is the activation of node :math:`k` in the last convolutional layer of the model at + position :math:`(x, y)`, :math:`Y^{(c)}(X)` is the model output score for class :math:`c` before softmax + for input :math:`X`, :math:`X_b` is a baseline image, + and :math:`M_k` is defined as follows: + + .. math:: + M_k = \\Bigg(\\frac{U(A_k) - \\min\\limits_m U(A_m)}{\\max\\limits_m U(A_m) - \\min\\limits_m U(A_m)} + + \\delta\\Bigg) \\odot X + + where :math:`\\odot` refers to the element-wise multiplication, :math:`U` is the upsampling operation, + :math:`\\delta \\sim \\mathcal{N}(0, \\sigma^2)` is the random noise that follows a 0-mean gaussian distribution + with a standard deviation of :math:`\\sigma`. + + Example:: + >>> from torchvision.models import resnet18 + >>> from torchcam.cams import SSCAM + >>> model = resnet18(pretrained=True).eval() + >>> cam = SSCAM(model, 'layer4', 'conv1') + >>> with torch.no_grad(): out = model(input_tensor) + >>> cam(class_idx=100) + + Args: + model (torch.nn.Module): input model + conv_layer (str): name of the last convolutional layer + input_layer (str): name of the first layer + batch_size (int, optional): batch size used to forward masked inputs + num_samples (int, optional): number of noisy samples used for weight computation + std (float, optional): standard deviation of the noise added to the normalized activation + """ + + hook_a = None + hook_handles = [] + + def __init__(self, model, conv_layer, input_layer, batch_size=32, num_samples=35, std=2.0): + + super().__init__(model, conv_layer, input_layer, batch_size) + + self.num_samples = num_samples + self.std = std + self._distrib = torch.distributions.normal.Normal(0, self.std) + + def _get_weights(self, class_idx, scores=None): + """Computes the weight coefficients of the hooked activation maps""" + + # Normalize the activation + upsampled_a = self._normalize(self.hook_a) + + # Upsample it to input_size + # 1 * O * M * N + upsampled_a = F.interpolate(upsampled_a, self._input.shape[-2:], mode='bilinear', align_corners=False) + + # Use it as a mask + # O * I * H * W + upsampled_a = upsampled_a.squeeze(0).unsqueeze(1) + + # Initialize weights + weights = torch.zeros(upsampled_a.shape[0], dtype=upsampled_a.dtype).to(device=upsampled_a.device) + + # Disable hook updates + self._hooks_enabled = False + + for _idx in range(self.num_samples): + noisy_m = self._input * (upsampled_a + + self._distrib.sample(self._input.size()).to(device=self._input.device)) + + # Process by chunk (GPU RAM limitation) + for idx in range(math.ceil(weights.shape[0] / self.bs)): + + selection_slice = slice(idx * self.bs, min((idx + 1) * self.bs, weights.shape[0])) + with torch.no_grad(): + # Get the softmax probabilities of the target class + weights[selection_slice] += F.softmax(self.model(noisy_m[selection_slice]), dim=1)[:, class_idx] + + weights /= self.num_samples + + # Reenable hook updates + self._hooks_enabled = True + + return weights + + def __repr__(self): + return f"{self.__class__.__name__}(batch_size={self.bs}, num_samples={self.num_samples}, std={self.std})"