曹子晗

个人博客

欢迎来到我的个人博客


分组卷积详解

目录

在看ACMix这篇文章的时候涉及到了分组卷积,由于网上的图只给了简单的情况,没有给出更复杂的情况,这里特地记录一下

有分组卷积,配置如下

import torch
import torch.nn as nn

dim = 128
nhead = 8
groups = dim // nhead

group_conv = nn.Conv2d(9*dim//nhead, dim, 3, padding=1, groups=groups, bias=False)  # 不失一般性,这里不设置bias

x=torch.randn(1, 9*dim//nhead, 224, 224)
y=group_conv(x)
y.shape

输出y的shape为

torch.Size([1, 128, 224, 224])

下面我们只判断第一组

groups = dim // nhead
x_g0 = x[:, :9, ...]
kernel_g0 = group_conv.weight.data[:nhead]
y_g0=nn.functional.conv2d(x_g0, kernel_g0, padding=1)

(np.around(y[:, :nhead].detach().numpy(), 5) == np.around(y_g0.detach().numpy(), 5)).all()  # 由于精度问题需要round一下

输出为

True

以此类推,可以判断第二组到最后一组的相等情况。

图解

image-20220903133948876