CANN/catlass Block MMAD开发详解
Block MMAD 代码开发详解【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass1. Block MMAD 概述Block MMADBlock Matrix Multiply-Add是CATLASS模板库中负责块级矩阵乘法的核心组件位于计算架构的中间层。它上接Kernel层下接Tile层负责将全局内存GM中的数据高效地加载到本地内存L1/L0并调度Tile级别的矩阵乘法运算。Block MMAD采用高度模块化和模板化的设计支持多种调度策略、Tile形状和数据类型能够灵活适应不同的硬件架构和计算需求本文将以BlockMmadPingpong为例详细讲解。2. 模板组装机制Block MMAD实现基于以下基础模板结构template class DispatchPolicy, class L1TileShape, class L0TileShape, class AType, class BType, class CType, class BiasType void, class TileCopy Gemm::Tile::TileCopytypename DispatchPolicy::ArchTag, AType, BType, CType, BiasType, class TileMmad Gemm::Tile::TileMmadtypename DispatchPolicy::ArchTag, AType, BType, BiasType struct BlockMmad { // 类型定义和核心实现 };2.1 核心模板参数参数名描述DispatchPolicy调度策略控制计算任务的分配和执行流程L1TileShapeL1缓存级别Tile的形状定义M、N、K三个维度的大小L0TileShapeL0缓存级别Tile的形状定义M、N、K三个维度的大小AType/BType/CType输入矩阵A、B和输出矩阵C的数据类型和布局信息BiasType偏置数据类型可选TileCopyTile级别的数据拷贝组件负责不同内存层级间的数据传输TileMmadTile级别的矩阵乘法组件负责实际的计算操作2.2 类型导出Block MMAD通过类型导出系统建立统一的类型接口方便上层组件使用public: // Type Aliases using DispatchPolicy MmadAtlasA2PingpongENABLE_UNIT_FLAG_; using ArchTag typename DispatchPolicy::ArchTag; using L1TileShape L1TileShape_; using L0TileShape L0TileShape_; using ElementA typename AType_::Element; using LayoutA typename AType_::Layout; using ElementB typename BType_::Element; using LayoutB typename BType_::Layout; using ElementC typename CType_::Element; using LayoutC typename CType_::Layout; using TileMmad TileMmad_; using CopyGmToL1A typename TileCopy_::CopyGmToL1A; using CopyGmToL1B typename TileCopy_::CopyGmToL1B; using CopyL1ToL0A typename TileCopy_::CopyL1ToL0A; using CopyL1ToL0B typename TileCopy_::CopyL1ToL0B; using CopyL0CToGm typename TileCopy_::CopyL0CToGm; using ElementAccumulator typename Gemm::helper::ElementAccumulatorSelectorElementA, ElementB::ElementAccumulator; using LayoutAInL1 typename CopyL1ToL0A::LayoutSrc; using LayoutBInL1 typename CopyL1ToL0B::LayoutSrc; using LayoutAInL0 typename CopyL1ToL0A::LayoutDst; using LayoutBInL0 typename CopyL1ToL0B::LayoutDst; using LayoutCInL0 layout::zN;3. 内存管理与缓存设计Block MMAD负责管理不同层级的内存包括全局内存GM、L1缓存和L0缓存3.1 静态常量定义static constexpr bool ENABLE_UNIT_FLAG DispatchPolicy::ENABLE_UNIT_FLAG; static constexpr uint32_t STAGES DispatchPolicy::STAGES; static constexpr uint32_t L1A_SIZE L1TileShape::M * L1TileShape::K * sizeof(ElementA); static constexpr uint32_t L1B_SIZE L1TileShape::N * L1TileShape::K * sizeof(ElementB); static constexpr uint32_t L0A_SIZE ArchTag::L0A_SIZE; static constexpr uint32_t L0B_SIZE ArchTag::L0B_SIZE; static constexpr uint32_t L0C_SIZE ArchTag::L0C_SIZE; static constexpr uint32_t L0A_PINGPONG_BUF_SIZE L0A_SIZE / STAGES; static constexpr uint32_t L0B_PINGPONG_BUF_SIZE L0B_SIZE / STAGES;3.2 内存检查// Check LayoutC static_assert(std::is_same_vLayoutC, layout::RowMajor, LayoutC only support RowMajor yet!); // Check L1TileShape static_assert((L1A_SIZE * STAGES L1B_SIZE * STAGES) ArchTag::L1_SIZE, L1TileShape exceeding the L1 space!); // Check L0TileShape static constexpr uint32_t L0A_TILE_SIZE L0TileShape::M * L0TileShape::K * sizeof(ElementA); static constexpr uint32_t L0B_TILE_SIZE L0TileShape::K * L0TileShape::N * sizeof(ElementB); static constexpr uint32_t L0C_TILE_SIZE L0TileShape::M * L0TileShape::N * sizeof(ElementAccumulator); static_assert((L0A_TILE_SIZE * STAGES) L0A_SIZE, L0TileShape exceeding the L0A space!); static_assert((L0B_TILE_SIZE * STAGES) L0B_SIZE, L0TileShape exceeding the L0B space!); static_assert(L0C_TILE_SIZE L0C_SIZE, L0TileShape exceeding the L0C space!);3.3 多阶段缓存设计Block MMAD采用多阶段流水线设计使用pingpong技术隐藏内存访问延迟protected: // Multi-stage tensors list AscendC::LocalTensorElementA l1ATensorList[STAGES]; AscendC::LocalTensorElementB l1BTensorList[STAGES]; AscendC::LocalTensorElementA l0ATensorList[STAGES]; AscendC::LocalTensorElementB l0BTensorList[STAGES]; AscendC::LocalTensorElementAccumulator l0CTensor; // Multi-stage event id list int32_t l1AEventList[STAGES]; int32_t l1BEventList[STAGES]; int32_t l0AEventList[STAGES]; int32_t l0BEventList[STAGES]; // The id of current stage uint32_t l1ListId{0}; uint32_t l0AListId{0}; uint32_t l0BListId{0};4. 核心接口实现4.1 构造函数CATLASS_DEVICE BlockMmad(Arch::ResourceArchTag resource, uint32_t l1BufAddrStart 0) { uint32_t l1AOffset l1BufAddrStart; uint32_t l1BOffset l1BufAddrStart L1A_SIZE * STAGES; // Init buffers for (uint32_t i 0; i STAGES; i) { // Assign L1/L0A/L0B space for each stages l1ATensorList[i] resource.l1Buf.template GetBufferByByteElementA(l1AOffset L1A_SIZE * i); l1BTensorList[i] resource.l1Buf.template GetBufferByByteElementB(l1BOffset L1B_SIZE * i); l0ATensorList[i] resource.l0ABuf.template GetBufferByByteElementA(L0A_PINGPONG_BUF_SIZE * i); l0BTensorList[i] resource.l0BBuf.template GetBufferByByteElementB(L0B_PINGPONG_BUF_SIZE * i); // Assign event ID for each stages l1AEventList[i] i; l1BEventList[i] i STAGES; l0AEventList[i] i; l0BEventList[i] i STAGES; // The event id that needs to be set before the loop AscendC::SetFlagAscendC::HardEvent::MTE1_MTE2(l1AEventList[i]); AscendC::SetFlagAscendC::HardEvent::MTE1_MTE2(l1BEventList[i]); AscendC::SetFlagAscendC::HardEvent::M_MTE1(l0AEventList[i]); AscendC::SetFlagAscendC::HardEvent::M_MTE1(l0BEventList[i]); } l0CTensor resource.l0CBuf.template GetBufferByByteElementAccumulator(0); AscendC::SetFlagAscendC::HardEvent::FIX_M(EVENT_ID0); }4.2 析构函数CATLASS_DEVICE ~BlockMmad() { for (uint32_t i 0; i STAGES; i) { AscendC::WaitFlagAscendC::HardEvent::MTE1_MTE2(l1AEventList[i]); AscendC::WaitFlagAscendC::HardEvent::MTE1_MTE2(l1BEventList[i]); AscendC::WaitFlagAscendC::HardEvent::M_MTE1(l0AEventList[i]); AscendC::WaitFlagAscendC::HardEvent::M_MTE1(l0BEventList[i]); } AscendC::WaitFlagAscendC::HardEvent::FIX_M(EVENT_ID0); }4.3 核心计算接口 operator()operator()是Block MMAD的核心接口负责执行块级矩阵乘法CATLASS_DEVICE void operator()( AscendC::GlobalTensorElementA const gmA, LayoutA const layoutA, AscendC::GlobalTensorElementB const gmB, LayoutB const layoutB, AscendC::GlobalTensorElementC const gmC, LayoutC const layoutC, GemmCoord const actualShape) { // 1. 初始化参数和布局 // 2. 预加载第一批数据到L1缓存 // 3. 主循环 // a. 预加载下一批数据到L1缓存 // b. 将当前L1缓存中的数据加载到L0缓存 // c. 执行Tile级矩阵乘法 // d. 管理多阶段流水线 // 4. 将结果从L0缓存写回全局内存 }5. 执行流程分析以BlockMmadPingpong为例其执行流程如下5.1 数据预加载// load first matrix A tile from GM to L1 AscendC::WaitFlagAscendC::HardEvent::MTE1_MTE2(l1AEventList[l1ListId]); auto layoutTileA layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA); AscendC::SetFlagAscendC::HardEvent::MTE2_MTE1(l1AEventList[l1ListId]); // load first matrix B tile from GM to L1 AscendC::WaitFlagAscendC::HardEvent::MTE1_MTE2(l1BEventList[l1ListId]); auto layoutTileB layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); copyGmToL1B(l1BTensorList[l1ListId], gmB, layoutBInL1, layoutTileB); AscendC::SetFlagAscendC::HardEvent::MTE2_MTE1(l1BEventList[l1ListId]);5.2 主循环多阶段流水线// main loop uint32_t kTileCount CeilDivL1TileShape::K(actualShape.k()); for (uint32_t kLoopIdx 0; kLoopIdx kTileCount; kLoopIdx) { uint32_t l1ListIdNext (l1ListId 1 STAGES) ? (l1ListId 1) : 0; uint32_t kActualNext{0}; // 预加载下一批数据到L1缓存 if (kLoopIdx kTileCount - 1) { // ... 预加载逻辑 ... } // 处理当前L1缓存中的数据 // Get L1 tensor for current stage auto l1ATensor l1ATensorList[l1ListId]; auto l1BTensor l1BTensorList[l1ListId]; // 将L1缓存中的数据加载到L0缓存并执行计算 // ... L0级处理逻辑 ... // 切换到下一个阶段 l1ListId l1ListIdNext; kActual kActualNext; }5.3 L0级处理与计算for (int mPartIdx 0; mPartIdx mPartLoop; mPartIdx) { // ... M维度循环 ... for (int kPartIdx 0; kPartIdx kPartLoop; kPartIdx) { // ... K维度循环 ... for (int nPartIdx 0; nPartIdx nPartLoop; nPartIdx) { // ... N维度循环 ... // 执行Tile级矩阵乘法 bool initC ((kLoopIdx 0) (kPartIdx 0)); uint8_t unitFlag 0b00; // ... unitFlag设置 ... tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); } } }5.4 结果写回// copy block out LayoutC layoutBlock layoutC.GetTileLayout(actualShape.GetCoordMN()); if constexpr (!ENABLE_UNIT_FLAG) { AscendC::SetFlagAscendC::HardEvent::M_FIX(EVENT_ID0); AscendC::WaitFlagAscendC::HardEvent::M_FIX(EVENT_ID0); copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C); AscendC::SetFlagAscendC::HardEvent::FIX_M(EVENT_ID0); } else { copyL0CToGm(gmC, l0CTensor, layoutBlock, layoutInL0C, 0b11); }6. 多阶段流水线与事件同步Block MMAD使用事件驱动的同步机制确保多阶段流水线的正确执行// 等待事件 AscendC::WaitFlagAscendC::HardEvent::MTE1_MTE2(l1AEventList[l1ListId]); // 执行操作 copyGmToL1A(l1ATensorList[l1ListId], gmA, layoutAInL1, layoutTileA); // 触发事件 AscendC::SetFlagAscendC::HardEvent::MTE2_MTE1(l1AEventList[l1ListId]);7. 不同Block MMAD实现的共性与差异CATLASS提供了多种Block MMAD实现它们具有以下共性统一的模板接口所有实现都基于相同的模板参数结构多阶段流水线设计都采用多阶段设计来隐藏内存访问延迟事件驱动同步都使用事件机制确保数据传输和计算的正确顺序内存层级管理都负责管理GM、L1和L0之间的数据传输不同实现的主要差异在于调度策略如pingpong、preload、streamk等不同的调度方式硬件适配针对不同硬件架构的优化实现特殊功能支持如量化、稀疏计算、偏置处理等性能优化不同的流水线深度和内存访问模式8. 总结Block MMAD是CATLASS模板库中连接Kernel层和Tile层的关键组件它通过以下设计实现了高效的块级矩阵乘法高度模块化通过模板参数组装不同的功能组件多阶段流水线使用pingpong技术隐藏内存访问延迟事件驱动同步确保数据传输和计算的正确顺序自动内存管理高效管理不同层级的内存资源灵活适配支持不同的硬件架构和计算需求Block MMAD的设计体现了现代高性能计算库的核心思想通过精心的流水线设计和内存管理充分发挥硬件的计算能力实现高效的矩阵乘法运算。【免费下载链接】catlass本项目是CANN的算子模板库提供NPU上高性能矩阵乘及其相关融合类算子模板样例。项目地址: https://gitcode.com/cann/catlass创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考