从SRResNet到SRGAN:手把手复现经典论文,用PyTorch搞定‘以假乱真’的图像超分辨率
从SRResNet到SRGANPyTorch实战图像超分辨率重建1. 环境准备与数据加载在开始构建SRGAN之前我们需要配置合适的开发环境。推荐使用Python 3.8和PyTorch 1.10版本配合CUDA 11.3进行GPU加速。以下是基础环境配置步骤conda create -n srgan python3.8 conda activate srgan pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 pip install opencv-python pillow matplotlib tqdm对于数据集准备建议使用DIV2K或ImageNet的子集。这里提供一个通用的数据加载器实现import torch from torch.utils.data import Dataset from PIL import Image import os class SRDataset(Dataset): def __init__(self, hr_dir, lr_dir, transformNone): self.hr_images [os.path.join(hr_dir, x) for x in sorted(os.listdir(hr_dir))] self.lr_images [os.path.join(lr_dir, x) for x in sorted(os.listdir(lr_dir))] self.transform transform def __len__(self): return len(self.hr_images) def __getitem__(self, idx): lr_img Image.open(self.lr_images[idx]).convert(RGB) hr_img Image.open(self.hr_images[idx]).convert(RGB) if self.transform: lr_img self.transform(lr_img) hr_img self.transform(hr_img) return lr_img, hr_img提示在实际应用中建议使用数据增强技术如随机裁剪、旋转和翻转来增加训练样本的多样性。2. SRResNet生成器实现SRResNet作为SRGAN的生成器核心由多个残差块组成。以下是关键组件的PyTorch实现import torch.nn as nn class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 nn.Conv2d(channels, channels, kernel_size3, padding1) self.bn1 nn.BatchNorm2d(channels) self.prelu nn.PReLU() self.conv2 nn.Conv2d(channels, channels, kernel_size3, padding1) self.bn2 nn.BatchNorm2d(channels) def forward(self, x): residual x out self.prelu(self.bn1(self.conv1(x))) out self.bn2(self.conv2(out)) out residual return out class UpsampleBLock(nn.Module): def __init__(self, in_channels, up_scale): super(UpsampleBLock, self).__init__() self.conv nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size3, padding1) self.pixel_shuffle nn.PixelShuffle(up_scale) self.prelu nn.PReLU() def forward(self, x): x self.conv(x) x self.pixel_shuffle(x) x self.prelu(x) return x完整的SRResNet生成器结构如下表所示层类型参数配置输出尺寸输入卷积Conv2d(3,64,k9,p1)64×H×WPReLU激活-64×H×W残差块 ×16[Conv2d(64,64,k3,p1)→BN→PReLU]×264×H×W后处理卷积Conv2d(64,64,k3,p1)→BN64×H×W上采样块 ×2PixelShuffle(2)256×2H×2W输出卷积Conv2d(64,3,k9,p1)3×4H×4W3. 判别器设计与对抗训练判别器采用类似VGG的结构用于区分生成图像和真实高分辨率图像class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.features nn.Sequential( nn.Conv2d(3, 64, kernel_size3, stride1, padding1), nn.LeakyReLU(0.2, inplaceTrue), nn.Conv2d(64, 64, kernel_size3, stride2, padding1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplaceTrue), # 类似结构重复多次... nn.Conv2d(512, 512, kernel_size3, stride2, padding1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2, inplaceTrue), ) self.classifier nn.Sequential( nn.Linear(512*6*6, 1024), nn.LeakyReLU(0.2, inplaceTrue), nn.Linear(1024, 1), nn.Sigmoid() ) def forward(self, x): x self.features(x) x torch.flatten(x, 1) x self.classifier(x) return x对抗训练的关键在于平衡生成器和判别器的学习进度。以下是训练循环的核心代码def train(generator, discriminator, train_loader, epochs): g_optimizer torch.optim.Adam(generator.parameters(), lr1e-4) d_optimizer torch.optim.Adam(discriminator.parameters(), lr1e-4) adversarial_loss nn.BCELoss() content_loss nn.MSELoss() for epoch in range(epochs): for i, (lr_imgs, hr_imgs) in enumerate(train_loader): # 训练判别器 real_labels torch.ones(lr_imgs.size(0), 1) fake_labels torch.zeros(lr_imgs.size(0), 1) # 真实图像损失 real_outputs discriminator(hr_imgs) d_loss_real adversarial_loss(real_outputs, real_labels) # 生成图像损失 fake_imgs generator(lr_imgs) fake_outputs discriminator(fake_imgs.detach()) d_loss_fake adversarial_loss(fake_outputs, fake_labels) # 判别器总损失 d_loss d_loss_real d_loss_fake d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() # 训练生成器 g_loss_content content_loss(fake_imgs, hr_imgs) g_loss_adv adversarial_loss(discriminator(fake_imgs), real_labels) g_loss g_loss_content 1e-3 * g_loss_adv g_optimizer.zero_grad() g_loss.backward() g_optimizer.step()4. 感知损失与模型优化SRGAN的核心创新在于引入了感知损失函数结合了内容损失和对抗损失class VGGLoss(nn.Module): def __init__(self): super(VGGLoss, self).__init__() vgg torchvision.models.vgg19(pretrainedTrue) loss_network nn.Sequential(*list(vgg.features)[:35]).eval() for param in loss_network.parameters(): param.requires_grad False self.loss_network loss_network self.mse_loss nn.MSELoss() def forward(self, fake, real): fake_features self.loss_network(fake) real_features self.loss_network(real) return self.mse_loss(fake_features, real_features)训练过程中需要注意的几个关键点学习率调度初始阶段使用较高学习率(1e-4)后期降低到1e-5预训练策略先用MSE损失预训练SRResNet再微调SRGAN批量归一化在生成器中大量使用BN层判别器中谨慎使用梯度裁剪防止梯度爆炸设置最大范数为0.1# 学习率调度示例 scheduler torch.optim.lr_scheduler.StepLR(optimizer, step_size100000, gamma0.1)5. 结果评估与可视化评估超分辨率结果时除了传统的PSNR和SSIM指标还应关注视觉质量。以下是比较生成结果的代码示例def compare_results(model, test_loader, device): model.eval() with torch.no_grad(): lr, hr next(iter(test_loader)) lr, hr lr.to(device), hr.to(device) sr model(lr) # 计算指标 psnr 10 * torch.log10(1 / torch.mean((sr - hr) ** 2)) # 可视化 plt.figure(figsize(15,5)) plt.subplot(1,3,1); plt.imshow(lr[0].cpu().permute(1,2,0)) plt.title(Low Resolution); plt.axis(off) plt.subplot(1,3,2); plt.imshow(sr[0].cpu().permute(1,2,0)) plt.title(fSuper Resolved (PSNR: {psnr:.2f}dB)); plt.axis(off) plt.subplot(1,3,3); plt.imshow(hr[0].cpu().permute(1,2,0)) plt.title(Original High Resolution); plt.axis(off) plt.show()在实际项目中我发现以下几个技巧能显著提升模型性能使用自注意力机制增强生成器对全局结构的捕捉能力采用渐进式训练策略先训练低分辨率版本再逐步提高分辨率引入谱归一化稳定判别器的训练过程使用多尺度判别器捕捉不同层次的图像特征