简述:小数据集照片分类的模型训练
简述小数据集照片分类的模型训练如20 个分类 × 每类 500 张 总共 1 万张图属于小数据集多分类。一、先确定你该用什么模型直接选MobileNetV2 或 ResNet34理由适合小数据集1 万张以内训练快、不容易过拟合分类效果稳20 类完全够用你之前已经在用 ResNet无缝衔接二、训练前准备文件夹结构必须这样放plaintext dataset/ train/ 类别1/ 001.jpg 002.jpg... 类别2/... 类别20/ val/ 类别1/ 类别2/... 类别20/数据划分规则非常重要每类 500 张400 张 → train 训练100 张 → val 验证三、最简单训练方案推荐你直接用核心技术迁移学习小数据集必用用ImageNet 预训练权重只微调最后几层训练速度快、准确率高训练超参数直接照抄输入尺寸224×224批次大小16 或 32学习率1e-4优化器Adam损失函数CrossEntropyLoss多分类标准训练轮数30~50 轮四、完整训练步骤从 0 到 1第 1 步安装环境1 行命令pipinstalltorch torchvision pillow tqdm第 2 步训练代码复制直接运行importtorchimporttorch.nnasnnfromtorchvisionimportdatasets,models,transformsfromtorch.utils.dataimportDataLoaderimportos# 1. 配置 devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)num_classes20# 你的分类数量batch_size16epochs30lr1e-4# 2. 数据增强小数据集必须加data_transforms{train:transforms.Compose([transforms.Resize((224,224)),transforms.RandomHorizontalFlip(),# 随机翻转transforms.RandomRotation(15),# 随机旋转transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),val:transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])}# 3. 加载数据 data_dirdataset# 你的数据集路径image_datasets{x:datasets.ImageFolder(os.path.join(data_dir,x),data_transforms[x])forxin[train,val]}dataloaders{x:DataLoader(image_datasets[x],batch_sizebatch_size,shuffleTrue,num_workers0)forxin[train,val]}# 4. 模型ResNet34 迁移学习 modelmodels.resnet34(pretrainedTrue)# 修改最后一层为20分类num_ftrsmodel.fc.in_features model.fcnn.Linear(num_ftrs,num_classes)modelmodel.to(device)# 5. 损失函数 优化器 criterionnn.CrossEntropyLoss()optimizertorch.optim.Adam(model.parameters(),lrlr)# 6. 开始训练 forepochinrange(epochs):model.train()running_loss0.0correct0total0forinputs,labelsindataloaders[train]:inputs,labelsinputs.to(device),labels.to(device)optimizer.zero_grad()outputsmodel(inputs)losscriterion(outputs,labels)loss.backward()optimizer.step()running_lossloss.item()_,predictedtorch.max(outputs,1)totallabels.size(0)correct(predictedlabels).sum().item()train_acc100*correct/totalprint(fEpoch{epoch1}/{epochs}| Loss:{running_loss:.4f}| Acc:{train_acc:.2f}%)# 验证model.eval()val_correct0val_total0withtorch.no_grad():forinputs,labelsindataloaders[val]:inputs,labelsinputs.to(device),labels.to(device)outputsmodel(inputs)_,predictedtorch.max(outputs,1)val_totallabels.size(0)val_correct(predictedlabels).sum().item()val_acc100*val_correct/val_totalprint(f验证集准确率:{val_acc:.2f}%\n)# 保存模型torch.save(model.state_dict(),photo_classify_20classes.pth)print(训练完成模型已保存)第 3 步运行训练把数据集放好直接运行代码30~60 分钟训练完成。五、预测代码训练完直接用importtorchfromtorchvisionimportmodels,transformsfromPILimportImage# 加载模型devicetorch.device(cudaiftorch.cuda.is_available()elsecpu)modelmodels.resnet34()model.fctorch.nn.Linear(model.fc.in_features,20)model.load_state_dict(torch.load(photo_classify_20classes.pth))modelmodel.to(device)model.eval()# 预处理transformtransforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])# 预测一张图defpredict_image(img_path):imgImage.open(img_path).convert(RGB)imgtransform(img).unsqueeze(0).to(device)withtorch.no_grad():outputmodel(img)_,predtorch.max(output,1)returnpred.item()# 使用print(predict_image(test.jpg))六、这个数据规模预期效果训练集准确率95%~99%验证集准确率85%~92%20 分类完全够用每类 500 张属于刚刚好的小样本七、总结最关键的 3 点任务类型图像分类模型ResNet34不是分割模型数据每类分 400 训练 100 验证文件夹分类放好训练用迁移学习30 轮直接运行我给的代码