diff --git a/csrc/rotary/rotary.cpp b/csrc/rotary/rotary.cpp index 77037e8..206fda3 100644 --- a/csrc/rotary/rotary.cpp +++ b/csrc/rotary/rotary.cpp @@ -1,10 +1,8 @@ #include +#include -#define CHECK_DEVICE(x) \ - TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") -#define CHECK_SHAPE(x, ...) \ - TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ - #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, const torch::Tensor cos, const torch::Tensor sin, @@ -26,6 +24,11 @@ void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, TORCH_CHECK(x1.sizes() == x2.sizes()); TORCH_CHECK(cos.sizes() == sin.sizes()); TORCH_CHECK(out1.sizes() == out2.sizes()); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x1.get_device()}; + apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); } diff --git a/flash_attn/layers/rotary.py b/flash_attn/layers/rotary.py index dadee32..db3319c 100644 --- a/flash_attn/layers/rotary.py +++ b/flash_attn/layers/rotary.py @@ -137,17 +137,19 @@ class RotaryEmbedding(torch.nn.Module): """ - def __init__(self, dim: int, base=10000, scale_base=0, *_, **__): + def __init__(self, dim: int, base=10000, scale_base=0, device=None): """ If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py """ super().__init__() # Generate and save the inverse frequency buffer (non trainable) - inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device, + dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq) self.scale_base = scale_base - scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) if scale_base > 0 else None + scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) + / (1.4 * dim) if scale_base > 0 else None) self.register_buffer("scale", scale) self._seq_len_cached = 0 @@ -168,14 +170,14 @@ class RotaryEmbedding(torch.nn.Module): t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 # freqs = torch.einsum("i,j->ij", t, self.inv_freq) - freqs = torch.outer(t, self.inv_freq) + freqs = torch.outer(t, self.inv_freq.to(device=t.device)) if self.scale is None: self._cos_cached = torch.cos(freqs).to(x.dtype) self._sin_cached = torch.sin(freqs).to(x.dtype) else: power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2) / self.scale_base) - scale = self.scale ** rearrange(power, 's -> s 1') + scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1') # We want the multiplication by scale to happen in fp32 self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) diff --git a/flash_attn/modules/mha.py b/flash_attn/modules/mha.py index 3536b0d..6bc3321 100644 --- a/flash_attn/modules/mha.py +++ b/flash_attn/modules/mha.py @@ -21,9 +21,9 @@ except ImportError: flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None try: - from flash_attn.ops.fused_dense import FusedDense + from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear except ImportError: - FusedDense = None + FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None try: from flash_attn.layers.rotary import RotaryEmbedding @@ -42,7 +42,7 @@ class FlashSelfAttention(nn.Module): (default: 0.0) """ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - triton=False, device=None, dtype=None): + triton=False): super().__init__() if attention_dropout != 0.0 or not triton: assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed' @@ -109,7 +109,7 @@ class FlashCrossAttention(nn.Module): (default: 0.0) """ def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - triton=False, device=None, dtype=None): + triton=False): super().__init__() if attention_dropout != 0.0 or not triton: assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed' @@ -181,8 +181,7 @@ class SelfAttention(nn.Module): attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - device=None, dtype=None): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal self.softmax_scale = softmax_scale @@ -228,8 +227,7 @@ class CrossAttention(nn.Module): attention_dropout: The dropout rate to apply to the attention (default: 0.0) """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, - device=None, dtype=None): + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): super().__init__() self.causal = causal self.softmax_scale = softmax_scale @@ -309,7 +307,8 @@ class MHA(nn.Module): if self.rotary_emb_dim > 0: assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' assert RotaryEmbedding is not None, 'rotary_emb is not installed' - self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base) + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, + device=device) if fused_bias_fc and FusedDense is None: raise ImportError('fused_dense is not installed') @@ -338,7 +337,7 @@ class MHA(nn.Module): groups=2 * embed_dim) inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, - attention_dropout=dropout, **factory_kwargs) + attention_dropout=dropout) # output projection always have the bias (for now) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) @@ -378,7 +377,7 @@ class MHA(nn.Module): if self.dwconv: qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], 'b d s -> b s d').contiguous() - qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, h=self.num_heads) + qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) if self.rotary_emb_dim > 0: qkv = self.rotary_emb(qkv) if not self.checkpointing: @@ -395,8 +394,8 @@ class MHA(nn.Module): else: kv, x = self.Wkv(x) q = self.Wq(x) - q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) - kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, h=self.num_heads) + q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim) + kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim) if self.dwconv: q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2], 'b d s -> b s d').contiguous() @@ -408,3 +407,66 @@ class MHA(nn.Module): context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs) out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) return out if not self.return_residual else (out, x) + + +class ParallelMHA(nn.Module): + """Multi-head self-attention and cross-attention + """ + + def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0, + softmax_scale=None, causal=False, rotary_emb_dim=0, rotary_emb_scale_base=0, + use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.process_group = process_group + self.embed_dim = embed_dim + self.causal = causal + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.checkpointing = checkpointing + + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, 'rotary_emb is not installed' + self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base, + device=device) + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError('fused_dense is not installed') + self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias, + **factory_kwargs) + inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention + self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout) + # output projection always have the bias (for now) + self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, **factory_kwargs) + + def forward(self, x, seqlen=None, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + qkv = self.Wqkv(x) + if seqlen is None: + qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim) + else: + qkv = rearrange(qkv, '(b s) (three h d) -> b s three h d', s=seqlen, three=3, + d=self.head_dim) + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb(qkv) + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) + if seqlen is None: + context = rearrange(context, 'b s h d -> b s (h d)') + else: + context = rearrange(context, 'b s h d -> (b s) (h d)') + out = self.out_proj(context) + return out diff --git a/tests/modules/test_mha_parallel.py b/tests/modules/test_mha_parallel.py new file mode 100644 index 0000000..2d3d01f --- /dev/null +++ b/tests/modules/test_mha_parallel.py @@ -0,0 +1,109 @@ +# Run test with: +# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py + +import math + +import torch +import torch.nn.functional as F +import pytest + +from einops import rearrange + +from apex.transformer import parallel_state +from apex.transformer import tensor_parallel + +from flash_attn.modules.mha import MHA, ParallelMHA + +is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8 + + +@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else [])) +# @pytest.mark.parametrize('dtype', [torch.float16]) +@pytest.mark.parametrize('world_size', [1, 2, 4, 8]) +# @pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('head_dim', [64, 128]) +# @pytest.mark.parametrize('head_dim', [64]) +@pytest.mark.parametrize('embed_dim', [1024, 4096]) +# @pytest.mark.parametrize('embed_dim', [1024]) +def test_mha_parallel(embed_dim, head_dim, world_size, dtype): + assert embed_dim % head_dim == 0 + num_heads = embed_dim // head_dim + assert num_heads % world_size == 0 + rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3) + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend='nccl', init_method='env://') + device = f'cuda:{torch.distributed.get_rank()}' + assert world_size <= torch.distributed.get_world_size() + parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size) + rank = parallel_state.get_tensor_model_parallel_rank() + # set seed + torch.random.manual_seed(0) + batch_size = 8 + seqlen = 1024 + assert (batch_size * seqlen) % world_size == 0 + x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype, + requires_grad=True) + # We need to generate g here so that all processes get the same gradient, + # as rank 0 will have an extra bias that changes the RNG. + # If we don't divide by batch_size, the gradient gets a bit too large. + g = torch.randn_like(x_pt) / 32 + x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_() + + model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2), + use_flash_attn=True, device=device, dtype=dtype) + partition_dim = embed_dim // world_size + model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(), + rotary_emb_dim=int(head_dim // 2), use_flash_attn=True, + device=device, dtype=dtype) + + with torch.no_grad(): + model.Wqkv.weight.copy_( + rearrange(rearrange(model_pt.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], + 'three o i -> (three o) i') + ) + model.Wqkv.bias.copy_( + rearrange(rearrange(model_pt.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], + 'three o -> (three o)') + ) + model.out_proj.weight.copy_( + model_pt.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim] + ) + if rank == 0: + model.out_proj.bias.copy_(model_pt.out_proj.bias) + + out = model(x, seqlen=seqlen) + out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d') + partition_batch_dim = batch_size * seqlen // world_size + assert torch.allclose( + out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + rtol=rtol, atol=atol + ) + + out_pt.backward(g) + out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim]) + parallel_state.destroy_model_parallel() + + assert torch.allclose( + x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim], + rtol=rtol, atol=atol + ) + # The error for d_weight and d_bias is quite a bit higher + assert torch.allclose( + model.Wqkv.weight.grad, + rearrange(rearrange(model_pt.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], + 'three o i -> (three o) i'), + rtol=rtol, atol=atol * 10 + ) + assert torch.allclose( + model.Wqkv.bias.grad, + rearrange(rearrange(model_pt.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim], + 'three o -> (three o)'), + rtol=rtol, atol=atol * 5 + ) + assert torch.allclose( + model.out_proj.weight.grad, + model_pt.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim], + rtol=rtol, atol=atol * 10 + ) + if rank == 0: + assert torch.allclose(model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5)