Pytorch图像去噪实战十二DDPM图像去噪完整训练流程构建可复现扩散模型工程一、问题场景扩散模型能跑但工程代码很容易写乱上一篇我们从最小实现理解了 Diffusion 的核心逻辑。但如果真正放到项目里会很快遇到问题beta schedule 写在训练脚本里后续不好改采样逻辑和训练逻辑混在一起模型保存与恢复不规范训练参数不可复现后续无法扩展 DDIM、条件去噪、彩色图像很多人学扩散模型时能写出一个 demo但很难整理成工程。这一篇我们重点做一件事把 DDPM 图像去噪流程整理成一个可复现、可扩展的工程结构。二、DDPM核心训练目标DDPM训练目标仍然是预测噪声epsilon_theta(x_t, t) ≈ epsilon训练时从数据集中取 clean image x0随机采样时间步 t根据 t 给 x0 加噪得到 xt模型输入 xt 和 t模型预测 noise使用 MSELoss 训练三、推荐工程结构ddpm_denoise/ ├── configs/ │ └── train_config.py ├── data/ │ └── train/ ├── models/ │ └── unet.py ├── diffusion/ │ └── ddpm.py ├── dataset.py ├── train.py ├── sample.py └── utils.py这个结构相比简单 demo 有几个好处模型独立扩散过程独立配置独立训练和采样分离后续扩展方便四、配置文件configs/train_config.pyclassTrainConfig:image_size64channels1batch_size32num_workers4epochs100lr2e-4timesteps1000beta_start1e-4beta_end0.02save_interval10data_dirdata/trainsave_dircheckpoints配置单独抽出来最大的好处是实验参数不会散落在代码里。后面复现实验时非常重要。五、数据集代码dataset.pyimportosfromPILimportImagefromtorch.utils.dataimportDatasetimporttorchvision.transformsastransformsclassImageFolderDataset(Dataset):def__init__(self,root_dir,image_size64,channels1):self.paths[os.path.join(root_dir,name)fornameinos.listdir(root_dir)ifname.lower().endswith((.jpg,.jpeg,.png))]ifchannels1:self.modeLelse:self.modeRGBself.transformtransforms.Compose([transforms.Resize((image_size,image_size)),transforms.ToTensor()])def__len__(self):returnlen(self.paths)def__getitem__(self,index):imgImage.open(self.paths[index]).convert(self.mode)returnself.transform(img)六、DDPM扩散类封装diffusion/ddpm.pyimporttorchclassDDPM:def__init__(self,timesteps1000,beta_start1e-4,beta_end0.02,devicecuda):self.timestepstimesteps self.devicedevice self.betastorch.linspace(beta_start,beta_end,timesteps).to(device)self.alphas1.0-self.betas self.alpha_barstorch.cumprod(self.alphas,dim0)self.sqrt_alpha_barstorch.sqrt(self.alpha_bars)self.sqrt_one_minus_alpha_barstorch.sqrt(1.0-self.alpha_bars)defq_sample(self,x0,t,noiseNone):ifnoiseisNone:noisetorch.randn_like(x0)sqrt_alpha_barself.sqrt_alpha_bars[t].view(-1,1,1,1)sqrt_one_minusself.sqrt_one_minus_alpha_bars[t].view(-1,1,1,1)xtsqrt_alpha_bar*x0sqrt_one_minus*noisereturnxt,noisetorch.no_grad()defp_sample(self,model,x,t):betaself.betas[t]alphaself.alphas[t]alpha_barself.alpha_bars[t]batch_ttorch.full((x.size(0),),t,devicex.device,dtypetorch.long)pred_noisemodel(x,batch_t)mean(1/torch.sqrt(alpha))*(x-(beta/torch.sqrt(1.0-alpha_bar))*pred_noise)ift0:noisetorch.randn_like(x)returnmeantorch.sqrt(beta)*noisereturnmean七、UNet噪声预测模型models/unet.pyimporttorchimporttorch.nnasnnclassTimeEmbedding(nn.Module):def__init__(self,dim):super().__init__()self.netnn.Sequential(nn.Linear(1,dim),nn.SiLU(),nn.Linear(dim,dim))defforward(self,t):tt.float().view(-1,1)/1000.0returnself.net(t)classResidualBlock(nn.Module):def__init__(self,in_channels,out_channels,time_dim):super().__init__()self.conv1nn.Conv2d(in_channels,out_channels,3,padding1)self.conv2nn.Conv2d(out_channels,out_channels,3,padding1)self.time_projnn.Linear(time_dim,out_channels)self.shortcutnn.Identity()ifin_channels!out_channels:self.shortcutnn.Conv2d(in_channels,out_channels,1)self.actnn.SiLU()defforward(self,x,t_emb):hself.act(self.conv1(x))timeself.time_proj(t_emb).view(x.size(0),-1,1,1)hhtime hself.conv2(self.act(h))returnhself.shortcut(x)classDDPMUNet(nn.Module):def__init__(self,channels1,base64,time_dim128):super().__init__()self.time_mlpTimeEmbedding(time_dim)self.down1ResidualBlock(channels,base,time_dim)self.down2ResidualBlock(base,base*2,time_dim)self.poolnn.MaxPool2d(2)self.midResidualBlock(base*2,base*2,time_dim)self.upnn.ConvTranspose2d(base*2,base,2,2)self.up_blockResidualBlock(base*2,base,time_dim)self.outnn.Conv2d(base,channels,3,padding1)defforward(self,x,t):t_embself.time_mlp(t)d1self.down1(x,t_emb)d2self.down2(self.pool(d1),t_emb)midself.mid(d2,t_emb)uself.up(mid)utorch.cat([u,d1],dim1)uself.up_block(u,t_emb)returnself.out(u)八、训练脚本train.pyimportosimporttorchfromtorch.utils.dataimportDataLoaderfromconfigs.train_configimportTrainConfigfromdatasetimportImageFolderDatasetfrommodels.unetimportDDPMUNetfromdiffusion.ddpmimportDDPMdeftrain():cfgTrainConfig()os.makedirs(cfg.save_dir,exist_okTrue)devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)datasetImageFolderDataset(root_dircfg.data_dir,image_sizecfg.image_size,channelscfg.channels)loaderDataLoader(dataset,batch_sizecfg.batch_size,shuffleTrue,num_workerscfg.num_workers)modelDDPMUNet(channelscfg.channels).to(device)diffusionDDPM(timestepscfg.timesteps,beta_startcfg.beta_start,beta_endcfg.beta_end,devicedevice)optimizertorch.optim.AdamW(model.parameters(),lrcfg.lr)criteriontorch.nn.MSELoss()forepochinrange(1,cfg.epochs1):model.train()total_loss0forx0inloader:x0x0.to(device)ttorch.randint(0,cfg.timesteps,(x0.size(0),),devicedevice)xt,noisediffusion.q_sample(x0,t)pred_noisemodel(xt,t)losscriterion(pred_noise,noise)optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)optimizer.step()total_lossloss.item()avg_losstotal_loss/len(loader)print(fEpoch [{epoch}/{cfg.epochs}], Loss:{avg_loss:.6f})ifepoch%cfg.save_interval0:pathos.path.join(cfg.save_dir,fddpm_epoch_{epoch}.pth)torch.save(model.state_dict(),path)if__name____main__:train()九、采样脚本sample.pyimporttorchimporttorchvision.utilsasvutilsfromconfigs.train_configimportTrainConfigfrommodels.unetimportDDPMUNetfromdiffusion.ddpmimportDDPMtorch.no_grad()defsample():cfgTrainConfig()devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)modelDDPMUNet(channelscfg.channels).to(device)model.load_state_dict(torch.load(checkpoints/ddpm_epoch_100.pth,map_locationdevice))model.eval()diffusionDDPM(timestepscfg.timesteps,beta_startcfg.beta_start,beta_endcfg.beta_end,devicedevice)xtorch.randn(16,cfg.channels,cfg.image_size,cfg.image_size).to(device)fortinreversed(range(cfg.timesteps)):xdiffusion.p_sample(model,x,t)xtorch.clamp(x,0.0,1.0)vutils.save_image(x.cpu(),ddpm_sample.png,nrow4)if__name____main__:sample()十、为什么要做工程拆分很多扩散模型代码一开始写在一个文件里能跑但很难维护。工程拆分带来的好处diffusion类可复用UNet可替换config方便调参train和sample互不干扰后续DDIM可以直接扩展这也是从“能跑demo”到“能做项目”的关键一步。十一、踩坑记录坑1采样结果全是噪声常见原因模型训练不够时间步输入错误beta schedule太激进采样公式写错建议先用小数据集验证过拟合能力。坑2loss下降但采样效果差DDPM的loss下降不代表马上能生成好图。采样质量通常需要更多训练轮数。坑3训练太慢DDPM采样慢是正常现象因为要从 T 逐步采样。后续可以使用 DDIM 或减少 timesteps。十二、适合收藏总结DDPM工程化流程配置文件管理参数Dataset加载图像DDPM类负责加噪和采样UNet预测噪声train.py训练模型sample.py生成结果避坑清单不要把所有代码写一个文件时间步必须正确传入beta schedule要稳定采样结果差不一定是loss问题先用小尺寸图跑通十三、优化建议后续可以继续做DDIM加速采样条件Diffusion去噪彩色图像支持EMA模型权重混合精度训练结尾总结DDPM不是一个单独模型而是一套完整的扩散训练和采样框架。如果你只是写一个demo很容易跑通但如果要长期做系列实验就必须从一开始整理好工程结构。这一篇的重点不是追求最强效果而是把DDPM搭成一个稳定可复现的项目骨架。下一篇预告Pytorch图像去噪实战十三DDIM加速采样让扩散模型去噪从1000步降到50步