DeepLabV3Plus核心代码详解

  1. 1. Atrous Spatial Pyramid Pooling
    1. 1.1 ASPP Conv
    2. 1.2 ASPP Pooling
    3. 1.3 ASPP
  2. 2. 空洞卷积ResNet
  3. 3. IntermediateLayerGetter
  4. 4. DeepLabV3Plus
  5. 5. AtrousSeparableConvolution

本文记录 DeepLabV3+ 论文(Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation)的阅读笔记。

1. Atrous Spatial Pyramid Pooling

ASPP(Atrous Spatial Pyramid Pooling),空洞空间金字塔池化。简单理解就是个至尊版池化层,其目的与普通的池化层一致,尽可能地去提取特征。

ASPP 本质上由一个 1×1 的卷积层、三个 3×3 的空洞卷积层 ASPP Conv 以及一个全局池化层 ASPP Pooling。五个模块输出的特征图尺寸都与输入相同,因此最后将它们在通道维度上 Concat 起来然后再通过一个 1×1 的卷积层降维得到 ASPP 的输出。

1.1 ASPP Conv

空洞卷积层与一般卷积层之间的差别在于膨胀率(dilation rate),膨胀率控制的是卷积时的 padding 以及 dilation。ASPP 应用了多个不同膨胀率的并行空洞卷积。通过不同的填充与膨胀,可以在不改变输出图像尺寸的情况下获取不同尺度的感受野,提取多尺度的信息。注意卷积核尺寸始终保持 3×3 不变:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch import nn
from torch.nn import functional as F


class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
]
super(ASPPConv, self).__init__(*modules)

1.2 ASPP Pooling

ASPP Pooling 首先是一个 AdaptiveAvgPool2d 层。所谓自适应均值池化,其自适应的地方在于不需要指定 kernel sizestride,只需指定最后的输出尺寸(此处为 1×1)。通过将各通道的特征图尺寸分别压缩至 1×1,从而提取各通道的特征,进而获取全局的特征。然后是一个 1×1 的卷积层,对上一步获取的特征进行进一步的提取,并降维。需要注意的是,在 ASPP Pooling 的网络结构部分,只是对特征进行了提取;而在 forward 方法中,除了顺序执行网络的各层外,最终还将特征图从 1×1 上采样回原来的尺寸:

1
2
3
4
5
6
7
8
9
10
11
12
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))

def forward(self, x):
size = x.shape[-2:] # 记录输入的尺寸
x = super(ASPPPooling, self).forward(x) # 图像尺寸变为(1, 1)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False) # 上采样为输入的尺寸

1.3 ASPP

我们将以上模块进行组合即可构建完整的 ASPP 模块:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates):
super(ASPP, self).__init__()
out_channels = 256
modules = [] # 五个分支输出图像大小均与原图相同

# (n, 2048, h, w) -> (n, 256, h, w)
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)))

rate1, rate2, rate3 = tuple(atrous_rates) # 膨胀率
modules.append(ASPPConv(in_channels, out_channels, rate1)) # (n, 2048, h, w) -> (n, 256, h, w)
modules.append(ASPPConv(in_channels, out_channels, rate2)) # (n, 2048, h, w) -> (n, 256, h, w)
modules.append(ASPPConv(in_channels, out_channels, rate3)) # (n, 2048, h, w) -> (n, 256, h, w)
modules.append(ASPPPooling(in_channels, out_channels)) # (n, 2048, h, w) -> (n, 256, h, w)

self.convs = nn.ModuleList(modules)

self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.1))

# 假设ASPP的输入为ResNet101中layer4的输出,即x.shape: (n, 2048, h, w)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1) # (n, 256*5, h, w)
return self.project(res) # (n, 256, h, w)


if __name__ == '__main__':
aspp = ASPP(in_channels=2048, atrous_rates=[6, 12, 18])

x = torch.randn(8, 2048, 18, 32)
print(aspp(x).shape) # torch.Size([8, 256, 18, 32])

2. 空洞卷积ResNet

本文以 ResNet50、101、152 为例,构建加入了空洞卷积的 ResNet 网络。ResNet 网络一共有四大层,我们记为 layer1、layer2、layer3 以及 layer4,默认情况下输出的 Feature Map 宽高尺寸比原图像小32倍,我们可以在第2~4层使用空洞卷积,假设在最后一层使用了空洞卷积,那么最后输出的 Feature Map 宽高尺寸比原图像小16倍:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url


model_urls = {
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Bottleneck(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()

if norm_layer is None:
norm_layer = nn.BatchNorm2d

width = int(planes * (base_width / 64.)) * groups

self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))

if self.downsample is not None:
identity = self.downsample(x)

out += identity

return self.relu(out)


class ResNet(nn.Module):
def __init__(self, block, layers, replace_stride_with_dilation=None,
num_classes=1000, groups=1, width_per_group=64, norm_layer=None):
super(ResNet, self).__init__()

if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer

self.inplanes = 64
self.dilation = 1
self.groups = groups
self.base_width = width_per_group

