基于萤火虫算法优化算法优化XGBoost(FA-XGBoost)的数据分类预测 FA-XGBoost数据分类 采用交叉验证抑制过拟合问题 优化参数为迭代次数、最大深度和学习率 matlab代码 注暂无Matlab版本要求 -- 推荐 2018B 版本及以上 注采用 XGBoost 工具箱仅支持 Windows 64位系统最近在折腾分类算法优化的时候发现萤火虫算法和XGBoost的组合有点意思。特别是当数据集存在噪声或者样本分布不均衡时这个FA-XGBoost组合拳打出来的效果竟然比网格搜索调参还稳。今天咱们就手把手实现一个基于Matlab的版本顺便聊聊怎么用交叉验证防止模型走火入魔。基于萤火虫算法优化算法优化XGBoost(FA-XGBoost)的数据分类预测 FA-XGBoost数据分类 采用交叉验证抑制过拟合问题 优化参数为迭代次数、最大深度和学习率 matlab代码 注暂无Matlab版本要求 -- 推荐 2018B 版本及以上 注采用 XGBoost 工具箱仅支持 Windows 64位系统先说说萤火虫算法的核心——参数编码部分。这里我们把XGBoost的三个命门参数打包成萤火虫的位置向量% 萤火虫位置编码 [迭代次数, 最大深度, 学习率] firefly.position [... randi([50,200]), % n_estimators范围50-200 randi([3,10]), % max_depth范围3-10 rand()*0.3 0.1]; % learning_rate范围0.1-0.4这个范围设置是之前实验踩坑得出的经验值学习率低于0.1容易收敛过慢超过0.4又会导致震荡。接下来是亮度计算函数这里直接拿交叉验证的准确率当评价指标function accuracy fitness_func(position, X, y) params struct(max_depth, position(2), learning_rate, position(3),... n_estimators, position(1), objective, binary:logistic); % 5折交叉验证防过拟合 cv cvpartition(y, KFold, 5); accuracies zeros(cv.NumTestSets, 1); for i 1:cv.NumTestSets trainIdx cv.training(i); testIdx cv.test(i); model xgb_train(X(trainIdx,:), y(trainIdx), params); [~, accuracies(i)] xgb_predict(model, X(testIdx,:), y(testIdx)); end accuracy mean(accuracies); end这里有个细节需要注意XGBoost的Matlab工具箱在数据格式转换上有点坑。如果输入的是table类型记得先转换成double矩阵% 数据预处理防报错 if istable(X_train) X_train table2array(X_train); end y_train double(y_train);萤火虫的位置更新是算法的精髓所在。相比标准公式我加了个参数范围的约束处理防止搜索跑偏% 带约束的位置更新 new_pos firefly.pos beta * rand(1,3) .* (neighbor.pos - firefly.pos) alpha * (rand(1,3)-0.5); % 参数边界约束 new_pos(1) min(max(round(new_pos(1)), 50), 200); % 迭代次数取整 new_pos(2) min(max(round(new_pos(2)), 3), 10); % 深度取整 new_pos(3) min(max(new_pos(3), 0.1), 0.4); % 学习率裁剪实际跑起来会发现最大深度这个参数最容易引发过拟合。有次实验忘加交叉验证测试集准确率直接从95%暴跌到70%。后来在迭代过程中加入早停机制才算稳住% 早停机制 if iter 10 mean(acc_history(end-9:end)) max_acc disp([Early stopping at iteration , num2str(iter)]); break; end最终得到的参数组合往往比默认参数提升3-5个点准确率。不过要注意Matlab的XGBoost工具箱对类别特征的处理方式如果遇到特征类型不匹配的报错试试强制转换% 处理分类特征 if ismember(Category, varTypes) X.(varNames{idx}) grp2idx(X.(varNames{idx})); end这个方案的缺点是挺吃算力的50只萤火虫跑20代差不多要半小时。建议在循环里加个进度条提升体验% 进度条 if ~exist(progress,var) progress waitbar(0,Fireflies are dancing...); end waitbar(iter/params.max_iter, progress);最后说下环境问题亲测在Win10Matlab2021a下运行稳定但Mac用户可能需要虚拟机或者双系统。另外数据量超过10万条时建议先做降采样不然内存分分钟爆炸。