cann/catlass MX FP8批量矩阵乘
MXFP8BatchMatmulTla Example Readme【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass功能介绍演示 Ascend 950 上的MX FP8 矩阵乘A、B 为 MX FP8经float8_e8m0缩放后做矩阵乘输出 BF16。本示例中 A、B 元素类型为float8_e4m3_t缩放因子为float8_e8m0_t。暂不支持 Bias本 batched kernel 无 Bias 通路。默认布局为 ARowMajor、BColumnMajor、CRowMajor与gen_data.py在trans_a0, trans_b1时生成的数据一致。代码组织├── 58_ascend950_fp8_mx_batch_matmul │ ├── CMakeLists.txt # CMake 编译配置 │ ├── README.md │ ├── gen_data.py # 生成 input/ 与 golden/ │ └── fp8_mx_matmul.cpp # 主程序使用示例获取代码之后编译相应的算子可执行文件可参考quickstart本用例为 Ascend9503510算子编译时需加-DCATLASS_ARCH3510执行算子# 编译指定用例 bash scripts/build.sh 58_ascend950_fp8_mx_batch_matmul -DCATLASS_ARCH3510 # 生成测试样例在 examples/58_ascend950_fp8_mx_batch_matmul/data 下生成 input/ 与 golden/ python3 examples/58_ascend950_fp8_mx_batch_matmul/gen_data.py 5 256 512 1024 0 1 # 输入参数分别对应 b, m, n, k, trans_a, trans_b # trans_a表示A矩阵是否转置0是不转置1是转置 # trans_b表示B矩阵是否转置0是不转置1是转置 # 执行测试样例 ./output/bin/58_ascend950_fp8_mx_batch_matmul 5 256 512 1024 0 # 可执行文件名 |batch_size|矩阵m轴|n轴|k轴|Device ID # Device ID可选默认为0执行结果如下说明精度比对成功。Compare success.使用说明1、gen_data.py的输入支持trans_a和trans_b但58_ascend950_fp8_mx_batch_matmul可执行文件不支持仅仅是trans_a为0及trans_b为1的example示例。若要对应转置情况请修改example示例中的layout因为layout隐式表征转置状态即layout::RowMajor表示不转置layout::ColumnMajor表示转置。其对应关系如下表trans_atrans_bLayoutALayoutB00layout::RowMajorlayout::RowMajor01layout::RowMajorlayout::ColumnMajor10layout::ColumnMajorlayout::RowMajor11layout::ColumnMajorlayout::ColumnMajor2、 本example完成mx batch量化矩阵乘 C (MxScaleA x A) * (MxScaleB x B) Bias A、B支持数据类型为float8_e4m3或float8_e5m2 MxScaleA、MxScaleB支持数据类型为float8_e8m0其中对于MxScaleA、MxScaleB的数据排布要求如下 当A为RowMajor时MxScaleA的shape为m, ceil(k/64), 2 当A为ColumnMajor时MxScaleA的shape为ceil(k/64), m, 2 当B为RowMajor时MxScaleB的shape为ceil(k/64), n, 2 当B为ColumnMajor时MxScaleB的shape为n, ceil(k/64), 23、 MxMatmul默认使用的DispatchPolicy MxMmad支持以下几个模板参数模板参数默认值参数说明ArchTag无指定架构型号enableUnitFlagfalse是否开启Unitflag开启L0C多缓冲时必须设置为falsel0CStages1指定L0C的缓冲区数量设置为2即可开启L0C双缓冲enableL1Residentfalse是否开启L1常驻l1AStages2L1上加载矩阵A的Buffer数量l1BStages2L1上加载矩阵B的Buffer数量l0AStages2L0上加载矩阵A的Buffer数量l0BStages2L0上加载矩阵B的Buffer数量设矩阵Shape为M N K, L1上的分块大小为m1 n1 k1M方向的分块数量mTiles CeilDiv(M, m1)N方向的分块数量nTiles CeilDiv(N, n1)总任务数为taskBlocks mTiles * nTiles在以下两种情况下可以选择开启enableL1Resident1.mTiles 1且nTiles CoreNum且K 2 * k1。此时还可以设置l0CStages2(需要关闭enableUnitFlag)如果空间不足无法设置l0CStages2则将n1设置为原来的一半。2.nTiles 1且mTiles CoreNum, 且K 2 * k1。此时还可以设置l0CStages2(需要关闭enableUnitFlag)如果空间不足无法设置l0CStages2则将m1设置为原来的一半。【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考