058、BaseTrainer训练循环源码拆解:Epoch 迭代、前向传播、Loss 计算
058、BaseTrainer训练循环源码拆解Epoch 迭代、前向传播、Loss 计算从一次诡异的梯度爆炸说起去年有个项目训练YOLOv8检测无人机跑了200个epochmAP一直卡在0.3上不去。我翻来覆去调学习率、改数据增强都没用。最后实在没辙把训练循环里每个环节的tensor shape和数值范围都打印出来才发现问题出在loss计算时的一个维度广播错误——某个分支的target tensor在batch维度上多复制了一次导致loss被放大了4倍梯度直接炸穿。那次之后我养成了一个习惯训练循环里的每一行代码都要知道它在干什么以及它可能怎么坑你。今天我们就来拆解Ultralytics中BaseTrainer的train方法把epoch迭代、前向传播、loss计算这三个核心环节扒个底朝天。训练循环的骨架call与 train 方法打开ultralytics/engine/trainer.py找到BaseTrainer类。它的入口是__call__方法但真正干活的是train方法。别被那些装饰器和回调函数吓到核心逻辑其实就三层循环deftrain(self):# 外层epoch循环forepochinrange(self.epochs):self.model.train()# 中层batch循环fori,batchinenumerate(self.train_loader):# 内层单步训练lossself.loss_fn(self.model(batch[img]),batch)self.scaler.scale(loss).backward()self.optimizer_step()这里有个容易忽略的细节self.model.train()放在epoch循环开始处而不是batch循环里。为什么因为有些模型比如带BatchNorm的在训练和评估模式下行为不同但每个epoch开始时确保是训练模式就够了。别在batch循环里反复切换模式我见过有人这么写每次forward前都调一次model.train()虽然不影响结果但多了几百次函数调用开销。Epoch迭代那些你以为是废话的细节epoch循环看起来简单但YOLO的训练器在这里埋了不少坑。看这段代码forepochinrange(self.start_epoch,self.epochs):# 这里踩过坑start_epoch可能是恢复训练时的值# 如果直接range(self.epochs)断点续训就废了self.epochepoch# 学习率调度每个epoch开始前更新self.scheduler.step()# 数据加载器的shuffle每个epoch重新打乱# 但注意如果用了分布式采样器这里不能手动shuffleifhasattr(self.train_loader.sampler,set_epoch):self.train_loader.sampler.set_epoch(epoch)start_epoch这个参数在断点续训时从checkpoint里读出来。如果你写死了range(self.epochs)恢复训练时会从epoch 0重新开始前面白练了。正确做法是存一个start_epoch到checkpoint里Ultralytics就是这么干的。学习率调度器scheduler.step()放在epoch开始前还是batch后这取决于调度器类型。YOLO默认用余弦退火每个epoch更新一次所以放在epoch循环开头。如果你用ReduceLROnPlateau那得在验证后根据指标更新。别把调度器位置搞反了否则学习率曲线会和你预期完全不一样。前向传播模型到底输出了什么进入batch循环后第一件事是前向传播。但YOLO的前向传播不是简单的model(images)它做了很多包装# 实际代码在_loss方法里但核心逻辑是这样的imagesbatch[img].to(self.device,non_blockingTrue)# non_blockingTrue异步传输别等继续干别的# 混合精度上下文withtorch.cuda.amp.autocast(self.amp):predsself.model(images)# 这里返回的是list of tensors# 注意preds不是最终输出是中间特征# 对于YOLOv8preds[0]是检测头输出shape: [batch, 84, 8400]# 84 4(bbox) 80(class) 1(obj) 不对YOLOv8没有obj分支# 实际上是 4(bbox) 80(class) 84这里有个容易误解的地方self.model(images)返回的是什么YOLOv8的Model类在forward方法里做了判断——如果是训练模式返回的是self.model(images)的原始输出一个包含多个尺度的list如果是推理模式会经过后处理返回检测结果。别在训练时用model(images, augmentTrue)我见过有人把推理时的数据增强参数传进去结果模型返回了增强后的预测loss计算直接崩了。训练时前向传播就是纯前向不要加任何后处理。Loss计算三个分支的恩怨情仇Loss计算是训练循环里最复杂也最容易出错的部分。YOLOv8的loss函数在ultralytics/utils/loss.py里叫v8DetectionLoss。它的核心逻辑是def__call__(self,preds,batch):# preds: list of 3 tensors, 对应三个检测尺度# batch: dict, 包含cls, bbox等losstorch.zeros(3,deviceself.device)# [box_loss, cls_loss, dfl_loss]fori,predinenumerate(preds):# 对每个尺度分别计算loss# 1. 正样本匹配找到哪些anchor负责预测哪些gttarget_bbox,target_cls,fg_maskself.assigner(pred,batch)# 2. 计算box loss (CIoU)loss[0]self.bce(pred_bbox,target_bbox)# 这里简化了实际是CIoU# 3. 计算class loss (BCE)loss[1]self.bce(pred_cls,target_cls,weightfg_mask)# 4. 计算DFL loss (Distribution Focal Loss)loss[2]self.dfl(pred_dfl,target_bbox)returnloss.sum()*self.batch_size# 注意这里乘了batch_size这里踩过坑loss.sum() * self.batch_size这行。为什么乘batch_size因为YOLO的loss是每个样本的平均但梯度累积时需要对batch求和。如果你用分布式训练这个操作会导致loss被放大world_size倍。Ultralytics在分布式场景下会除以world_size来抵消但单卡训练时没问题。正样本匹配self.assigner是loss计算的核心。YOLOv8用的TaskAlignedAssigner它会根据预测和gt的alignment程度动态分配正样本。别试图自己写正样本匹配我试过边界情况多到怀疑人生直接用官方的就好。反向传播与优化器步进Loss计算完后的反向传播看起来简单但混合精度训练时有很多细节# 梯度清零self.optimizer.zero_grad()# 反向传播混合精度self.scaler.scale(loss).backward()# 梯度裁剪可选但YOLO默认不裁ifself.args.clip_grad0:torch.nn.utils.clip_grad_norm_(self.model.parameters(),self.args.clip_grad)# 优化器步进self.scaler.step(self.optimizer)self.scaler.update()别这样写loss.backward()直接调用而不通过scaler。混合精度训练时loss会被scaler放大直接backward会导致梯度数值不稳定。一定要用scaler.scale(loss).backward()。梯度裁剪在YOLO里默认是关闭的因为YOLO的loss设计已经比较稳定。但如果你改动了loss函数或者加了新的分支建议加上梯度裁剪max_norm10.0是个安全的起点。日志记录与验证别让训练成了黑盒每个batch结束后训练器会记录loss和指标。但YOLO的日志记录有个特点——它记录的是滑动平均而不是当前batch的原始值# 在train.py的_loss方法里self.lossloss.item()# 当前batch的lossself.tloss(self.tloss*iself.loss)/(i1)ifi0elseself.loss# tloss是滑动平均用于打印和tensorboard这个tloss在打印时看起来更平滑但如果你要调试loss是否异常一定要看原始self.loss。滑动平均会掩盖短时间的loss尖峰让你以为训练很稳定实际上可能已经出问题了。验证环节在每个epoch结束后执行但YOLO默认是每10个epoch验证一次通过self.args.val控制。验证时模型切换到eval模式关闭dropout和batch norm的统计更新self.model.eval()withtorch.no_grad():forbatchinself.val_loader:predsself.model(batch[img])# 后处理NMS等# 计算mAP注意验证时不要用torch.cuda.amp.autocast因为验证不需要混合精度而且amp在某些后处理操作如NMS上可能不兼容。个人经验训练循环调试三板斧打印每个环节的tensor shape和数值范围。在loss计算前后加print(preds[0].shape, preds[0].min(), preds[0].max())在反向传播前加print(loss.item())。如果loss突然变成nan或inf看是哪个分支导致的。用小的overfit测试。取一个batch的数据训练100个step看loss能不能降到接近0。如果不能说明模型或loss函数有问题。我一般用--batch 1 --epochs 1快速验证。梯度检查。在backward()后加torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)然后打印梯度范数。如果梯度范数突然变大说明某个分支的梯度爆炸了需要检查loss计算或学习率。最后说一句训练循环是YOLO里最枯燥但最重要的部分。别嫌麻烦把每一行代码都理解透遇到问题才能快速定位。下次遇到loss不收敛别急着调学习率先看看训练循环里有没有bug。