Deep Residual Learning for Image Recognition 全精读:ResNet 残差网络开山之作
论文信息标题Deep Residual Learning for Image Recognition会议CVPR 2016单位Microsoft Research代码https://github.com/KaimingHe/deep-residual-networks论文https://arxiv.org/pdf/1512.03385.pdf前言深度学习一直有一个魔咒网络越深性能越容易退化。层数堆上去训练误差反而变大精度不升反降。大家都以为是过拟合直到 ResNet 出现才揭开真相不是过拟合而是深层网络太难优化。图 1. 在 CIFAR-10 数据集上20 层和 56 层“普通”网络的训练误差左和测试误差右。更深的网络训练误差更高因此测试误差也更高。在 ImageNet 上的类似现象见图 4。何恺明团队用一个极简的残差学习shortcut 跳跃连接让 152 层超深网络轻松训练ImageNet 错误率狂降到 3.57%拿下 ILSVRC 2015 冠军。从此CNN 正式进入“百层时代”成为检测、分割、超分所有视觉任务的基础骨架。一、核心痛点深层网络的退化问题1.1 什么是退化Degradation随着网络深度线性增加训练精度先上升→饱和→快速下降不是过拟合训练集误差也升高不是梯度消失BN 已解决1.2 问题本质深层网络希望学习恒等映射H ( x ) x H(x)xH(x)x但直接拟合极难。ResNet 换了思路不学H ( x ) H(x)H(x)学残差F ( x ) H ( x ) − x F(x)H(x)-xF(x)H(x)−x。二、核心创新残差学习与 shortcut 连接图 2. 残差学习一个构建模块。2.1 残差块公式全文最重要y F ( x , { W i } ) x y F(x, \{W_i\}) xyF(x,{Wi})xy yy残差块输出x xx块输入直接跳过卷积F ( x , { W i } ) F(x, \{W_i\})F(x,{Wi})卷积层学习的残差函数 元素级相加通道必须一致通俗解释让网络只学“输入和输出的差值”最优情况只需让F ( x ) 0 F(x)0F(x)0输出就是x xx比直接学恒等映射简单一万倍。2.2 维度不匹配时升维/下采样y F ( x , { W i } ) W s x y F(x, \{W_i\}) W_s xyF(x,{Wi})WsxW s W_sWs1×1 卷积对齐通道与尺寸仅在尺寸/通道变化时使用2.3 为什么能解决退化梯度可以直接通过 shortcut 回传不会逐层衰减深层网络轻松收敛训练误差不上升不增加额外计算量与参数三、网络结构Plain → ResNet3.1 设计规则特征图尺寸减半 → 通道数翻倍下采样通过 stride2 卷积实现全部 3×3 卷积遵循 VGG 设计3.2 两种残差块基础块ResNet-18/343×3 → BN → ReLU → 3×3 → BN → x → ReLU瓶颈块ResNet-50/101/1521×1降维→ 3×3 → 1×1升维速度更快参数更少适合极深层四、完整公式体系逐字母解释4.1 残差映射F ( x ) H ( x ) − x F(x) H(x) - xF(x)H(x)−xH ( x ) H(x)H(x)期望的总映射x xx输入F ( x ) F(x)F(x)网络需要学习的残差4.2 残差块前向y F ( x ) x H ( x ) y F(x) x H(x)yF(x)xH(x)4.3 梯度回传关键∂ L ∂ x ∂ L ∂ y ⋅ ∂ y ∂ x ∂ L ∂ y ⋅ ( 1 ∂ F ∂ x ) \frac{\partial L}{\partial x} \frac{\partial L}{\partial y} \cdot \frac{\partial y}{\partial x} \frac{\partial L}{\partial y} \cdot (1 \frac{\partial F}{\partial x})∂x∂L∂y∂L⋅∂x∂y∂y∂L⋅(1∂x∂F)∂ L ∂ x \frac{\partial L}{\partial x}∂x∂L输入梯度1 11来自 shortcut保证梯度不会消失∂ L ∂ y \frac{\partial L}{\partial y}∂y∂L输出梯度通俗解释梯度多了一条“高速公路”直接传回浅层再深也不会丢梯度。五、网络架构一览图 3. ImageNet 的示例网络架构。左作为参考的 VGG-19 模型 [41]196 亿次浮点运算。中具有 34 个参数层的普通网络36 亿次浮点运算。右具有 34 个参数层的残差网络36 亿次浮点运算。虚线捷径用于增加维度。表 1 展示了更多细节和其他变体。分析右侧每两层加一条 shortcut结构几乎不变性能质变。表格1 主流 ResNet 结构配置模型层数结构参数ResNet-1818基础块11.7MResNet-3434基础块21.8MResNet-5050瓶颈块25.6MResNet-101101瓶颈块44.5MResNet-152152瓶颈块60.2M表格1 出处ResNet 原文分析152 层比 VGG-19 更深复杂度却更低。六、实验结果全文验证6.1 ImageNet 分类效果表格2 ImageNet 单模型 Top-5 错误率模型错误率VGG-199.33%ResNet-347.76%ResNet-506.71%ResNet-1016.05%ResNet-1525.71%表格2 出处ResNet 原文分析越深越准彻底解决退化。6.2 退化问题验证图 4. 在 ImageNet 上的训练情况。细曲线表示训练误差粗曲线表示中心裁剪部分的验证误差。左图18 层和 34 层的普通网络。右图18 层和 34 层的 ResNet 网络。在该图中残差网络与普通网络相比没有额外的参数。分析Plain-34 训练误差高于 Plain-18退化ResNet-34 训练误差低于 ResNet-18无退化6.3 CIFAR-10 超深实验可训练1202 层超深网络训练误差 0.1%证明残差结构可无限加深七、迁移学习检测/分割屠榜表格3 COCO 检测 mAP 提升主干mAP[.5,.95]VGG-1621.2%ResNet-10127.2%表格3 出处ResNet 原文分析仅更换主干mAP 相对提升28%成为 Faster R-CNN、Mask R-CNN 标配骨架。八、核心代码PyTorch 完整版importtorchimporttorch.nnasnn# 基础残差块ResNet-18/34classBasicBlock(nn.Module):expansion1def__init__(self,inplanes,planes,stride1):super().__init__()self.conv1nn.Conv2d(inplanes,planes,3,stride,1,biasFalse)self.bn1nn.BatchNorm2d(planes)self.relunn.ReLU(True)self.conv2nn.Conv2d(planes,planes,3,1,1,biasFalse)self.bn2nn.BatchNorm2d(planes)# 下采样/维度对齐self.downsampleNoneifstride!1orinplanes!planes:self.downsamplenn.Sequential(nn.Conv2d(inplanes,planes,1,stride,biasFalse),nn.BatchNorm2d(planes))defforward(self,x):identityx outself.relu(self.bn1(self.conv1(x)))outself.bn2(self.conv2(out))ifself.downsampleisnotNone:identityself.downsample(x)outidentity outself.relu(out)returnout# 瓶颈残差块ResNet-50classBottleneck(nn.Module):expansion4def__init__(self,inplanes,planes,stride1):super().__init__()self.conv1nn.Conv2d(inplanes,planes,1,biasFalse)self.bn1nn.BatchNorm2d(planes)self.conv2nn.Conv2d(planes,planes,3,stride,1,biasFalse)self.bn2nn.BatchNorm2d(planes)self.conv3nn.Conv2d(planes,planes*4,1,biasFalse)self.bn3nn.BatchNorm2d(planes*4)self.relunn.ReLU(True)self.downsampleNoneifstride!1orinplanes!planes*4:self.downsamplenn.Sequential(nn.Conv2d(inplanes,planes*4,1,stride,biasFalse),nn.BatchNorm2d(planes*4))defforward(self,x):identityx outself.relu(self.bn1(self.conv1(x)))outself.relu(self.bn2(self.conv2(out)))outself.bn3(self.conv3(out))ifself.downsampleisnotNone:identityself.downsample(x)outidentity outself.relu(out)returnout# ResNet 主体classResNet(nn.Module):def__init__(self,block,layers,num_classes1000):super().__init__()self.inplanes64self.conv1nn.Conv2d(3,64,7,2,3,biasFalse)self.bn1nn.BatchNorm2d(64)self.relunn.ReLU(True)self.maxpoolnn.MaxPool2d(3,2,1)self.layer1self._make_layer(block,64,layers[0])self.layer2self._make_layer(block,128,layers[1],stride2)self.layer3self._make_layer(block,256,layers[2],stride2)self.layer4self._make_layer(block,512,layers[3],stride2)self.avgpoolnn.AdaptiveAvgPool2d((1,1))self.fcnn.Linear(512*block.expansion,num_classes)def_make_layer(self,block,planes,blocks,stride1):layers[]layers.append(block(self.inplanes,planes,stride))self.inplanesplanes*block.expansionfor_inrange(1,blocks):layers.append(block(self.inplanes,planes))returnnn.Sequential(*layers)defforward(self,x):xself.maxpool(self.relu(self.bn1(self.conv1(x))))xself.layer1(x)xself.layer2(x)xself.layer3(x)xself.layer4(x)xself.avgpool(x)xtorch.flatten(x,1)xself.fc(x)returnx# 构建模型defresnet18():returnResNet(BasicBlock,[2,2,2,2])defresnet34():returnResNet(BasicBlock,[3,4,6,3])defresnet50():returnResNet(Bottleneck,[3,4,6,3])defresnet101():returnResNet(Bottleneck,[3,4,23,3])defresnet152():returnResNet(Bottleneck,[3,8,36,3])九、全文总结核心突破用残差学习shortcut解决深层网络退化数学极简y F ( x ) x yF(x)xyF(x)x不增参、不增计算精度革命152 层 CNNImageNet 错误率 3.57%通用骨架统治图像分类、检测、分割、超分、关键点所有视觉任务历史地位CNN 从浅到深的分水岭现代视觉模型基石