在看ACMix这篇文章的时候涉及到了分组卷积,由于网上的图只给了简单的情况,没有给出更复杂的情况,这里特地记录一下
有分组卷积,配置如下
import torch
import torch.nn as nn
dim = 128
nhead = 8
groups = dim // nhead
group_conv = nn.Conv2d(9*dim//nhead, dim, 3, padding=1, groups=groups, bias=False) # 不失一般性,这里不设置bias
x=torch.randn(1, 9*dim//nhead, 224, 224)
y=group_conv(x)
y.shape
输出y的shape为
torch.Size([1, 128, 224, 224])
下面我们只判断第一组
groups = dim // nhead
x_g0 = x[:, :9, ...]
kernel_g0 = group_conv.weight.data[:nhead]
y_g0=nn.functional.conv2d(x_g0, kernel_g0, padding=1)
(np.around(y[:, :nhead].detach().numpy(), 5) == np.around(y_g0.detach().numpy(), 5)).all() # 由于精度问题需要round一下
输出为
True
以此类推,可以判断第二组到最后一组的相等情况。