if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False] # 分别表示在layer2/3/4是否使用空洞卷积
if len(replace_stride_with_dilation) != 3:
raise ValueError(f"replace_stride_with_dilation should be None or a 3-element tuple, got {replace_stride_with_dilation}")

self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2])

self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate: # 如果使用空洞卷积那么就不对图像的尺寸进行下采样
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)

layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion

for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups, base_width=self.base_width,
dilation=self.dilation, norm_layer=norm_layer))

return nn.Sequential(*layers)

def forward(self, x):
dict_out = {}

x = self.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)

x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
dict_out['layer4'] = x

x = self.avgpool(x)
x = torch.flatten(x, start_dim=1)
x = self.fc(x)

return x, dict_out


def get_resnet(arch, block, layers, pretrained, progress, replace_stride_with_dilation):
model = ResNet(block, layers, replace_stride_with_dilation)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
model.load_state_dict(state_dict)
return model


def resnet50(pretrained=False, progress=True, replace_stride_with_dilation=None):
return get_resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, replace_stride_with_dilation)

def resnet101(pretrained=False, progress=True, replace_stride_with_dilation=None):
return get_resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, replace_stride_with_dilation)

def resnet152(pretrained=False, progress=True, replace_stride_with_dilation=None):
return get_resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, replace_stride_with_dilation)


if __name__ == '__main__':
net1 = resnet101(pretrained=True, replace_stride_with_dilation=[False, False, True]) # 在layer4使用空洞卷积
net2 = resnet101(pretrained=True, replace_stride_with_dilation=[False, True, True]) # 在layer3与layer4使用空洞卷积

x = torch.randn(8, 3, 224, 224)
out1, dict_out1 = net1(x)
out2, dict_out2 = net2(x)
print(dict_out1['layer4'].shape) # torch.Size([8, 2048, 14, 14]),有一层使用空洞卷积因此输出尺寸为输入的1/16
print(dict_out2['layer4'].shape) # torch.Size([8, 2048, 28, 28]),有两层使用空洞卷积因此输出尺寸为输入的1/8

3. IntermediateLayerGetter

我们需要用到 ResNet 特征提取层提取出的特征,其中有浅层特征(layer1 的输出)与深层特征(layer4 的输出),两者的通道维度我们分别记作 low_level_channelsin_channels,越深层的特征蕴含的语义信息越丰富,但是目标的位置信息更为模糊,小目标的特征可能还会丢失。因此我们需要记录下这些特征,实现 IntermediateLayerGetter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
from collections import OrderedDict
from ResNet import resnet101


class IntermediateLayerGetter(nn.ModuleDict):
def __init__(self, model, return_layers):
if not set(return_layers).issubset([name for name, _ in model.named_children()]):
raise ValueError("return_layers are not present in model")

origin_ret_layers = return_layers.copy() # 之后会删除return_layers的内容因此需要备份
layers = OrderedDict()
for name, module in model.named_children():
layers[name] = module
if name in return_layers:
del return_layers[name]
if not return_layers: # 到layer4结束,即layers中没有avgpool与fc层
break

super(IntermediateLayerGetter, self).__init__(layers)
self.return_layers = origin_ret_layers # {'layer4': 'out', 'layer1': 'low_level'}

def forward(self, x):
out = OrderedDict() # 输出以字典形式返回
for name, module in self.named_children():
x = module(x)

if name in self.return_layers: # 记录下需要返回的中间层输出,映射为自定义的名称
out_name = self.return_layers[name]
out[out_name] = x
return out


if __name__ == '__main__':
model = resnet101(pretrained=True, replace_stride_with_dilation=[False, False, True])
model = IntermediateLayerGetter(model, {'layer4': 'out', 'layer1': 'low_level'})

x = torch.randn(8, 3, 224, 224)
output = model(x)
print(output['low_level'].shape, output['out'].shape) # torch.Size([8, 256, 56, 56]) torch.Size([8, 2048, 14, 14])

4. DeepLabV3Plus

ResNet 与 ASPP 均为 DeepLabV3 的 Encoder 部分,现在先介绍一下 Decoder 部分 DeepLabHeadV3Plus

ResNet 提取出的浅层特征与深层特征的通道维度我们分别记作 low_level_channelsin_channels。先对浅层特征做一个投影,降低通道维度,将投影后的特征记作 low_level_feature;接着将深层特征通过 ASPP,并上采样到与 low_level_feature 相同的尺寸,将该特征记作 output_feature;最后将这两个特征在通道维度上 Concat 起来后通过分类头将通道维度映射为分类数量:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
from torch import nn
from torch.nn import functional as F
import ResNet
from IntermediateLayerGetter import IntermediateLayerGetter
from ASPP import ASPP


class DeepLabHeadV3Plus(nn.Module):
def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[6, 12, 18]):
super(DeepLabHeadV3Plus, self).__init__()

# 浅层特征的投影层
self.project = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True))

# 深层特征的ASPP层
self.aspp = ASPP(in_channels, aspp_dilate) # aspp_dilate = [6, 12, 18]

# 分类头
self.classifier = nn.Sequential(
nn.Conv2d(304, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1)
)

