CANN EasyAsc DSL a2 Cube-Vec-Cube-Vec模式
a2 Cube-to-Vec-to-Cube-to-Vec Pattern (Triple Bridge, Normalized Online Softmax)【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skillsRead this file when writing an a2 (easyasc.a2, deviceb3) kernel with:one cube stage that produces a score tilevec logic that updates running row max and running row suma later cube stage that consumes the delayed probability tilea final vec stage that accumulates the delayed cube outputone final vec-only divide by the accumulated row sumTypical target formula:score_j q.float() k_j.float().t() * scalecurr_m maximum(prev_m, rowmax(score_j))expdiff_j exp(prev_m - curr_m)p_j exp(score_j - curr_m)row_sum row_sum * expdiff_j p_j.sum(-1)pv_j p_j.half().float() v_j.float()out out * expdiff_j pv_jout out / row_sumThis is the normalized counterpart toa2-cube-vec-cube-vec.md. Use that older pattern only when the kernel stops at the unnormalized numerator.One-page route for the common caseIf this file matches your contract, donotpreload all of:agent/references/constraints/reduction.mdagent/references/constraints/vec-reduction-a2.mdagent/references/constraints/vec-stride.mdagent/references/constraints/online-softmax-tail.mdThis page now owns the common normalized-online-softmax authoring rules. Open the smaller constraint pages only when a specific failure mode still remains unclear after this file.Why this needs its own a2 patternThe a2 hardware constraints are the same as the unnormalized case:cube - vec cannot usel0c_to_ubvec - cube cannot useub_to_l1_*delayed cube output must come back to vec for final accumulationBut normalized online softmax adds two stability-sensitive requirements:runningrow_summust be updated from the floatexp(...)tile before any cast to halfthe final divide must happen only once, after all delayed numerator tiles have been accumulatedSo the stable a2 flow is:GM(q,k,v) - L1 - L0 - L0C(score) - GM(score_ws) - UB(score)- vec(max, expdiff, exp, row_sum, cast p) - GM(p_ws) - L1 - L0 - L0C(pv)- GM(pv_ws) - UB(pv) - UB(accum) - final UB divide by row_sum - GM(out)Workspaces and ownership edgesUse the same three GM workspaces as the unnormalized pattern:score_wsdtype:floatshape:[GetCubeNum(), 2, TILE_M, TILE_N]purpose:L0C(score)-UB(score)p_wsdtype:halfshape:[GetCubeNum(), 2, TILE_M, TILE_N]purpose:UB(p_j.half())-L1(p_j)pv_wsdtype:floatshape:[GetCubeNum(), 2, TILE_M, D]purpose:L0C(pv_j)-UB(pv_j)Ownership edges:stage 1 cube - vec:CvMutex(0, src_end_pipePipe.FIX, dst_end_pipePipe.MTE2)stage 1 vec - stage 2 cube:VcMutex(1, src_end_pipePipe.MTE3, dst_end_pipePipe.FIX)stage 2 cube - stage 3 vec:CvMutex(2, src_end_pipePipe.FIX, dst_end_pipePipe.MTE2)Stable scheduleUse the same one-tile lookahead loop as the unnormalized pattern:for ni in range(0, tiles_n 1): if ni tiles_n: # stage 1: produce tile j ni if ni 0: # stage 2 stage 3: consume tile j ni - 1That gives:warmup: first iteration only producessteady state: producejwhile consumingj - 1drain: final iteration only consumes the last delayed tileSharedL0CruleReuse one physicalL0Cfamily across the two cube stages.This is the same capacity-driven choice as the unnormalized pattern:stage 1 needs float[TILE_M, TILE_N]stage 2 needs float[TILE_M, D]with validatedD 128a2 still has only128 KBL0CKeep one sharedl0c_cnt, but do not merge unrelated counters just becauseL0Cis shared.Counter layoutKeep these lifetimes separate:l1qk_cnt: stage-1q/kloadsl1pv_cnt: stage-2p/vloadsl0c_cnt: shared physicalL0Cfamily across the two cube stagesstage1_cnt: delayed slot rhythm forscore_ws,p_ws, andexpdiffstage2_cnt: delayed slot rhythm forp_wsconsumption andpv_wsRunningrow_sumdoes not need its own delayed counter. It stays vec-resident for the whole inner loop and updates immediately in stage 1.Vec-resident persistent stateKeep these values in per-subblock UB across the whole inner loop:running row max:[HALF_M, 1]running row sum:[HALF_M, 1]delayedexpdiffslots:DBuff(DT.float, [HALF_M, 1], Position.UB)final numerator accumulation:[HALF_M, D]UseGetSubBlockIdx()so each vec lane owns only its ownHALF_Mrows.Stable stage-1 update orderThe normalized online update order matters:computerowmax(score_j)in[HALF_M, 1]snapshotprev_minto the delayedexpdiffslot withadd(..., zero)updaterunning_max maximum(running_max, tile_max)turn the delayed slot intoexp(prev_m - curr_m)broadcastrunning_maxand subtract from the score tilecompute the float probability tilep_j exp(score_j - curr_m)reducesum_jfrom that float tile withaddcaddupdaterunning_sum running_sum * expdiff_j sum_jin[HALF_M, 1]castp_jtohalfonly now, because stage 2 wants the exactp_j.half().float()contractDo not move the row-sum update after the cast. That would silently change the reference contract.Vec rules you usually need without extra docsFor the commonTILE_N 128,D 128path, the usual extra questions are already answered here:keeprunning_max,running_sum, and delayedexpdiffin scalar format[HALF_M, 1]snapshot scalar state withadd(dst, src, zero), notub_to_ubcmax/caddoutput dense scalars, so broadcast them with:brcb(dst, src, dst_blk_stride1, dst_rep_stride8)when a wide[HALF_M, 128]buffer is paired with a narrow[HALF_M, 8]broadcast row, operate on:buf[:, 0:64]buf[:, 64:128]rather than on the full 128-column view in one vec callupdaterunning_sumfrom the floatp_jtile before any cast tohalforhif8for non-alignedS2, invalidate score columns beforecmaxwith a sufficiently negative finite sentinel;valid_non the GM load alone is not enoughThese six rules cover the usual reasons people would otherwise open the separate reduction, vec-reduction, vec-stride, and tail files.Critical scalar-state rule on a2Donotcopy[HALF_M, 1]scalar-format state withub_to_ub.That applies to both:prev_many temporary scalar snapshot you might be tempted to use forrow_sumUseadd(dst, src, zero)for scalar-format copies, and keep bothrunning_maxandrunning_sumin[M,1]format until you explicitly need a broadcast.Final vec accumulation and divideStage 3 still matches the unnormalized pattern:load delayedpv_jback into UBbrcbthe delayedexpdiffslot to[HALF_M, 8]scale the two 64-column halves ofaccumadd(accum, accum, pv_j)After the inner loop finishes:brcbthe finalrunning_sumto[HALF_M, 8]div(accum[:, 0:64], accum[:, 0:64], row_sum_broadcast)div(accum[:, 64:128], accum[:, 64:128], row_sum_broadcast)write the normalized result to GMWhy the divide happens at the end:accummust finish all delayedpv_jcontributions firstrow_sumis the denominator for the whole streamed softmax, not one tileExtending the pattern to non-alignedS2The initial validated contract for this pattern keptS2 % 128 0so the first implementation could ignore score-tail masking.WhenS2is not aligned, donotstop at GM-boundaryvalid_nslicing. For normalized online softmax, padded score columns can still corrupt:rowmax(score_j)curr_mdelayedexpdiffrow_sumStable rule:loadk/vthroughvalid_nkeep local score buffers full-sizedbeforecmax, force invalid score columns to behave like-infwhen materializing that mask, use a sufficiently large finite negative fill value instead of literal-infafterexp, those same columns naturally behave like0For the currentTILE_N 128layout, the simplest a2 implementation is:split the score tile into two[HALF_M, 64]halvesuse vec mask finite-negativedup(...)on the affected halfrecomputeprev_valid_nfor the delayedvload in stage 2Read next for the exact rule and mask-construction trick:agent/references/constraints/online-softmax-tail.mdValidation targetKeep the first validated contract narrow:D 128S1 % 128 0S2 % 128 0inputq/k/varefloat16output isfloat32Suggested cases:(1, 3, 256, 256, 128)for the smallest two-tile online update(1, 1, 256, 512, 128)(1, 3, 256, 512, 128)(1, 3, 2048, 4096, 128)For non-alignedS2extensions, add at least:one aligned baseline:S2 % 128 0one left-half tail:S2 % 128 10one cross-boundary case:S2 % 128 65one mid-right-half case:S2 % 128 96one last-column case:S2 % 128 127Files to study / deeper fallbacksagent/example/kernels/a2/flash_attn_full.pyagent/example/kernels/a2/flash_attn_unnorm.pyagent/example/kernels/a2/flash_attn_score_pv.pyagent/references/patterns/a2-cube-vec-cube-vec.mdagent/references/constraints/reduction.md— fallback only when the online update order is still unclearagent/references/constraints/vec-reduction-a2.md— fallback only when thecmax/cadd - brcbdetail is still unclearagent/references/constraints/vec-stride.md— fallback only when a sliced wide/narrow vec op is still unclearagent/references/constraints/online-softmax-tail.md— fallback only when the non-alignedS2mask construction itself is the question【免费下载链接】cannbot-skillsCANNBot 是面向 CANN 开发的用于提升开发效率的系列智能体本仓库为其提供可复用的 Skills 模块。项目地址: https://gitcode.com/cann/cannbot-skills创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考