Transpose【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言原生支持C和C标准规范主要由类库和语言扩展层构成提供多层级API满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit产品支持情况产品是否支持Ascend 950PR/Ascend 950DT√Atlas A3 训练系列产品 / Atlas A3 推理系列产品√Atlas A2 训练系列产品 / Atlas A2 推理系列产品√Atlas 200I/500 A2 推理产品√Atlas 推理系列产品 AI Core√Atlas 推理系列产品 Vector CorexAtlas 训练系列产品√Kirin X90√Kirin 9030√功能说明头文件路径为basic_api/kernel_operator_vec_transpose_intf.h。Transpose接口用于实现16*16的二维矩阵数据块转置或者[N,C,H,W]与[N,H,W,C]数据格式互相转换。16*16的普通转置接口计算原理和参考伪代码如下import numpy as np src np.random.randn(16, 16).astype(np.float16) dst src.T[N,C,H,W]与[N,H,W,C]数据格式互相转换的增强转置计算原理和参考伪代码如下import numpy as np # transposeParams.transposeType : TRANSPOSE_NCHW2NHWC src_nchw np.random.randn(transposeParams.nSize, transposeParams.cSize, transposeParams.hSize, transposeParams.wSize).astype(np.float16) dst_nhwc np.transpose(src_nchw, axes(0,2,3,1)) # transposeParams.transposeType : TRANSPOSE_NHWC2NCHW src_nhwc np.random.randn(transposeParams.nSize, transposeParams.hSize, transposeParams.wSize, transposeParams.cSize).astype(np.float16) dst_nchw np.transpose(src_nhwc, axes(0,3,1,2))函数原型普通转置支持16*16的二维矩阵数据块进行转置。template typename T __aicore__ inline void Transpose(const LocalTensorT dst, const LocalTensorT src)增强转置支持16*16的二维矩阵数据块转置支持[N,C,H,W]与[N,H,W,C]互相转换。template typename T __aicore__ inline void Transpose(const LocalTensorT dst, const LocalTensorT src, const LocalTensoruint8_t sharedTmpBuffer, const TransposeParamsExt transposeParams)参数说明表模板参数说明参数名描述T操作数的数据类型。表接口参数说明参数名称输入/输出含义dst输出目的操作数。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT存储位置为Unified Buffer。LocalTensor的起始地址需要32字节对齐。src输入源操作数。类型为LocalTensor支持的TPosition为VECIN/VECCALC/VECOUT存储位置为Unified Buffer。LocalTensor的起始地址需要32字节对齐。数据类型需要与dst保持一致。sharedTmpBuffer输入共享的临时BuffersharedTmpBuffer的大小参考表 sharedTmpBuffer所需的内存。transposeParams输入控制Transpose的数据结构。结构体内包含输入的shape信息和transposeType参数。该数据结构的定义请参考表 TransposeParamsExt结构体内参数说明。struct TransposeParamsExt {__aicore__ TransposeParamsExt() {}__aicore__ TransposeParamsExt(const uint16_t nSizeIn, const uint16_t cSizeIn, const uint16_t hSizeIn,const uint16_t wSizeIn, const TransposeType transposeTypeIn): nSize(nSizeIn),cSize(cSizeIn),hSize(hSizeIn),wSize(wSizeIn),transposeType(transposeTypeIn){}uint16_t nSize 0;uint16_t cSize 0;uint16_t hSize 0;uint16_t wSize 0;TransposeType transposeType TransposeType::TRANSPOSE_ND2ND_B16;};表TransposeParamsExt结构体内参数说明参数名称含义nSizen轴长度。默认值为0。•二维矩阵数据块转置无需传入传入数值无效。•[N,C,H,W]与[N,H,W,C]数据格式互相转换取值范围nSize∈[0, 65535]。cSizec轴长度。默认值为0。•二维矩阵数据块转置无需传入传入数值无效。•[N,C,H,W]与[N,H,W,C]数据格式互相转换取值范围cSize∈[0, 4095]。hSizeh轴长度。默认值为0。•二维矩阵数据块转置固定传入16。•[N,C,H,W]与[N,H,W,C]数据格式互相转换取值范围hSize * wSize ∈[0, 4095]hSize * wSize * sizeof(T)需要保证32B对齐。wSizew轴长度。默认值为0。•二维矩阵数据块转置固定传入16。•[N,C,H,W]与[N,H,W,C]数据格式互相转换取值范围hSize * wSize ∈[0, 4095]hSize * wSize * sizeof(T)需要保证32B对齐。transposeType数据排布及reshape的类型类型为TransposeType枚举类。默认值为TRANSPOSE_ND2ND_B16。enum class TransposeType : uint8_t {TRANSPOSE_TYPE_NONE, // API不做任何处理TRANSPOSE_NZ2ND_0213, // 当前不支持TRANSPOSE_NZ2NZ_0213, // 当前不支持TRANSPOSE_NZ2NZ_012_WITH_N, // 当前不支持TRANSPOSE_NZ2ND_012_WITH_N, // 当前不支持TRANSPOSE_NZ2ND_012_WITHOUT_N, // 当前不支持TRANSPOSE_NZ2NZ_012_WITHOUT_N, // 当前不支持TRANSPOSE_ND2ND_ONLY, // 当前不支持TRANSPOSE_ND_UB_GM, // 当前不支持TRANSPOSE_GRAD_ND_UB_GM, // 当前不支持TRANSPOSE_ND2ND_B16, // [16,16]二维矩阵转置TRANSPOSE_NCHW2NHWC, // [N,C,H,W]-[N,H,W,C]TRANSPOSE_NHWC2NCHW // [N,H,W,C]-[N,C,H,W]};表Ascend 950PR/Ascend 950DT sharedTmpBuffer所需的内存transposeTypesharedTmpBuffer所需的大小TRANSPOSE_ND2ND_B16不需要临时Buffer。TRANSPOSE_NCHW2NHWC临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize 2) * h0 * w0 * sizeof(type);TRANSPOSE_NHWC2NCHW临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize * 2 1) * h0 * w0 * sizeof(type);表Atlas A3 训练系列产品/Atlas A3 推理系列产品sharedTmpBuffer所需的内存transposeTypesharedTmpBuffer所需的大小TRANSPOSE_ND2ND_B16不需要临时Buffer。TRANSPOSE_NCHW2NHWC临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize 2) * h0 * w0 * sizeof(type);TRANSPOSE_NHWC2NCHW临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize * 2 1) * h0 * w0 * sizeof(type);表Atlas A2 训练系列产品/Atlas A2 推理系列产品sharedTmpBuffer所需的内存transposeTypesharedTmpBuffer所需的大小TRANSPOSE_ND2ND_B16不需要临时Buffer。TRANSPOSE_NCHW2NHWC临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize 2) * h0 * w0 * sizeof(type);TRANSPOSE_NHWC2NCHW临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize * 2 1) * h0 * w0 * sizeof(type);表Atlas 200I/500 A2 推理产品sharedTmpBuffer所需的内存transposeTypesharedTmpBuffer所需的大小TRANSPOSE_ND2ND_B16不需要临时Buffer。表Atlas 推理系列产品AI Core sharedTmpBuffer所需的内存transposeTypesharedTmpBuffer所需的大小TRANSPOSE_ND2ND_B16不需要临时Buffer。TRANSPOSE_NCHW2NHWC不需要临时Buffer。TRANSPOSE_NHWC2NCHW不需要临时Buffer。表Kirin X90 sharedTmpBuffer所需的内存transposeTypesharedTmpBuffer所需的大小TRANSPOSE_ND2ND_B16不需要临时Buffer。TRANSPOSE_NCHW2NHWC临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize 2) * h0 * w0 * sizeof(type);TRANSPOSE_NHWC2NCHW临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize * 2 1) * h0 * w0 * sizeof(type);表Kirin 9030 sharedTmpBuffer所需的内存transposeTypesharedTmpBuffer所需的大小TRANSPOSE_ND2ND_B16不需要临时Buffer。TRANSPOSE_NCHW2NHWC临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize 2) * h0 * w0 * sizeof(type);TRANSPOSE_NHWC2NCHW临时Buffer的大小按照下述计算规则伪代码进行计算。auto h0 16; // 当数据类型的位宽为8时h0 32其他情况下h0 16auto w0 32 / sizeof(type); // type代表数据类型auto tmpBufferSize (cSize * 2 1) * h0 * w0 * sizeof(type);数据类型普通转置Ascend 950PR/Ascend 950DT操作数支持的数据类型为int16_t、uint16_t、half。Atlas A3 训练系列产品/Atlas A3 推理系列产品操作数支持的数据类型为int16_t、uint16_t、half。Atlas A2 训练系列产品/Atlas A2 推理系列产品操作数支持的数据类型为int16_t、uint16_t、half。Atlas 200I/500 A2 推理产品操作数支持的数据类型为int16_t、uint16_t、half。Atlas 推理系列产品AI Core操作数支持的数据类型为int16_t、uint16_t、half。Atlas 训练系列产品操作数支持的数据类型为int16_t、uint16_t、half。Kirin X90操作数支持的数据类型为int16_t、uint16_t、half。Kirin 9030操作数支持的数据类型为int16_t、uint16_t、half。增强转置transposeType为TRANSPOSE_ND2ND_B16Ascend 950PR/Ascend 950DT操作数支持的数据类型为int16_t、uint16_t、half。Atlas A3 训练系列产品/Atlas A3 推理系列产品操作数支持的数据类型为uint16_t。Atlas A2 训练系列产品/Atlas A2 推理系列产品操作数支持的数据类型为uint16_t。Atlas 200I/500 A2 推理产品操作数支持的数据类型为uint16_t。Atlas 推理系列产品AI Core操作数支持的数据类型为uint16_t。transposeType为TRANSPOSE_NCHW2NHWC或TRANSPOSE_NHWC2NCHWAscend 950PR/Ascend 950DT操作数支持的数据类型为int8_t、uint8_t、fp4x2_e2m1_t、fp4x2_e1m2_t、hifloat8_t、fp8_e8m0_t、fp8_e5m2_t、fp8_e4m3fn_t、int4x2_t、int16_t、uint16_t、half、bfloat16_t、int32_t、uint32_t、float、complex32。Atlas A3 训练系列产品/Atlas A3 推理系列产品操作数支持的数据类型为int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。Atlas A2 训练系列产品/Atlas A2 推理系列产品操作数支持的数据类型为int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。Atlas 推理系列产品AI Core操作数支持的数据类型为int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。Kirin X90操作数支持的数据类型为int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。Kirin 9030操作数支持的数据类型为int8_t、uint8_t、int16_t、uint16_t、half、int32_t、uint32_t、float。返回值说明无。约束说明操作数地址对齐要求请参见Unified Buffer地址对齐约束。普通转置接口支持src和dst复用。增强转置接口transposeType为TRANSPOSE_ND2ND_B16时支持src和dst复用transposeType为TRANSPOSE_NCHW2NHWC、TRANSPOSE_NHWC2NCHW时不支持src和dst复用。二维矩阵数据块转置时nSize、cSize无需传入传入数值无效hSize、wSize固定传入16。增强转置接口transposeType为TRANSPOSE_NCHW2NHWC、TRANSPOSE_NHWC2NCHW时如果nSize、cSize、hSize、wSize为0不会执行计算操作不会对目的操作数进行写入。[N,C,H,W]与[N,H,W,C]数据格式互相转换参数取值范围nSize∈[0, 65535]cSize∈[0, 4095]hSize * wSize ∈[0, 4095]hSize * wSize * sizeof(T)需要保证32B对齐。转置增强接口中入参sharedTmpBuffer的大小不得小于计算所需的最小阈值。调用示例普通接口调用示例片段完整片段请参考Transpose类样例场景一该示例对[16,16]的half类型矩阵进行转置。// dstLocal目的操作数tensor // srcLocal源操作数tensor AscendC::Transposehalf(dstLocal, srcLocal);输入数据src_gm: [[ 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11. 12. 13. 14. 15.] [ 16. 17. 18. 19. 20. 21. 22. 23. 24. 25. 26. 27. 28. 29. 30. 31.] [ 32. 33. 34. 35. 36. 37. 38. 39. 40. 41. 42. 43. 44. 45. 46. 47.] [ 48. 49. 50. 51. 52. 53. 54. 55. 56. 57. 58. 59. 60. 61. 62. 63.] [ 64. 65. 66. 67. 68. 69. 70. 71. 72. 73. 74. 75. 76. 77. 78. 79.] [ 80. 81. 82. 83. 84. 85. 86. 87. 88. 89. 90. 91. 92. 93. 94. 95.] [ 96. 97. 98. 99. 100. 101. 102. 103. 104. 105. 106. 107. 108. 109. 110. 111.] [112. 113. 114. 115. 116. 117. 118. 119. 120. 121. 122. 123. 124. 125. 126. 127.] [128. 129. 130. 131. 132. 133. 134. 135. 136. 137. 138. 139. 140. 141. 142. 143.] [144. 145. 146. 147. 148. 149. 150. 151. 152. 153. 154. 155. 156. 157. 158. 159.] [160. 161. 162. 163. 164. 165. 166. 167. 168. 169. 170. 171. 172. 173. 174. 175.] [176. 177. 178. 179. 180. 181. 182. 183. 184. 185. 186. 187. 188. 189. 190. 191.] [192. 193. 194. 195. 196. 197. 198. 199. 200. 201. 202. 203. 204. 205. 206. 207.] [208. 209. 210. 211. 212. 213. 214. 215. 216. 217. 218. 219. 220. 221. 222. 223.] [224. 225. 226. 227. 228. 229. 230. 231. 232. 233. 234. 235. 236. 237. 238. 239.] [240. 241. 242. 243. 244. 245. 246. 247. 248. 249. 250. 251. 252. 253. 254. 255.]] 输出数据dst_gm: [[ 0. 16. 32. 48. 64. 80. 96. 112. 128. 144. 160. 176. 192. 208. 224. 240.] [ 1. 17. 33. 49. 65. 81. 97. 113. 129. 145. 161. 177. 193. 209. 225. 241.] [ 2. 18. 34. 50. 66. 82. 98. 114. 130. 146. 162. 178. 194. 210. 226. 242.] [ 3. 19. 35. 51. 67. 83. 99. 115. 131. 147. 163. 179. 195. 211. 227. 243.] [ 4. 20. 36. 52. 68. 84. 100. 116. 132. 148. 164. 180. 196. 212. 228. 244.] [ 5. 21. 37. 53. 69. 85. 101. 117. 133. 149. 165. 181. 197. 213. 229. 245.] [ 6. 22. 38. 54. 70. 86. 102. 118. 134. 150. 166. 182. 198. 214. 230. 246.] [ 7. 23. 39. 55. 71. 87. 103. 119. 135. 151. 167. 183. 199. 215. 231. 247.] [ 8. 24. 40. 56. 72. 88. 104. 120. 136. 152. 168. 184. 200. 216. 232. 248.] [ 9. 25. 41. 57. 73. 89. 105. 121. 137. 153. 169. 185. 201. 217. 233. 249.] [ 10. 26. 42. 58. 74. 90. 106. 122. 138. 154. 170. 186. 202. 218. 234. 250.] [ 11. 27. 43. 59. 75. 91. 107. 123. 139. 155. 171. 187. 203. 219. 235. 251.] [ 12. 28. 44. 60. 76. 92. 108. 124. 140. 156. 172. 188. 204. 220. 236. 252.] [ 13. 29. 45. 61. 77. 93. 109. 125. 141. 157. 173. 189. 205. 221. 237. 253.] [ 14. 30. 46. 62. 78. 94. 110. 126. 142. 158. 174. 190. 206. 222. 238. 254.] [ 15. 31. 47. 63. 79. 95. 111. 127. 143. 159. 175. 191. 207. 223. 239. 255.]]增强接口调用示例片段完整代码请参考Transpose类样例场景二完成half类型的[N,C,H,W]-[N,H,W,C]转置。AscendC::TransposeParamsExt transposeParams; transposeParams.nSize N; // N轴长度 transposeParams.cSize C; // C轴长度 transposeParams.hSize H; // H轴长度 transposeParams.wSize W; // W轴长度 transposeParams.transposeType transposeType; AscendC::Transpose(dstLocal, srcLocal, stackBuffer, transposeParams);【免费下载链接】asc-devkit本项目是CANN 推出的昇腾AI处理器专用的算子程序开发语言原生支持C和C标准规范主要由类库和语言扩展层构成提供多层级API满足多维场景算子开发诉求。项目地址: https://gitcode.com/cann/asc-devkit创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考