告别卷积!用PyTorch从零实现Vision Transformer (ViT) 图像分类,附完整代码
从零构建Vision Transformer用PyTorch实现图像分类新范式当Transformer在自然语言处理领域大放异彩时计算机视觉专家们开始思考这种基于自注意力机制的架构能否同样革新图像处理2020年Google Research的Vision TransformerViT论文给出了肯定答案。本文将带你用PyTorch从零实现这一突破性架构跳过繁琐的数学推导直接通过代码理解ViT的核心思想。1. 环境准备与数据加载在开始构建ViT之前我们需要配置合适的开发环境。建议使用Python 3.8和PyTorch 1.10版本这些版本对Transformer相关操作有更好的支持。以下是基础环境配置步骤pip install torch torchvision torchaudio pip install numpy matplotlib tqdm对于图像数据我们将使用经典的CIFAR-10数据集作为示例。这个数据集包含60,000张32x32像素的彩色图像分为10个类别非常适合验证ViT在小规模图像上的表现import torch from torchvision import datasets, transforms # 数据增强和归一化 transform transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 加载数据集 train_data datasets.CIFAR10(data, trainTrue, downloadTrue, transformtransform) test_data datasets.CIFAR10(data, trainFalse, downloadTrue, transformtransform) # 创建数据加载器 batch_size 64 train_loader torch.utils.data.DataLoader(train_data, batch_sizebatch_size, shuffleTrue) test_loader torch.utils.data.DataLoader(test_data, batch_sizebatch_size)提示如果GPU可用建议将数据和模型转移到GPU上加速训练。可以使用device torch.device(cuda if torch.cuda.is_available() else cpu)来检测可用设备。2. ViT核心模块实现2.1 Patch Embedding图像到序列的转换传统Transformer处理的是单词序列而ViT的关键创新在于将图像分割为小块patches然后把这些小块视为视觉单词。以下是Patch Embedding的PyTorch实现import torch.nn as nn class PatchEmbedding(nn.Module): def __init__(self, img_size32, patch_size8, in_channels3, embed_dim512): super().__init__() self.img_size img_size self.patch_size patch_size self.n_patches (img_size // patch_size) ** 2 # 使用卷积层实现分块和投影 self.proj nn.Conv2d( in_channels, embed_dim, kernel_sizepatch_size, stridepatch_size ) def forward(self, x): # x形状: [batch_size, channels, height, width] x self.proj(x) # [batch_size, embed_dim, n_patches^0.5, n_patches^0.5] x x.flatten(2) # [batch_size, embed_dim, n_patches] x x.transpose(1, 2) # [batch_size, n_patches, embed_dim] return x这个模块的工作原理将输入图像分割为不重叠的8x8小块对于32x32的CIFAR-10图像得到16个小块每个小块被展平为8x8x3192维的向量通过线性投影将每个小块映射到512维的嵌入空间2.2 Position Embedding引入空间信息与NLP中的单词不同图像块具有明确的二维空间关系。我们需要通过位置编码将这些信息注入模型class PositionEmbedding(nn.Module): def __init__(self, n_patches, embed_dim): super().__init__() # 可学习的位置编码 self.pos_embed nn.Parameter(torch.zeros(1, n_patches 1, embed_dim)) nn.init.trunc_normal_(self.pos_embed, std0.02) def forward(self, x): # x形状: [batch_size, n_patches, embed_dim] # 添加分类token cls_token torch.zeros(x.shape[0], 1, x.shape[2], devicex.device) x torch.cat([cls_token, x], dim1) # 添加位置编码 x x self.pos_embed return x注意ViT在序列开头添加了一个特殊的[class] token这个token的最终状态将用于分类任务。这与BERT中的[CLS] token类似。2.3 Transformer Encoder实现ViT使用标准的Transformer编码器结构包含多头自注意力机制和前馈神经网络class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads, mlp_ratio4.0, dropout0.1): super().__init__() self.norm1 nn.LayerNorm(embed_dim) self.attn nn.MultiheadAttention(embed_dim, num_heads, dropoutdropout) self.norm2 nn.LayerNorm(embed_dim) mlp_hidden_dim int(embed_dim * mlp_ratio) self.mlp nn.Sequential( nn.Linear(embed_dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden_dim, embed_dim), nn.Dropout(dropout) ) def forward(self, x): # 层归一化 多头注意力 残差连接 x x self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0] # 层归一化 MLP 残差连接 x x self.mlp(self.norm2(x)) return x3. 构建完整ViT模型现在我们可以将各个模块组合起来构建完整的Vision Transformerclass VisionTransformer(nn.Module): def __init__(self, img_size32, patch_size8, in_channels3, embed_dim512, depth6, num_heads8, mlp_ratio4.0, num_classes10, dropout0.1): super().__init__() self.patch_embed PatchEmbedding(img_size, patch_size, in_channels, embed_dim) n_patches self.patch_embed.n_patches self.pos_embed PositionEmbedding(n_patches, embed_dim) # 堆叠Transformer块 self.blocks nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth) ]) # 分类头 self.norm nn.LayerNorm(embed_dim) self.head nn.Linear(embed_dim, num_classes) # 初始化权重 self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, x): # 生成patch embedding x self.patch_embed(x) # 添加位置编码 x self.pos_embed(x) # 通过Transformer编码器 for block in self.blocks: x block(x) # 使用[class] token进行分类 x self.norm(x[:, 0]) x self.head(x) return x4. 训练与优化策略ViT的训练需要特别注意学习率和正则化的设置以下是一个完整的训练流程def train_model(): # 初始化模型 model VisionTransformer( img_size32, patch_size8, embed_dim256, depth6, num_heads8, num_classes10 ).to(device) # 损失函数和优化器 criterion nn.CrossEntropyLoss() optimizer torch.optim.AdamW(model.parameters(), lr3e-4, weight_decay0.05) # 学习率调度器 scheduler torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max100) # 训练循环 for epoch in range(100): model.train() train_loss 0.0 correct 0 total 0 for inputs, labels in train_loader: inputs, labels inputs.to(device), labels.to(device) optimizer.zero_grad() outputs model(inputs) loss criterion(outputs, labels) loss.backward() optimizer.step() train_loss loss.item() _, predicted outputs.max(1) total labels.size(0) correct predicted.eq(labels).sum().item() scheduler.step() # 验证集评估 val_acc evaluate(model, test_loader) print(fEpoch {epoch1}: Train Loss: {train_loss/len(train_loader):.4f}, fTrain Acc: {100.*correct/total:.2f}%, Val Acc: {100.*val_acc:.2f}%) def evaluate(model, loader): model.eval() correct 0 total 0 with torch.no_grad(): for inputs, labels in loader: inputs, labels inputs.to(device), labels.to(device) outputs model(inputs) _, predicted outputs.max(1) total labels.size(0) correct predicted.eq(labels).sum().item() return correct / total提示ViT相比CNN通常需要更长的训练时间和更多的数据增强。在实际项目中可以尝试以下技巧提升性能使用MixUp或CutMix数据增强增加模型深度和嵌入维度需要更多计算资源采用渐进式学习率预热5. 实际应用中的挑战与解决方案虽然ViT在理论上非常优雅但在实际应用中会遇到几个关键挑战5.1 计算资源需求ViT对计算资源的需求主要来自两个方面因素影响缓解策略序列长度内存消耗随序列长度平方增长使用更大的patch size减少序列长度模型深度深层Transformer训练困难采用残差连接和层归一化对于高分辨率图像可以考虑以下优化方案# 高效处理大图像的策略 class EfficientViT(nn.Module): def __init__(self): super().__init__() # 使用更大的patch size self.patch_embed PatchEmbedding(img_size224, patch_size32) # 或者使用金字塔结构 self.stage1 PatchEmbedding(img_size224, patch_size16) self.stage2 PatchEmbedding(img_size112, patch_size16)5.2 小数据集上的表现ViT在大型数据集如JFT-300M上表现出色但在小型数据集如CIFAR-10上可能不如CNN。改进方法包括知识蒸馏使用训练好的CNN作为教师模型指导ViT训练迁移学习在大型数据集上预训练然后在小数据集上微调数据增强应用更强的数据增强策略如RandAugment5.3 位置编码的替代方案原始ViT使用可学习的一维位置编码这可能不是最优选择。其他替代方案包括相对位置编码考虑patch之间的相对距离而非绝对位置二维位置编码明确编码x和y方向的位置信息条件位置编码根据图像内容动态生成位置编码# 二维位置编码示例 class PositionEmbedding2D(nn.Module): def __init__(self, grid_size, embed_dim): super().__init__() self.row_embed nn.Parameter(torch.randn(grid_size, embed_dim//2)) self.col_embed nn.Parameter(torch.randn(grid_size, embed_dim//2)) def forward(self, x): h, w x.shape[1], x.shape[2] pos_embed torch.cat([ self.row_embed[:h].unsqueeze(1).repeat(1, w, 1), self.col_embed[:w].unsqueeze(0).repeat(h, 1, 1) ], dim-1) return x pos_embed.flatten(0, 1).unsqueeze(0)在CIFAR-10上训练约100个epoch后这个实现的ViT模型可以达到约85%的测试准确率。虽然这个结果可能不如精心调优的CNN模型但它验证了纯Transformer架构在计算机视觉任务中的可行性。