[Kernel] Explicitly specify other value in tl.load calls (#9014)
Signed-off-by: Angus Wang <wangjadehao@gmail.com>
This commit is contained in:
parent
6b2d25efc7
commit
c2170a5b39
@ -157,19 +157,22 @@ def _fwd_kernel_inner(
|
|||||||
k = tl.load(
|
k = tl.load(
|
||||||
k_ptrs + start_n * stride_kt,
|
k_ptrs + start_n * stride_kt,
|
||||||
mask=offs_n[None, :] + start_n < k_seqlen,
|
mask=offs_n[None, :] + start_n < k_seqlen,
|
||||||
|
other=0.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
k = tl.load(
|
k = tl.load(
|
||||||
k_ptrs + start_n * stride_kt,
|
k_ptrs + start_n * stride_kt,
|
||||||
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
mask=(offs_n[None, :] + start_n < k_seqlen) &
|
||||||
(offs_d[:, None] < D_HEAD),
|
(offs_d[:, None] < D_HEAD),
|
||||||
|
other=0.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if EVEN_D:
|
if EVEN_D:
|
||||||
k = tl.load(k_ptrs + start_n * stride_kt)
|
k = tl.load(k_ptrs + start_n * stride_kt)
|
||||||
else:
|
else:
|
||||||
k = tl.load(k_ptrs + start_n * stride_kt,
|
k = tl.load(k_ptrs + start_n * stride_kt,
|
||||||
mask=offs_d[:, None] < D_HEAD)
|
mask=offs_d[:, None] < D_HEAD,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32)
|
||||||
qk += tl.dot(q, k)
|
qk += tl.dot(q, k)
|
||||||
@ -200,19 +203,22 @@ def _fwd_kernel_inner(
|
|||||||
v = tl.load(
|
v = tl.load(
|
||||||
v_ptrs + start_n * stride_vt,
|
v_ptrs + start_n * stride_vt,
|
||||||
mask=offs_n[:, None] + start_n < k_seqlen,
|
mask=offs_n[:, None] + start_n < k_seqlen,
|
||||||
|
other=0.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
v = tl.load(
|
v = tl.load(
|
||||||
v_ptrs + start_n * stride_vt,
|
v_ptrs + start_n * stride_vt,
|
||||||
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
mask=(offs_n[:, None] + start_n < k_seqlen) &
|
||||||
(offs_d[None, :] < D_HEAD),
|
(offs_d[None, :] < D_HEAD),
|
||||||
|
other=0.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if EVEN_D:
|
if EVEN_D:
|
||||||
v = tl.load(v_ptrs + start_n * stride_vt)
|
v = tl.load(v_ptrs + start_n * stride_vt)
|
||||||
else:
|
else:
|
||||||
v = tl.load(v_ptrs + start_n * stride_vt,
|
v = tl.load(v_ptrs + start_n * stride_vt,
|
||||||
mask=offs_d[None, :] < D_HEAD)
|
mask=offs_d[None, :] < D_HEAD,
|
||||||
|
other=0.0)
|
||||||
|
|
||||||
acc += tl.dot(p, v)
|
acc += tl.dot(p, v)
|
||||||
|
|
||||||
@ -318,12 +324,13 @@ def _fwd_kernel_batch_inference(
|
|||||||
q = tl.load(
|
q = tl.load(
|
||||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||||
mask=offs_m[:, None] < q_seqlen,
|
mask=offs_m[:, None] < q_seqlen,
|
||||||
|
other=0.0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = tl.load(
|
q = tl.load(
|
||||||
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd,
|
||||||
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD),
|
||||||
other=0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h +
|
||||||
|
|||||||
@ -75,7 +75,9 @@ def _bgmv_expand_kernel(
|
|||||||
other=0.0,
|
other=0.0,
|
||||||
) # [BLOCK_N,BLOCK_K]
|
) # [BLOCK_N,BLOCK_K]
|
||||||
if ADD_INPUTS:
|
if ADD_INPUTS:
|
||||||
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
tiled_out = tl.load(c_ptr + current_n * cn_stride,
|
||||||
|
mask=c_mask,
|
||||||
|
other=0.0)
|
||||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||||
else:
|
else:
|
||||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||||
|
|||||||
@ -78,7 +78,13 @@ def _bgmv_expand_slice_kernel(
|
|||||||
) # [BLOCK_N,BLOCK_K]
|
) # [BLOCK_N,BLOCK_K]
|
||||||
|
|
||||||
if ADD_INPUTS:
|
if ADD_INPUTS:
|
||||||
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
# explicitly pass in other=None to tell triton that masked values
|
||||||
|
# can be uninitialized. This is OK because the later tl.store
|
||||||
|
# operation uses the same mask, eliminating the risk of garbage
|
||||||
|
# values propagating
|
||||||
|
tiled_out = tl.load(c_ptr + current_n * cn_stride,
|
||||||
|
mask=c_mask,
|
||||||
|
other=None)
|
||||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||||
else:
|
else:
|
||||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||||
|
|||||||
@ -88,7 +88,10 @@ def _sgmv_expand_kernel(
|
|||||||
c_mask = (offset_cm[:, None] <
|
c_mask = (offset_cm[:, None] <
|
||||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||||
if ADD_INPUTS:
|
if ADD_INPUTS:
|
||||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
# explicitly pass in other=None to tell triton that masked values
|
||||||
|
# can be uninitialized. This is OK because the later tl.store operation
|
||||||
|
# uses the same mask, eliminating the risk of garbage values propagating
|
||||||
|
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
|
||||||
tiled_c += tiled_out
|
tiled_c += tiled_out
|
||||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||||
|
|
||||||
|
|||||||
@ -94,7 +94,10 @@ def _sgmv_expand_slice_kernel(
|
|||||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
||||||
(slice_offset + N))
|
(slice_offset + N))
|
||||||
if ADD_INPUTS:
|
if ADD_INPUTS:
|
||||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
# explicitly pass in other=None to tell triton that masked values
|
||||||
|
# can be uninitialized. This is OK because the later tl.store operation
|
||||||
|
# uses the same mask, eliminating the risk of garbage values propagating
|
||||||
|
tiled_out = tl.load(c_ptr, mask=c_mask, other=None)
|
||||||
tiled_c += tiled_out
|
tiled_c += tiled_out
|
||||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,7 @@ def awq_dequantize_kernel(
|
|||||||
result_masks = result_masks_y[:, None] & result_masks_x[None, :]
|
result_masks = result_masks_y[:, None] & result_masks_x[None, :]
|
||||||
|
|
||||||
# Load the weights.
|
# Load the weights.
|
||||||
iweights = tl.load(qweight_ptr + offsets, masks)
|
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
|
||||||
iweights = tl.interleave(iweights, iweights)
|
iweights = tl.interleave(iweights, iweights)
|
||||||
iweights = tl.interleave(iweights, iweights)
|
iweights = tl.interleave(iweights, iweights)
|
||||||
iweights = tl.interleave(iweights, iweights)
|
iweights = tl.interleave(iweights, iweights)
|
||||||
@ -71,7 +71,7 @@ def awq_dequantize_kernel(
|
|||||||
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
|
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
|
||||||
|
|
||||||
# Load the zeros.
|
# Load the zeros.
|
||||||
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)
|
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
|
||||||
zeros = tl.interleave(zeros, zeros)
|
zeros = tl.interleave(zeros, zeros)
|
||||||
zeros = tl.interleave(zeros, zeros)
|
zeros = tl.interleave(zeros, zeros)
|
||||||
zeros = tl.interleave(zeros, zeros)
|
zeros = tl.interleave(zeros, zeros)
|
||||||
@ -91,7 +91,7 @@ def awq_dequantize_kernel(
|
|||||||
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
|
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
|
||||||
|
|
||||||
# Load the scales.
|
# Load the scales.
|
||||||
scales = tl.load(scales_ptr + scale_offsets, scale_masks)
|
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
|
||||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
||||||
|
|
||||||
# Dequantize.
|
# Dequantize.
|
||||||
@ -165,10 +165,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
|||||||
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
|
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
|
||||||
masks_k = offsets_k < K
|
masks_k = offsets_k < K
|
||||||
masks_a = masks_am[:, None] & masks_k[None, :]
|
masks_a = masks_am[:, None] & masks_k[None, :]
|
||||||
a = tl.load(a_ptrs, mask=masks_a)
|
a = tl.load(a_ptrs, mask=masks_a, other=0.0)
|
||||||
|
|
||||||
masks_b = masks_k[:, None] & masks_bn[None, :]
|
masks_b = masks_k[:, None] & masks_bn[None, :]
|
||||||
b = tl.load(b_ptrs, mask=masks_b)
|
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
|
||||||
b = tl.interleave(b, b)
|
b = tl.interleave(b, b)
|
||||||
b = tl.interleave(b, b)
|
b = tl.interleave(b, b)
|
||||||
b = tl.interleave(b, b)
|
b = tl.interleave(b, b)
|
||||||
@ -181,7 +181,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
|||||||
masks_zk = offsets_szk < K // group_size
|
masks_zk = offsets_szk < K // group_size
|
||||||
masks_z = masks_zk[:, None] & masks_zn[None, :]
|
masks_z = masks_zk[:, None] & masks_zn[None, :]
|
||||||
zeros_ptrs = zeros_ptr + offsets_z
|
zeros_ptrs = zeros_ptr + offsets_z
|
||||||
zeros = tl.load(zeros_ptrs, mask=masks_z)
|
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
|
||||||
zeros = tl.interleave(zeros, zeros)
|
zeros = tl.interleave(zeros, zeros)
|
||||||
zeros = tl.interleave(zeros, zeros)
|
zeros = tl.interleave(zeros, zeros)
|
||||||
zeros = tl.interleave(zeros, zeros)
|
zeros = tl.interleave(zeros, zeros)
|
||||||
@ -191,7 +191,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K,
|
|||||||
masks_sk = offsets_szk < K // group_size
|
masks_sk = offsets_szk < K // group_size
|
||||||
masks_s = masks_sk[:, None] & masks_sn[None, :]
|
masks_s = masks_sk[:, None] & masks_sn[None, :]
|
||||||
scales_ptrs = scales_ptr + offsets_s
|
scales_ptrs = scales_ptr + offsets_s
|
||||||
scales = tl.load(scales_ptrs, mask=masks_s)
|
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
|
||||||
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
||||||
|
|
||||||
b = (b >> shifts) & 0xF
|
b = (b >> shifts) & 0xF
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user