diff --git a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py index ec1c37c5..727a470b 100644 --- a/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +++ b/vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py @@ -157,19 +157,22 @@ def _fwd_kernel_inner( k = tl.load( k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < k_seqlen, + other=0.0, ) else: k = tl.load( k_ptrs + start_n * stride_kt, mask=(offs_n[None, :] + start_n < k_seqlen) & (offs_d[:, None] < D_HEAD), + other=0.0, ) else: if EVEN_D: k = tl.load(k_ptrs + start_n * stride_kt) else: k = tl.load(k_ptrs + start_n * stride_kt, - mask=offs_d[:, None] < D_HEAD) + mask=offs_d[:, None] < D_HEAD, + other=0.0) qk = tl.zeros([BLOCK_M_LOADING, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) @@ -200,19 +203,22 @@ def _fwd_kernel_inner( v = tl.load( v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < k_seqlen, + other=0.0, ) else: v = tl.load( v_ptrs + start_n * stride_vt, mask=(offs_n[:, None] + start_n < k_seqlen) & (offs_d[None, :] < D_HEAD), + other=0.0, ) else: if EVEN_D: v = tl.load(v_ptrs + start_n * stride_vt) else: v = tl.load(v_ptrs + start_n * stride_vt, - mask=offs_d[None, :] < D_HEAD) + mask=offs_d[None, :] < D_HEAD, + other=0.0) acc += tl.dot(p, v) @@ -318,12 +324,13 @@ def _fwd_kernel_batch_inference( q = tl.load( Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, mask=offs_m[:, None] < q_seqlen, + other=0.0, ) else: q = tl.load( Q + offs_m[:, None] * stride_qt + offs_d[None, :] * stride_qd, mask=(offs_m[:, None] < q_seqlen) & (offs_d[None, :] < D_HEAD), - other=0, + other=0.0, ) sparse_crow_ptr = (layout_crow_ptr + off_h * layout_crow_stride_h + diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 6a32387a..f176259f 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -75,7 +75,9 @@ def _bgmv_expand_kernel( other=0.0, ) # [BLOCK_N,BLOCK_K] if ADD_INPUTS: - tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + tiled_out = tl.load(c_ptr + current_n * cn_stride, + mask=c_mask, + other=0.0) accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out else: accumulator = tl.sum(tiled_a * tiled_b, 1) diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index 73628fd2..2c6ed96c 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -78,7 +78,13 @@ def _bgmv_expand_slice_kernel( ) # [BLOCK_N,BLOCK_K] if ADD_INPUTS: - tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store + # operation uses the same mask, eliminating the risk of garbage + # values propagating + tiled_out = tl.load(c_ptr + current_n * cn_stride, + mask=c_mask, + other=None) accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out else: accumulator = tl.sum(tiled_a * tiled_b, 1) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 4910cb40..ee2cd2e0 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -88,7 +88,10 @@ def _sgmv_expand_kernel( c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store operation + # uses the same mask, eliminating the risk of garbage values propagating + tiled_out = tl.load(c_ptr, mask=c_mask, other=None) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 844f5cec..5244fa14 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -94,7 +94,10 @@ def _sgmv_expand_slice_kernel( c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < (slice_offset + N)) if ADD_INPUTS: - tiled_out = tl.load(c_ptr, mask=c_mask) + # explicitly pass in other=None to tell triton that masked values + # can be uninitialized. This is OK because the later tl.store operation + # uses the same mask, eliminating the risk of garbage values propagating + tiled_out = tl.load(c_ptr, mask=c_mask, other=None) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py index bbb7fc8a..ace8f4a3 100644 --- a/vllm/model_executor/layers/quantization/awq_triton.py +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -42,7 +42,7 @@ def awq_dequantize_kernel( result_masks = result_masks_y[:, None] & result_masks_x[None, :] # Load the weights. - iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.load(qweight_ptr + offsets, masks, 0.0) iweights = tl.interleave(iweights, iweights) iweights = tl.interleave(iweights, iweights) iweights = tl.interleave(iweights, iweights) @@ -71,7 +71,7 @@ def awq_dequantize_kernel( zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] # Load the zeros. - zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) @@ -91,7 +91,7 @@ def awq_dequantize_kernel( scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] # Load the scales. - scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0) scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) # Dequantize. @@ -165,10 +165,10 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): masks_k = offsets_k < K masks_a = masks_am[:, None] & masks_k[None, :] - a = tl.load(a_ptrs, mask=masks_a) + a = tl.load(a_ptrs, mask=masks_a, other=0.0) masks_b = masks_k[:, None] & masks_bn[None, :] - b = tl.load(b_ptrs, mask=masks_b) + b = tl.load(b_ptrs, mask=masks_b, other=0.0) b = tl.interleave(b, b) b = tl.interleave(b, b) b = tl.interleave(b, b) @@ -181,7 +181,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] zeros_ptrs = zeros_ptr + offsets_z - zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) @@ -191,7 +191,7 @@ def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, masks_sk = offsets_szk < K // group_size masks_s = masks_sk[:, None] & masks_sn[None, :] scales_ptrs = scales_ptr + offsets_s - scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.load(scales_ptrs, mask=masks_s, other=0.0) scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) b = (b >> shifts) & 0xF