1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
| @triton.jit def _layer_norm_bwd_dx_fused(DY, DX, DW, DB, X, W, Mean, Rstd, Lock, stride, N, GROUP_SIZE_ROW: tl.constexpr, BLOCK_SIZE_COL: tl.constexpr): row_pid = tl.program_id(0) X += row_pid * stride DY += row_pid * stride DX += row_pid * stride
off_col = tl.arange(0, BLOCK_SIZE_COL) mask_col = off_col < N
lock_id = row_pid % GROUP_SIZE_ROW Lock += lock_id Count = Lock + GROUP_SIZE_ROW DW = DW + lock_id * N + off_col DB = DB + lock_id * N + off_col
x = tl.load(X + off_col, mask=mask_col, other=0.0).to(tl.float32) dy = tl.load(DY + off_col, mask=mask_col, other=0.0).to(tl.float32) w = tl.load(W + off_col, mask=mask_col).to(tl.float32) mean = tl.load(Mean + row_pid) rstd = tl.load(Rstd + row_pid)
x_hat = (x - mean) * rstd x_hat = tl.where(mask_col, x_hat, 0.0) wdy = w * dy wdy = tl.where(mask_col, wdy, 0.0)
term1 = tl.sum(wdy, axis=0) / N term2 = tl.sum(wdy * x_hat, axis=0) / N dx = (wdy - (xhat * term2 + term1)) * rstd tl.store(DX + off_col, dx.to(DX.dtype), mask=mask_col)
group_partial_dw = tl.sum(dy * x_hat, axis=0) group_partial_db = tl.sum(dy, axis=0) while tl.atomic_cas(Count, 0, 1) == 1: pass count = tl.load(Count) if Count == 0: tl.atomic_xchg(Count, 1) else: group_partial_dw += tl.load(DW, mask=mask_col).to(tl.float32) group_partial_db += tl.load(DB, mask=mask_col).to(tl.float32) tl.store(DW, group_partial_dw, mask=mask_col) tl.store(DB, group_partial_db, mask=mask_col)
tl.debug_barrier()
tl.atomic_xchg(Lock, 0)
@triton.jit def _layer_norm_bwd_dwdb(DW, DB, FINAL_DW, FINAL_DB, M, N, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): pid = tl.program_id(0) cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for i in range(0, M, BLOCK_SIZE_M): rows = i + tl.arange(0, BLOCK_SIZE_M) mask = (rows[:, None] < M) & (cols[None, :] < N) offs = rows[:, None] * N + cols[None, :] dw += tl.load(DW + offs, mask=mask, other=0.) db += tl.load(DB + offs, mask=mask, other=0.) sum_dw = tl.sum(dw, axis=0) sum_db = tl.sum(db, axis=0) tl.store(FINAL_DW + cols, sum_dw, mask=cols < N) tl.store(FINAL_DB + cols, sum_db, mask=cols < N)
|