Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split-Attention Module in PyTorch #66

Open
zhanghang1989 opened this issue May 24, 2020 · 5 comments
Open

Split-Attention Module in PyTorch #66

zhanghang1989 opened this issue May 24, 2020 · 5 comments

Comments

@zhanghang1989
Copy link
Owner

Just in case someone want to use the Split-Attention Module. The module is provided here:

import torch
import torch.nn as nn
from torch.nn import functional as F

class rSoftMax(nn.Module):
    def __init__(self, radix, cardinality):
        super().__init__()
        assert radix > 0
        self.radix = radix
        self.cardinality = cardinality

    def forward(self, x):
        batch = x.size(0)
        if self.radix > 1:
            x = x.view(batch, self.cardinality, self.radix, -1).transpose(1, 2)
            x = F.softmax(x, dim=1)
            x = x.reshape(batch, -1)
        else:
            x = torch.sigmoid(x)
        return x

class Splat(nn.Module):
    def __init__(self, channels, radix, cardinality, reduction_factor=4):
        super(Splat, self).__init__()
        self.radix = radix
        self.cardinality = cardinality
        self.channels = channels
        inter_channels = max(channels*radix//reduction_factor, 32)
        self.fc1 = nn.Conv2d(channels//radix, inter_channels, 1, groups=cardinality)
        self.bn1 = nn.BatchNorm2d(inter_channels)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(inter_channels, channels*radix, 1, groups=cardinality)
        self.rsoftmax = rSoftMax(radix, cardinality)

    def forward(self, x):
        batch, rchannel = x.shape[:2]
        if self.radix > 1:
            splited = torch.split(x, rchannel//self.radix, dim=1)
            gap = sum(splited) 
        else:
            gap = x
        gap = F.adaptive_avg_pool2d(gap, 1)
        gap = self.fc1(gap)

        gap = self.bn1(gap)
        gap = self.relu(gap)

        atten = self.fc2(gap)
        atten = self.rsoftmax(atten).view(batch, -1, 1, 1)

        if self.radix > 1:
            attens = torch.split(atten, rchannel//self.radix, dim=1)
            out = sum([att*split for (att, split) in zip(attens, splited)])
        else:
            out = atten * x
        return out.contiguous()
@zhanghang1989 zhanghang1989 pinned this issue May 24, 2020
@abhigoku10
Copy link

@zhanghang1989 thanks for sharing the split attention module implementation, can we integrate this with FPN module instead of Resnet backbone ??

@zhanghang1989
Copy link
Owner Author

@zhanghang1989 thanks for sharing the split attention module implementation, can we integrate this with FPN module instead of Resnet backbone ??

Yes, that should work.

@abhigoku10
Copy link

@zhanghang1989 any example references which you can point out

@zhanghang1989
Copy link
Owner Author

@Jerryzcn , You have done some similar things, right?

@Hi-Jingzhi
Copy link

Hi! some questions about the parameter channels of the layer fn1 in class Splat, as follows. The one is "what's mean of the parameter channels", the other one is "the input channels of the layer fn1 is channels//radix, or channels?".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants