Pytorch图像去噪实战(十一):Diffusion扩散模型去噪入门,从噪声预测理解生成式图像恢复
Pytorch图像去噪实战十一Diffusion扩散模型去噪入门从噪声预测理解生成式图像恢复一、问题场景传统去噪模型能用但上限开始明显前面我们已经做了 DnCNN、UNet、ResUNet、Attention UNet、FFDNet、CBDNet、Noise2Noise、Noise2Void、SwinIR、Restormer。这些模型有一个共同特点大多数都是直接学习 noisy - clean或者 noisy - noise。在普通图像去噪任务里这样已经够用。但当我处理一些复杂图像时问题开始变明显高噪声图像细节恢复差老照片去噪后纹理不自然真实噪声图像容易残留伪影强去噪后图像发糊模型对未知噪声泛化不足这时就会接触到一个更强的方向Diffusion Model 扩散模型。扩散模型不是简单做一次映射而是学习一个逐步去噪过程。二、Diffusion去噪和普通去噪有什么区别普通去噪模型noisy_image - clean_imageDiffusion模型clean_image - 不断加噪 - pure noise pure noise - 逐步去噪 - clean_image在训练阶段它学习的是给定某一步的带噪图像预测其中的噪声。也就是x_t - noise这和 DnCNN 的“预测噪声”思想有相似之处但 Diffusion 更进一步把噪声过程拆成了很多步。三、核心思想前向加噪与反向去噪1. 前向过程从干净图像 x0 开始逐步加入噪声x0 - x1 - x2 - ... - xT最后 xT 接近纯噪声。2. 反向过程模型学习从 xT 一步步恢复xT - xT-1 - ... - x0训练目标通常是预测噪声 epsilon。四、工程目录结构diffusion_denoise/ ├── data/ │ └── train/ ├── models/ │ └── simple_unet.py ├── diffusion.py ├── dataset.py ├── train.py ├── sample.py └── utils.py五、数据集准备这里先做灰度图像去噪方便理解扩散模型流程。importosfromPILimportImagefromtorch.utils.dataimportDatasetimporttorchvision.transformsastransformsclassImageDataset(Dataset):def__init__(self,root_dir,image_size64):self.paths[os.path.join(root_dir,name)fornameinos.listdir(root_dir)ifname.lower().endswith((.jpg,.png,.jpeg))]self.transformtransforms.Compose([transforms.Resize((image_size,image_size)),transforms.ToTensor()])def__len__(self):returnlen(self.paths)def__getitem__(self,idx):imgImage.open(self.paths[idx]).convert(L)returnself.transform(img)六、扩散过程实现diffusion.pyimporttorchclassGaussianDiffusion:def__init__(self,timesteps1000,beta_start1e-4,beta_end0.02,devicecuda):self.timestepstimesteps self.devicedevice self.betastorch.linspace(beta_start,beta_end,timesteps).to(device)self.alphas1.0-self.betas self.alpha_barstorch.cumprod(self.alphas,dim0)defadd_noise(self,x0,t):noisetorch.randn_like(x0)alpha_barself.alpha_bars[t].view(-1,1,1,1)noisytorch.sqrt(alpha_bar)*x0torch.sqrt(1.0-alpha_bar)*noisereturnnoisy,noise七、时间步编码Diffusion模型必须知道当前是第几步噪声。importtorchimporttorch.nnasnnimportmathclassTimeEmbedding(nn.Module):def__init__(self,dim):super().__init__()self.dimdim self.mlpnn.Sequential(nn.Linear(dim,dim*4),nn.SiLU(),nn.Linear(dim*4,dim))defforward(self,t):half_dimself.dim//2embmath.log(10000)/(half_dim-1)embtorch.exp(torch.arange(half_dim,devicet.device)*-emb)embt[:,None]*emb[None,:]embtorch.cat([torch.sin(emb),torch.cos(emb)],dim1)returnself.mlp(emb)八、简化版UNet噪声预测网络models/simple_unet.pyimporttorchimporttorch.nnasnnclassTimeEmbedding(nn.Module):def__init__(self,dim):super().__init__()self.netnn.Sequential(nn.Linear(1,dim),nn.SiLU(),nn.Linear(dim,dim))defforward(self,t):tt.float().view(-1,1)/1000.0returnself.net(t)classSimpleDenoiseUNet(nn.Module):def__init__(self,channels1,base64,time_dim128):super().__init__()self.time_mlpTimeEmbedding(time_dim)self.conv1nn.Conv2d(channels,base,3,padding1)self.conv2nn.Conv2d(base,base,3,padding1)self.conv3nn.Conv2d(base,channels,3,padding1)self.time_projnn.Linear(time_dim,base)self.actnn.SiLU()defforward(self,x,t):time_embself.time_mlp(t)time_embself.time_proj(time_emb).view(x.size(0),-1,1,1)hself.act(self.conv1(x))hhtime_emb hself.act(self.conv2(h))returnself.conv3(h)九、训练代码train.pyimporttorchfromtorch.utils.dataimportDataLoaderfromdatasetimportImageDatasetfromdiffusionimportGaussianDiffusionfrommodels.simple_unetimportSimpleDenoiseUNetdeftrain():devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)datasetImageDataset(data/train,image_size64)loaderDataLoader(dataset,batch_size32,shuffleTrue,num_workers4)modelSimpleDenoiseUNet().to(device)diffusionGaussianDiffusion(timesteps1000,devicedevice)optimizertorch.optim.AdamW(model.parameters(),lr2e-4)criteriontorch.nn.MSELoss()forepochinrange(1,101):model.train()total_loss0forx0inloader:x0x0.to(device)ttorch.randint(0,diffusion.timesteps,(x0.size(0),),devicedevice)xt,noisediffusion.add_noise(x0,t)pred_noisemodel(xt,t)losscriterion(pred_noise,noise)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)optimizer.step()total_lossloss.item()print(fEpoch{epoch}, Loss:{total_loss/len(loader):.6f})ifepoch%100:torch.save(model.state_dict(),fdiffusion_epoch_{epoch}.pth)if__name____main__:train()十、为什么Diffusion训练预测noise而不是预测clean这是很多人第一次学扩散模型时最容易疑惑的地方。如果直接预测 cleanmodel(x_t, t) - x0模型在高噪声阶段很难恢复完整图像。而预测 noisemodel(x_t, t) - epsilon训练目标更稳定也更符合扩散模型的数学推导。工程上看预测噪声还有一个优点loss更稳定模型更容易收敛。十一、采样过程简化实现下面写一个简化版采样逻辑帮助理解反向去噪。importtorchimporttorchvision.utilsasvutilsfromdiffusionimportGaussianDiffusionfrommodels.simple_unetimportSimpleDenoiseUNettorch.no_grad()defsample():devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)modelSimpleDenoiseUNet().to(device)model.load_state_dict(torch.load(diffusion_epoch_100.pth,map_locationdevice))model.eval()diffusionGaussianDiffusion(timesteps1000,devicedevice)xtorch.randn(16,1,64,64).to(device)foriinreversed(range(diffusion.timesteps)):ttorch.full((x.size(0),),i,devicedevice,dtypetorch.long)pred_noisemodel(x,t)betadiffusion.betas[i]alphadiffusion.alphas[i]alpha_bardiffusion.alpha_bars[i]x(1/torch.sqrt(alpha))*(x-(beta/torch.sqrt(1-alpha_bar))*pred_noise)ifi0:noisetorch.randn_like(x)xxtorch.sqrt(beta)*noise xtorch.clamp(x,0.0,1.0)vutils.save_image(x.cpu(),diffusion_samples.png,nrow4)if__name____main__:sample()十二、踩坑记录坑1时间步没有输入模型Diffusion模型必须知道 t。如果只输入 x_t不输入 t模型不知道当前噪声强度训练会非常差。坑2学习率过大导致loss震荡扩散模型训练比普通UNet更敏感。建议lr2e-4如果不稳定降到lr1e-4坑3图像尺寸一开始不要太大Diffusion训练成本高。建议从64x64开始流程跑通后再放大。十三、适合收藏总结Diffusion去噪训练流程读取干净图像随机采样时间步 t根据 t 给图像加噪模型预测噪声用真实噪声监督训练推理时逐步反向去噪避坑清单必须输入时间步训练目标建议预测noise初期图像尺寸别太大学习率不要过高采样速度较慢是正常现象十四、优化建议可以继续升级更完整UNet结构加Attention模块使用DDIM加速采样支持条件去噪使用真实噪声数据微调结尾总结Diffusion模型的核心不是“一个更大的UNet”而是一套新的去噪建模方式把图像恢复拆成多个连续的小步骤让模型逐步从噪声中恢复结构。如果你已经理解 DnCNN 的残差噪声预测那么学习 Diffusion 会更容易因为它本质上也是在学噪声只是把这个过程做得更细。下一篇预告Pytorch图像去噪实战十二DDPM图像去噪完整训练流程构建可复现扩散模型工程