[Kernel] Change interface to Mamba selective_state_update for continuous batching (#8039)
This commit is contained in:
parent
b3195bc9e4
commit
db9120cded
@ -323,3 +323,149 @@ def test_selective_state_update(dim, dstate, has_z, itype):
|
|||||||
|
|
||||||
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(state, state_ref, rtol=rtol, atol=atol)
|
||||||
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("itype",
|
||||||
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("has_z", [False, True])
|
||||||
|
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||||
|
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
||||||
|
def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype):
|
||||||
|
device = "cuda"
|
||||||
|
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
|
||||||
|
if itype == torch.bfloat16:
|
||||||
|
rtol, atol = 7e-2, 7e-2
|
||||||
|
if torch.version.hip:
|
||||||
|
atol *= 2
|
||||||
|
# set seed
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
batch_size = 16
|
||||||
|
|
||||||
|
total_entries = 10 * batch_size
|
||||||
|
state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device)
|
||||||
|
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||||
|
dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||||
|
dt = torch.randn(batch_size, dim, device=device, dtype=itype)
|
||||||
|
dt_bias = torch.rand(dim, device=device) - 4.0
|
||||||
|
A = -torch.rand(dim, dstate, device=device) - 1.0
|
||||||
|
B = torch.randn(batch_size, dstate, device=device)
|
||||||
|
C = torch.randn(batch_size, dstate, device=device)
|
||||||
|
D = torch.randn(dim, device=device)
|
||||||
|
z = torch.randn_like(x) if has_z else None
|
||||||
|
state_ref = state[state_indices, :].detach().clone()
|
||||||
|
out = selective_state_update(state,
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
D=D,
|
||||||
|
z=z,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
dt_softplus=True,
|
||||||
|
state_batch_indices=state_indices)
|
||||||
|
out_ref = selective_state_update_ref(state_ref,
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
D=D,
|
||||||
|
z=z,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
dt_softplus=True)
|
||||||
|
|
||||||
|
assert torch.allclose(state[state_indices, :],
|
||||||
|
state_ref,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("itype",
|
||||||
|
[torch.float32, torch.float16, torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize("has_z", [False, True])
|
||||||
|
@pytest.mark.parametrize("tie_hdim", [False, True])
|
||||||
|
@pytest.mark.parametrize("ngroups", [1, 2, 4])
|
||||||
|
@pytest.mark.parametrize("dstate", [16, 32, 64])
|
||||||
|
@pytest.mark.parametrize("dim", [2048, 4096])
|
||||||
|
def test_selective_state_update_with_heads_with_batch_indices(
|
||||||
|
dim, dstate, ngroups, has_z, tie_hdim, itype):
|
||||||
|
device = "cuda"
|
||||||
|
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 3e-2)
|
||||||
|
if itype == torch.bfloat16:
|
||||||
|
rtol, atol = 1e-1, 1e-1
|
||||||
|
# set seed
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
batch_size = 16
|
||||||
|
headdim = 64
|
||||||
|
nheads = dim // headdim
|
||||||
|
|
||||||
|
total_entries = 10 * batch_size
|
||||||
|
state = torch.randn(total_entries,
|
||||||
|
nheads,
|
||||||
|
headdim,
|
||||||
|
dstate,
|
||||||
|
dtype=itype,
|
||||||
|
device=device)
|
||||||
|
state_indices = torch.randperm(total_entries)[:batch_size].to(
|
||||||
|
dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, nheads, headdim, device=device, dtype=itype)
|
||||||
|
if not tie_hdim:
|
||||||
|
dt = torch.randn(batch_size,
|
||||||
|
nheads,
|
||||||
|
headdim,
|
||||||
|
device=device,
|
||||||
|
dtype=itype)
|
||||||
|
dt_bias = torch.rand(nheads, headdim, device=device) - 4.0
|
||||||
|
A = -torch.rand(nheads, headdim, dstate, device=device) - 1.0
|
||||||
|
D = torch.randn(nheads, headdim, device=device)
|
||||||
|
else:
|
||||||
|
dt = repeat(torch.randn(batch_size, nheads, device=device,
|
||||||
|
dtype=itype),
|
||||||
|
"b h -> b h p",
|
||||||
|
p=headdim)
|
||||||
|
dt_bias = repeat(torch.rand(nheads, device=device) - 4.0,
|
||||||
|
"h -> h p",
|
||||||
|
p=headdim)
|
||||||
|
A = repeat(-torch.rand(nheads, device=device) - 1.0,
|
||||||
|
"h -> h p n",
|
||||||
|
p=headdim,
|
||||||
|
n=dstate)
|
||||||
|
D = repeat(torch.randn(nheads, device=device), "h -> h p", p=headdim)
|
||||||
|
B = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||||
|
C = torch.randn(batch_size, ngroups, dstate, device=device)
|
||||||
|
z = torch.randn_like(x) if has_z else None
|
||||||
|
state_ref = state[state_indices, :].detach().clone()
|
||||||
|
out = selective_state_update(state,
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
D=D,
|
||||||
|
z=z,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
dt_softplus=True,
|
||||||
|
state_batch_indices=state_indices)
|
||||||
|
out_ref = selective_state_update_ref(state_ref,
|
||||||
|
x,
|
||||||
|
dt,
|
||||||
|
A,
|
||||||
|
B,
|
||||||
|
C,
|
||||||
|
D=D,
|
||||||
|
z=z,
|
||||||
|
dt_bias=dt_bias,
|
||||||
|
dt_softplus=True)
|
||||||
|
|
||||||
|
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
||||||
|
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
||||||
|
assert torch.allclose(state[state_indices, :],
|
||||||
|
state_ref,
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol)
|
||||||
|
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
# Copyright (c) 2024, Tri Dao, Albert Gu.
|
||||||
|
# Adapted from https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/triton/selective_state_update.py
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@ -27,6 +28,10 @@ else:
|
|||||||
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
{"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None})
|
||||||
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None})
|
||||||
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None})
|
||||||
|
@triton.heuristics({
|
||||||
|
"HAS_STATE_BATCH_INDICES":
|
||||||
|
lambda args: args["state_batch_indices_ptr"] is not None
|
||||||
|
})
|
||||||
@triton.heuristics(
|
@triton.heuristics(
|
||||||
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
{"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])})
|
||||||
@triton.jit
|
@triton.jit
|
||||||
@ -42,6 +47,7 @@ def _selective_scan_update_kernel(
|
|||||||
D_ptr,
|
D_ptr,
|
||||||
z_ptr,
|
z_ptr,
|
||||||
out_ptr,
|
out_ptr,
|
||||||
|
state_batch_indices_ptr,
|
||||||
# Matrix dimensions
|
# Matrix dimensions
|
||||||
batch,
|
batch,
|
||||||
nheads,
|
nheads,
|
||||||
@ -85,12 +91,24 @@ def _selective_scan_update_kernel(
|
|||||||
HAS_DT_BIAS: tl.constexpr,
|
HAS_DT_BIAS: tl.constexpr,
|
||||||
HAS_D: tl.constexpr,
|
HAS_D: tl.constexpr,
|
||||||
HAS_Z: tl.constexpr,
|
HAS_Z: tl.constexpr,
|
||||||
|
HAS_STATE_BATCH_INDICES: tl.constexpr,
|
||||||
BLOCK_SIZE_DSTATE: tl.constexpr,
|
BLOCK_SIZE_DSTATE: tl.constexpr,
|
||||||
):
|
):
|
||||||
pid_m = tl.program_id(axis=0)
|
pid_m = tl.program_id(axis=0)
|
||||||
pid_b = tl.program_id(axis=1)
|
pid_b = tl.program_id(axis=1)
|
||||||
pid_h = tl.program_id(axis=2)
|
pid_h = tl.program_id(axis=2)
|
||||||
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
|
||||||
|
# If HAS_STATE_BATCH_INDICES is true, then the ssm state's batch coordinate
|
||||||
|
# is taken from the state_batch_indices_ptr Otherwise, the state coordinate
|
||||||
|
# is the same as the batch id.
|
||||||
|
if HAS_STATE_BATCH_INDICES:
|
||||||
|
state_batch_indices_ptr += pid_b
|
||||||
|
state_batch_idx = tl.load(state_batch_indices_ptr)
|
||||||
|
state_ptr += (state_batch_idx * stride_state_batch +
|
||||||
|
pid_h * stride_state_head)
|
||||||
|
else:
|
||||||
|
state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head
|
||||||
|
|
||||||
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head
|
||||||
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head
|
||||||
if HAS_DT_BIAS:
|
if HAS_DT_BIAS:
|
||||||
@ -177,7 +195,8 @@ def selective_state_update(state,
|
|||||||
D=None,
|
D=None,
|
||||||
z=None,
|
z=None,
|
||||||
dt_bias=None,
|
dt_bias=None,
|
||||||
dt_softplus=False):
|
dt_softplus=False,
|
||||||
|
state_batch_indices=None):
|
||||||
"""
|
"""
|
||||||
Argument:
|
Argument:
|
||||||
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
state: (batch, dim, dstate) or (batch, nheads, dim, dstate)
|
||||||
@ -211,7 +230,10 @@ def selective_state_update(state,
|
|||||||
z = z.unsqueeze(1)
|
z = z.unsqueeze(1)
|
||||||
if dt_bias is not None and dt_bias.dim() == 1:
|
if dt_bias is not None and dt_bias.dim() == 1:
|
||||||
dt_bias = dt_bias.unsqueeze(0)
|
dt_bias = dt_bias.unsqueeze(0)
|
||||||
batch, nheads, dim, dstate = state.shape
|
|
||||||
|
_, nheads, dim, dstate = state.shape
|
||||||
|
batch = x.shape[0]
|
||||||
|
|
||||||
assert x.shape == (batch, nheads, dim)
|
assert x.shape == (batch, nheads, dim)
|
||||||
assert dt.shape == x.shape
|
assert dt.shape == x.shape
|
||||||
assert A.shape == (nheads, dim, dstate)
|
assert A.shape == (nheads, dim, dstate)
|
||||||
@ -225,6 +247,8 @@ def selective_state_update(state,
|
|||||||
assert z.shape == x.shape
|
assert z.shape == x.shape
|
||||||
if dt_bias is not None:
|
if dt_bias is not None:
|
||||||
assert dt_bias.shape == (nheads, dim)
|
assert dt_bias.shape == (nheads, dim)
|
||||||
|
if state_batch_indices is not None:
|
||||||
|
assert state_batch_indices.shape == (batch, )
|
||||||
out = torch.empty_like(x)
|
out = torch.empty_like(x)
|
||||||
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads)
|
||||||
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else
|
||||||
@ -249,6 +273,7 @@ def selective_state_update(state,
|
|||||||
D,
|
D,
|
||||||
z,
|
z,
|
||||||
out,
|
out,
|
||||||
|
state_batch_indices,
|
||||||
batch,
|
batch,
|
||||||
nheads,
|
nheads,
|
||||||
dim,
|
dim,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user