208 lines
8.0 KiB
Markdown
208 lines
8.0 KiB
Markdown
# ResNet
|
||
|
||
残差神经网络(ResNet)是由微软研究院的何恺明大神团队提出的一个经典网络模型,一经现世就成为了沿用至今的超级 Backbone。
|
||
|
||
[知乎](https://zhuanlan.zhihu.com/p/101332297)
|
||
|
||
[论文](https://arxiv.org/pdf/1512.03385.pdf)
|
||
|
||
# WHY residual?
|
||
|
||
在 ResNet 提出之前,所有的神经网络都是通过卷积层和池化层的叠加组成的。
|
||
人们认为卷积层和池化层的层数越多,获取到的图片特征信息越全,学习效果也就越好。但是在实际的试验中发现,随着卷积层和池化层的叠加,不但没有出现学习效果越来越好的情况,反而出现两种问题:
|
||
|
||
- 梯度消失和梯度爆炸
|
||
|
||
梯度消失:若每一层的梯度误差小于 1,反向传播时,网络越深,梯度越趋近于 0
|
||
|
||
梯度爆炸:若每一层的梯度误差大于 1,反向传播时,网络越深,梯度越趋近于无穷大
|
||
|
||
- 退化现象
|
||
|
||
如图所示,随着层数越来越深,预测的效果反而越来越差(error 越大)
|
||
|
||

|
||
|
||
# 网络模型
|
||
|
||

|
||
|
||
我们可以看到,ResNet 的网络依旧非常深,这是因为研究团队不仅发现了退化现象,还采用出一个可以将网络继续加深的 trick:shortcut,亦即我们所说的 residual。
|
||
|
||
- 为了解决梯度消失或梯度爆炸问题,ResNet 论文提出通过数据的预处理以及在网络中使用 BN(Batch Normalization)层来解决。
|
||
- 为了解决深层网络中的退化问题,可以人为地让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。这种神经网络被称为 残差网络 (ResNets)。ResNet 论文提出了 residual 结构(残差结构)来减轻退化问题。
|
||
|
||
## residual 结构
|
||
|
||

|
||
|
||
# 网络代码
|
||
|
||
```python
|
||
import torch.nn as nn
|
||
import torch
|
||
|
||
|
||
# ResNet18/34的残差结构,用的是2个3x3的卷积
|
||
class BasicBlock(nn.Module):
|
||
expansion = 1 # 残差结构中,主分支的卷积核个数是否发生变化,不变则为1
|
||
|
||
def __init__(self, in_channel, out_channel, stride=1, downsample=None): # downsample对应虚线残差结构
|
||
super(BasicBlock, self).__init__()
|
||
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
|
||
kernel_size=3, stride=stride, padding=1, bias=False)
|
||
self.bn1 = nn.BatchNorm2d(out_channel)
|
||
self.relu = nn.ReLU()
|
||
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
|
||
kernel_size=3, stride=1, padding=1, bias=False)
|
||
self.bn2 = nn.BatchNorm2d(out_channel)
|
||
self.downsample = downsample
|
||
|
||
def forward(self, x):
|
||
identity = x
|
||
if self.downsample is not None: # 虚线残差结构,需要下采样
|
||
identity = self.downsample(x) # 捷径分支 short cut
|
||
|
||
out = self.conv1(x)
|
||
out = self.bn1(out)
|
||
out = self.relu(out)
|
||
|
||
out = self.conv2(out)
|
||
out = self.bn2(out)
|
||
|
||
out += identity
|
||
out = self.relu(out)
|
||
|
||
return out
|
||
|
||
# ResNet50/101/152的残差结构,用的是1x1+3x3+1x1的卷积
|
||
class Bottleneck(nn.Module):
|
||
expansion = 4 # 残差结构中第三层卷积核个数是第一/二层卷积核个数的4倍
|
||
|
||
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
|
||
super(Bottleneck, self).__init__()
|
||
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
|
||
kernel_size=1, stride=1, bias=False) # squeeze channels
|
||
self.bn1 = nn.BatchNorm2d(out_channel)
|
||
# -----------------------------------------
|
||
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
|
||
kernel_size=3, stride=stride, bias=False, padding=1)
|
||
self.bn2 = nn.BatchNorm2d(out_channel)
|
||
# -----------------------------------------
|
||
self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel * self.expansion,
|
||
kernel_size=1, stride=1, bias=False) # unsqueeze channels
|
||
self.bn3 = nn.BatchNorm2d(out_channel * self.expansion)
|
||
self.relu = nn.ReLU(inplace=True)
|
||
self.downsample = downsample
|
||
|
||
def forward(self, x):
|
||
identity = x
|
||
if self.downsample is not None:
|
||
identity = self.downsample(x) # 捷径分支 short cut
|
||
|
||
out = self.conv1(x)
|
||
out = self.bn1(out)
|
||
out = self.relu(out)
|
||
|
||
out = self.conv2(out)
|
||
out = self.bn2(out)
|
||
out = self.relu(out)
|
||
|
||
out = self.conv3(out)
|
||
out = self.bn3(out)
|
||
|
||
out += identity
|
||
out = self.relu(out)
|
||
|
||
return out
|
||
|
||
|
||
class ResNet(nn.Module):
|
||
# block = BasicBlock or Bottleneck
|
||
# block_num为残差结构中conv2_x~conv5_x中残差块个数,是一个列表
|
||
def __init__(self, block, blocks_num, num_classes=1000, include_top=True):
|
||
super(ResNet, self).__init__()
|
||
self.include_top = include_top
|
||
self.in_channel = 64
|
||
|
||
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
|
||
padding=3, bias=False)
|
||
self.bn1 = nn.BatchNorm2d(self.in_channel)
|
||
self.relu = nn.ReLU(inplace=True)
|
||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||
self.layer1 = self._make_layer(block, 64, blocks_num[0]) # conv2_x
|
||
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2) # conv3_x
|
||
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2) # conv4_x
|
||
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2) # conv5_x
|
||
if self.include_top:
|
||
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
|
||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||
|
||
for m in self.modules():
|
||
if isinstance(m, nn.Conv2d):
|
||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||
|
||
# channel为残差结构中第一层卷积核个数
|
||
def _make_layer(self, block, channel, block_num, stride=1):
|
||
downsample = None
|
||
|
||
# ResNet50/101/152的残差结构,block.expansion=4
|
||
if stride != 1 or self.in_channel != channel * block.expansion:
|
||
downsample = nn.Sequential(
|
||
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
|
||
nn.BatchNorm2d(channel * block.expansion))
|
||
|
||
layers = []
|
||
layers.append(block(self.in_channel, channel, downsample=downsample, stride=stride))
|
||
self.in_channel = channel * block.expansion
|
||
|
||
for _ in range(1, block_num):
|
||
layers.append(block(self.in_channel, channel))
|
||
|
||
return nn.Sequential(*layers)
|
||
|
||
def forward(self, x):
|
||
x = self.conv1(x)
|
||
x = self.bn1(x)
|
||
x = self.relu(x)
|
||
x = self.maxpool(x)
|
||
|
||
x = self.layer1(x)
|
||
x = self.layer2(x)
|
||
x = self.layer3(x)
|
||
x = self.layer4(x)
|
||
|
||
if self.include_top:
|
||
x = self.avgpool(x)
|
||
x = torch.flatten(x, 1)
|
||
x = self.fc(x)
|
||
|
||
return x
|
||
|
||
|
||
def resnet34(num_classes=1000, include_top=True):
|
||
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
|
||
|
||
|
||
def resnet101(num_classes=1000, include_top=True):
|
||
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
|
||
|
||
|
||
'''
|
||
我们希望你能够去将论文下载下来以后跟一些讲解视频尝试将论文与代码结合起来理解
|
||
看论文的源码是我们必须要做的一个中重要的工作
|
||
'''
|
||
```
|
||
|
||
# 视频
|
||
|
||
# 思考
|
||
|
||
## 思考 1
|
||
|
||
请你自行了解网络结构中的 BN(Batch Normalization)层,这是很重要的一个 normalization 操作,如果感兴趣还可以继续了解 LN (Layer Normalization)
|
||
|
||
## 思考 2
|
||
|
||
你觉得论文中提出用 residual 这一解决方法来解决网络的退化现象的依据是什么,如果可以,请你进一步尝试用数学角度思考这一问题
|