self._init_weight()

def forward(self, feature):
# feature['low_level']: (n, 256, h/4, w/4), feature['out']: (n, 2048, h/16, w/16)
low_level_feature = self.project(feature['low_level']) # (n, 256, h/4, w/4) -> (n, 48, h/4, w/4)
size = low_level_feature.shape[2:] # (h/4, w/4)
output_feature = self.aspp(feature['out']) # (n, 2048, h/16, w/16) -> (n, 256, h/16, w/16)
output_feature = F.interpolate(output_feature, size=size, mode='bilinear', align_corners=False) # 上采样为与low_level_feature尺寸一致
return self.classifier(torch.cat([low_level_feature, output_feature], dim=1)) # (n, 48+256, h/4, w/4) -> (n, num_classes, h/4, w/4)

def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)

空间金字塔池化模块能够通过使用滤波器(卷积)或池化操作以多种比率(rate)和多个有效感受野(fields-of-view)探测传入特征,对多尺度的上下文信息进行编码;Encoder-Decoder 结构的模块能够通过逐渐恢复空间信息来捕获更清晰的对象边界。DeepLabV3Plus 结合了这两者的优点,通过添加了一个简单而有效的解码器模块来扩展 DeepLabv3,以细化分割结果,尤其是沿对象边界的分割结果。

此外还可将深度可分离卷积应用于 ASPP 和解码器模块,从而形成更快、更强的编码器-解码器网络,该部分代码在下一节中介绍。

最后我们组合所有的模块即可构建完整的 DeepLabV3Plus,注意最后需要对 Decoder 的分类头输出进行上采样,恢复到原始的输入图像大小:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class DeepLabV3Plus(nn.Module):
def __init__(self, backbone, classifier):
super(DeepLabV3Plus, self).__init__()
self.backbone = backbone
self.classifier = classifier

def forward(self, x): # (n, c, h, w)
input_shape = x.shape[-2:] # (h, w)
features = self.backbone(x) # features['low_level']: (n, 256, h/4, w/4), feature['out']: (n, 2048, h/16, w/16)
x = self.classifier(features) # (n, num_classes, h/4, w/4)
x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False) # (n, num_classes, h, w)
return x


def get_deeplabv3plus_model(args):
if args['output_stride'] == 8:
replace_stride_with_dilation = [False, True, True]
aspp_dilate = [12, 24, 36]
else: # output_stride == 16
replace_stride_with_dilation = [False, False, True]
aspp_dilate = [6, 12, 18]

return_layers = {'layer4': 'out', 'layer1': 'low_level'}
backbone = ResNet.__dict__[args['backbone']](True, True, replace_stride_with_dilation)
backbone = IntermediateLayerGetter(backbone, return_layers)

in_channels = 2048
low_level_channels = 256
classifier = DeepLabHeadV3Plus(in_channels, low_level_channels, args['num_classes'], aspp_dilate)

model = DeepLabV3Plus(backbone, classifier)

return model


if __name__ == '__main__':
args = {
'num_classes': 4,
'output_stride': 16,
'backbone': 'resnet101',
'device': 'cuda',
}

model = get_deeplabv3plus_model(args)
model.to(args['device'])

input = torch.randn(8, 3, 224, 224, device=args['device'])
output = model(input)
print(output.shape) # torch.Size([8, 4, 224, 224])

pred = output.detach().max(dim=1)[1].cpu().numpy()
print(pred.shape) # (8, 224, 224)
print(pred[0, 0, :10]) # [0 0 0 3 3 3 2 1 1 1]

5. AtrousSeparableConvolution

深度可分离卷积,将标准卷积分解为 Depth-wise 卷积与 Point-wise 卷积,大大降低了计算复杂度。我们在其中结合空洞卷积即可构建空洞可分离卷积:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
class AtrousSeparableConvolution(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, bias=True):
super(AtrousSeparableConvolution, self).__init__()
self.sepconv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size, stride, padding, dilation, groups=in_channels, bias=bias),
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
)

self._init_weight()

def forward(self, x):
return self.sepconv(x)

def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)


def convert_to_separable_conv(module):
new_module = module
if isinstance(module, nn.Conv2d) and module.kernel_size[0] > 1:
new_module = AtrousSeparableConvolution(module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.bias)
for name, child in module.named_children(): # 递归修改每一层
new_module.add_module(name, convert_to_separable_conv(child))

return new_module


if __name__ == '__main__':
args = {
'num_classes': 4,
'output_stride': 16,
'backbone': 'resnet101',
'device': 'cuda',
'separable_conv': True,
}

model = get_deeplabv3plus_model(args)
if args['separable_conv']:
convert_to_separable_conv(model.classifier)
# print(model)
model.to(args['device'])

input = torch.randn(8, 3, 224, 224, device=args['device'])
output = model(input)
print(output.shape) # torch.Size([8, 4, 224, 224])

pred = output.detach().max(dim=1)[1].cpu().numpy()
print(pred.shape) # (8, 224, 224)
print(pred[0, 0, :10]) # [0 0 0 3 3 3 2 1 1 1]