[MHA] Run black on mha.py
This commit is contained in:
parent
cb0daccc41
commit
bec5b3d374
@ -1,4 +1,4 @@
|
|||||||
# Copyright (c) 2022, Tri Dao.
|
# Copyright (c) 2023, Tri Dao.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -6,18 +6,21 @@ from functools import partial
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func
|
from flash_attn import (
|
||||||
from flash_attn import flash_attn_qkvpacked_func, flash_attn_kvpacked_func
|
flash_attn_kvpacked_func,
|
||||||
|
flash_attn_qkvpacked_func,
|
||||||
|
flash_attn_varlen_kvpacked_func,
|
||||||
|
flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None
|
||||||
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear
|
from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear
|
||||||
except ImportError:
|
except ImportError:
|
||||||
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
|
||||||
|
|
||||||
@ -42,10 +45,11 @@ class FlashSelfAttention(nn.Module):
|
|||||||
attention_dropout: The dropout rate to apply to the attention
|
attention_dropout: The dropout rate to apply to the attention
|
||||||
(default: 0.0)
|
(default: 0.0)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert flash_attn_varlen_qkvpacked_func is not None, 'FlashAttention is not installed'
|
assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed"
|
||||||
assert flash_attn_qkvpacked_func is not None, 'FlashAttention is not installed'
|
assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed"
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.drop = nn.Dropout(attention_dropout)
|
self.drop = nn.Dropout(attention_dropout)
|
||||||
@ -76,12 +80,20 @@ class FlashSelfAttention(nn.Module):
|
|||||||
assert max_seqlen is not None
|
assert max_seqlen is not None
|
||||||
assert isinstance(max_seqlen, int)
|
assert isinstance(max_seqlen, int)
|
||||||
return flash_attn_varlen_qkvpacked_func(
|
return flash_attn_varlen_qkvpacked_func(
|
||||||
qkv, cu_seqlens, max_seqlen, self.drop.p if self.training else 0.0,
|
qkv,
|
||||||
softmax_scale=self.softmax_scale, causal=causal
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
self.drop.p if self.training else 0.0,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
causal=causal,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return flash_attn_qkvpacked_func(qkv, self.drop.p if self.training else 0.0,
|
return flash_attn_qkvpacked_func(
|
||||||
softmax_scale=self.softmax_scale, causal=causal)
|
qkv,
|
||||||
|
self.drop.p if self.training else 0.0,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashCrossAttention(nn.Module):
|
class FlashCrossAttention(nn.Module):
|
||||||
@ -94,16 +106,25 @@ class FlashCrossAttention(nn.Module):
|
|||||||
attention_dropout: The dropout rate to apply to the attention
|
attention_dropout: The dropout rate to apply to the attention
|
||||||
(default: 0.0)
|
(default: 0.0)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert flash_attn_varlen_kvpacked_func is not None, 'FlashAttention is not installed'
|
assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed"
|
||||||
assert flash_attn_kvpacked_func is not None, 'FlashAttention is not installed'
|
assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed"
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.softmax_scale = softmax_scale
|
self.softmax_scale = softmax_scale
|
||||||
self.drop = nn.Dropout(attention_dropout)
|
self.drop = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
def forward(self, q, kv, causal=None, cu_seqlens=None, max_seqlen=None,
|
def forward(
|
||||||
cu_seqlens_k=None, max_seqlen_k=None):
|
self,
|
||||||
|
q,
|
||||||
|
kv,
|
||||||
|
causal=None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
max_seqlen=None,
|
||||||
|
cu_seqlens_k=None,
|
||||||
|
max_seqlen_k=None,
|
||||||
|
):
|
||||||
"""Implements the multihead softmax attention.
|
"""Implements the multihead softmax attention.
|
||||||
Arguments
|
Arguments
|
||||||
---------
|
---------
|
||||||
@ -130,16 +151,27 @@ class FlashCrossAttention(nn.Module):
|
|||||||
assert max_seqlen_k is not None
|
assert max_seqlen_k is not None
|
||||||
assert isinstance(max_seqlen, int)
|
assert isinstance(max_seqlen, int)
|
||||||
return flash_attn_varlen_kvpacked_func(
|
return flash_attn_varlen_kvpacked_func(
|
||||||
q, kv, cu_seqlens, cu_seqlens_k, max_seqlen, max_seqlen_k,
|
q,
|
||||||
|
kv,
|
||||||
|
cu_seqlens,
|
||||||
|
cu_seqlens_k,
|
||||||
|
max_seqlen,
|
||||||
|
max_seqlen_k,
|
||||||
self.drop.p if self.training else 0.0,
|
self.drop.p if self.training else 0.0,
|
||||||
softmax_scale=self.softmax_scale, causal=causal
|
softmax_scale=self.softmax_scale,
|
||||||
|
causal=causal,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
batch_size, seqlen_q = q.shape[0], q.shape[1]
|
||||||
seqlen_k = kv.shape[1]
|
seqlen_k = kv.shape[1]
|
||||||
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3]
|
||||||
return flash_attn_kvpacked_func(q, kv, self.drop.p if self.training else 0.0,
|
return flash_attn_kvpacked_func(
|
||||||
causal=causal, softmax_scale=self.softmax_scale)
|
q,
|
||||||
|
kv,
|
||||||
|
self.drop.p if self.training else 0.0,
|
||||||
|
causal=causal,
|
||||||
|
softmax_scale=self.softmax_scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
@ -152,6 +184,7 @@ class SelfAttention(nn.Module):
|
|||||||
attention_dropout: The dropout rate to apply to the attention
|
attention_dropout: The dropout rate to apply to the attention
|
||||||
(default: 0.0)
|
(default: 0.0)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
@ -171,22 +204,25 @@ class SelfAttention(nn.Module):
|
|||||||
causal = self.causal if causal is None else causal
|
causal = self.causal if causal is None else causal
|
||||||
q, k, v = qkv.unbind(dim=2)
|
q, k, v = qkv.unbind(dim=2)
|
||||||
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
||||||
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype,
|
padding_mask = torch.full(
|
||||||
device=scores.device)
|
(batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device
|
||||||
|
)
|
||||||
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
||||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||||
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
|
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
||||||
if causal:
|
if causal:
|
||||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||||
# So we have to construct the mask in float
|
# So we have to construct the mask in float
|
||||||
causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1)
|
causal_mask = torch.triu(
|
||||||
|
torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1
|
||||||
|
)
|
||||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||||
attention_drop = self.drop(attention)
|
attention_drop = self.drop(attention)
|
||||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -200,6 +236,7 @@ class CrossAttention(nn.Module):
|
|||||||
attention_dropout: The dropout rate to apply to the attention
|
attention_dropout: The dropout rate to apply to the attention
|
||||||
(default: 0.0)
|
(default: 0.0)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
@ -224,43 +261,48 @@ class CrossAttention(nn.Module):
|
|||||||
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3])
|
||||||
k, v = kv.unbind(dim=2)
|
k, v = kv.unbind(dim=2)
|
||||||
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1])
|
||||||
scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale)
|
scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale)
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
padding_mask = torch.full((batch_size, seqlen_k), -10000.0, dtype=scores.dtype,
|
padding_mask = torch.full(
|
||||||
device=scores.device)
|
(batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device
|
||||||
|
)
|
||||||
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
padding_mask.masked_fill_(key_padding_mask, 0.0)
|
||||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||||
scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s')
|
scores = scores + rearrange(padding_mask, "b s -> b 1 1 s")
|
||||||
if causal:
|
if causal:
|
||||||
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
# "triu_tril_cuda_template" not implemented for 'BFloat16'
|
||||||
# So we have to construct the mask in float
|
# So we have to construct the mask in float
|
||||||
causal_mask = torch.triu(torch.full((seqlen_q, seqlen_k), -10000.0,
|
causal_mask = torch.triu(
|
||||||
device=scores.device), 1)
|
torch.full((seqlen_q, seqlen_k), -10000.0, device=scores.device), 1
|
||||||
|
)
|
||||||
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
|
||||||
scores = scores + causal_mask.to(dtype=scores.dtype)
|
scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||||
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
attention = torch.softmax(scores, dim=-1, dtype=v.dtype)
|
||||||
attention_drop = self.drop(attention)
|
attention_drop = self.drop(attention)
|
||||||
output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
output = torch.einsum("bhts,bshd->bthd", attention_drop, v)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class LinearResidual(nn.Linear):
|
class LinearResidual(nn.Linear):
|
||||||
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.
|
"""Wrap nn.Linear to return the residual as well. For compatibility with FusedDense."""
|
||||||
"""
|
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
return super().forward(input), input
|
return super().forward(input), input
|
||||||
|
|
||||||
|
|
||||||
def _update_kv_cache(kv, inference_params, layer_idx):
|
def _update_kv_cache(kv, inference_params, layer_idx):
|
||||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
||||||
"""
|
|
||||||
# Pre-allocate memory for key-values for inference.
|
# Pre-allocate memory for key-values for inference.
|
||||||
num_heads, head_dim = kv.shape[-2:]
|
num_heads, head_dim = kv.shape[-2:]
|
||||||
if layer_idx not in inference_params.key_value_memory_dict:
|
if layer_idx not in inference_params.key_value_memory_dict:
|
||||||
kv_cache = torch.empty(
|
kv_cache = torch.empty(
|
||||||
inference_params.max_batch_size, inference_params.max_sequence_len, 2,
|
inference_params.max_batch_size,
|
||||||
num_heads, head_dim, dtype=kv.dtype, device=kv.device
|
inference_params.max_sequence_len,
|
||||||
|
2,
|
||||||
|
num_heads,
|
||||||
|
head_dim,
|
||||||
|
dtype=kv.dtype,
|
||||||
|
device=kv.device,
|
||||||
)
|
)
|
||||||
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
||||||
else:
|
else:
|
||||||
@ -292,22 +334,30 @@ def _update_kv_cache(kv, inference_params, layer_idx):
|
|||||||
packsize = 4 if kv.dtype == torch.float32 else 8
|
packsize = 4 if kv.dtype == torch.float32 else 8
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
||||||
k_cache = rearrange(kv_cache[:, :, 0], 'b s h (d packsize) -> b h d s packsize',
|
k_cache = rearrange(
|
||||||
packsize=packsize).contiguous()
|
kv_cache[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
|
||||||
v_cache = rearrange(kv_cache[:, :, 1], 'b s h d -> b h s d').contiguous()
|
).contiguous()
|
||||||
|
v_cache = rearrange(kv_cache[:, :, 1], "b s h d -> b h s d").contiguous()
|
||||||
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
|
inference_params.key_value_memory_dict[layer_idx] = (k_cache, v_cache)
|
||||||
else:
|
else:
|
||||||
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
|
k_cache[batch_start:batch_end, :, :, :sequence_end, :] = rearrange(
|
||||||
kv[:, :, 0], 'b s h (d packsize) -> b h d s packsize', packsize=packsize
|
kv[:, :, 0], "b s h (d packsize) -> b h d s packsize", packsize=packsize
|
||||||
)
|
)
|
||||||
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
|
v_cache[batch_start:batch_end, :, :sequence_end, :] = rearrange(
|
||||||
kv[:, :, 1], 'b s h d -> b h s d'
|
kv[:, :, 1], "b s h d -> b h s d"
|
||||||
)
|
)
|
||||||
return kv
|
return kv
|
||||||
|
|
||||||
|
|
||||||
def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotary_emb_dim,
|
def _apply_rotary_single_query_attention(
|
||||||
rotary_emb_base, kv=None, rotary_emb_interleaved=False):
|
qkv,
|
||||||
|
inference_params,
|
||||||
|
layer_idx,
|
||||||
|
rotary_emb_dim,
|
||||||
|
rotary_emb_base,
|
||||||
|
kv=None,
|
||||||
|
rotary_emb_interleaved=False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
|
qkv: (batch_size, 1, 3, nheads, head_dim) if kv is None else it's just
|
||||||
q of shape (batch_size, 1, nheads, head_dim)
|
q of shape (batch_size, 1, nheads, head_dim)
|
||||||
@ -316,17 +366,22 @@ def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotar
|
|||||||
assert inference_params.fused_ft_kernel
|
assert inference_params.fused_ft_kernel
|
||||||
assert ft_attention is not None
|
assert ft_attention is not None
|
||||||
if kv is None:
|
if kv is None:
|
||||||
q, k, v = rearrange(qkv, 'b 1 three h d -> b three h d').unbind(dim=1)
|
q, k, v = rearrange(qkv, "b 1 three h d -> b three h d").unbind(dim=1)
|
||||||
else:
|
else:
|
||||||
q = rearrange(qkv, 'b 1 h d -> b h d')
|
q = rearrange(qkv, "b 1 h d -> b h d")
|
||||||
k, v = rearrange(kv, 'b 1 two h d -> b two h d').unbind(dim=1)
|
k, v = rearrange(kv, "b 1 two h d -> b two h d").unbind(dim=1)
|
||||||
batch_start = inference_params.batch_size_offset
|
batch_start = inference_params.batch_size_offset
|
||||||
batch_end = batch_start + q.shape[0]
|
batch_end = batch_start + q.shape[0]
|
||||||
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
|
k_cache, v_cache = inference_params.key_value_memory_dict[layer_idx]
|
||||||
lengths_per_sample = (inference_params.lengths_per_sample[batch_start:batch_end]
|
lengths_per_sample = (
|
||||||
if inference_params.lengths_per_sample is not None else None)
|
inference_params.lengths_per_sample[batch_start:batch_end]
|
||||||
|
if inference_params.lengths_per_sample is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
context = ft_attention.single_query_attention(
|
context = ft_attention.single_query_attention(
|
||||||
q, k, v,
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
k_cache[batch_start:batch_end],
|
k_cache[batch_start:batch_end],
|
||||||
v_cache[batch_start:batch_end],
|
v_cache[batch_start:batch_end],
|
||||||
lengths_per_sample,
|
lengths_per_sample,
|
||||||
@ -334,29 +389,47 @@ def _apply_rotary_single_query_attention(qkv, inference_params, layer_idx, rotar
|
|||||||
None, # rotary_sin_
|
None, # rotary_sin_
|
||||||
None, # nnz_head_idx
|
None, # nnz_head_idx
|
||||||
inference_params.sequence_len_offset,
|
inference_params.sequence_len_offset,
|
||||||
rotary_emb_dim, rotary_emb_base,
|
rotary_emb_dim,
|
||||||
not rotary_emb_interleaved # neox_rotary_style
|
rotary_emb_base,
|
||||||
|
not rotary_emb_interleaved, # neox_rotary_style
|
||||||
)
|
)
|
||||||
return rearrange(context, 'b h d -> b 1 h d')
|
return rearrange(context, "b h d -> b 1 h d")
|
||||||
|
|
||||||
|
|
||||||
class MHA(nn.Module):
|
class MHA(nn.Module):
|
||||||
"""Multi-head self-attention and cross-attention
|
"""Multi-head self-attention and cross-attention"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, embed_dim, num_heads, num_heads_kv=None, cross_attn=False,
|
def __init__(
|
||||||
qkv_proj_bias=True, out_proj_bias=True,
|
self,
|
||||||
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None, dwconv=False,
|
embed_dim,
|
||||||
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
|
num_heads,
|
||||||
rotary_emb_interleaved=False, fused_bias_fc=False, use_flash_attn=False,
|
num_heads_kv=None,
|
||||||
return_residual=False, checkpointing=False, device=None, dtype=None) -> None:
|
cross_attn=False,
|
||||||
|
qkv_proj_bias=True,
|
||||||
|
out_proj_bias=True,
|
||||||
|
dropout=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
layer_idx=None,
|
||||||
|
dwconv=False,
|
||||||
|
rotary_emb_dim=0,
|
||||||
|
rotary_emb_base=10000.0,
|
||||||
|
rotary_emb_scale_base=None,
|
||||||
|
rotary_emb_interleaved=False,
|
||||||
|
fused_bias_fc=False,
|
||||||
|
use_flash_attn=False,
|
||||||
|
return_residual=False,
|
||||||
|
checkpointing=False,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads.
|
||||||
return_residual: whether to return the input x along with the output. This is for
|
return_residual: whether to return the input x along with the output. This is for
|
||||||
performance reason: for post-norm architecture, returning the input allows us
|
performance reason: for post-norm architecture, returning the input allows us
|
||||||
to fuse the backward of nn.Linear with the residual connection.
|
to fuse the backward of nn.Linear with the residual connection.
|
||||||
"""
|
"""
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.cross_attn = cross_attn
|
self.cross_attn = cross_attn
|
||||||
@ -370,24 +443,31 @@ class MHA(nn.Module):
|
|||||||
|
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
||||||
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
|
assert (
|
||||||
|
self.num_heads % self.num_heads_kv == 0
|
||||||
|
), "num_heads must be divisible by num_heads_kv"
|
||||||
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||||
self.head_dim = self.embed_dim // num_heads
|
self.head_dim = self.embed_dim // num_heads
|
||||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||||
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
||||||
|
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
|
assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet"
|
||||||
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
|
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
||||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
|
self.rotary_emb = RotaryEmbedding(
|
||||||
scale_base=rotary_emb_scale_base,
|
self.rotary_emb_dim,
|
||||||
interleaved=rotary_emb_interleaved, device=device)
|
base=rotary_emb_base,
|
||||||
|
scale_base=rotary_emb_scale_base,
|
||||||
|
interleaved=rotary_emb_interleaved,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
if fused_bias_fc and FusedDense is None:
|
if fused_bias_fc and FusedDense is None:
|
||||||
raise ImportError('fused_dense is not installed')
|
raise ImportError("fused_dense is not installed")
|
||||||
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
linear_cls = nn.Linear if not fused_bias_fc else FusedDense
|
||||||
linear_resid_cls = (LinearResidual if not fused_bias_fc
|
linear_resid_cls = (
|
||||||
else partial(FusedDense, return_residual=True))
|
LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True)
|
||||||
|
)
|
||||||
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls
|
||||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||||
@ -398,40 +478,57 @@ class MHA(nn.Module):
|
|||||||
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs)
|
||||||
if self.dwconv:
|
if self.dwconv:
|
||||||
if self.num_heads_kv == self.num_heads:
|
if self.num_heads_kv == self.num_heads:
|
||||||
self.dwconv_qkv = nn.Conv1d(qkv_dim, qkv_dim, kernel_size=3, padding=2,
|
self.dwconv_qkv = nn.Conv1d(
|
||||||
groups=qkv_dim)
|
qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.dwconv_q = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=2,
|
self.dwconv_q = nn.Conv1d(
|
||||||
groups=embed_dim)
|
embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim
|
||||||
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2,
|
)
|
||||||
groups=kv_dim)
|
self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim)
|
||||||
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
self.inner_attn = inner_attn_cls(
|
||||||
attention_dropout=dropout)
|
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||||
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
)
|
||||||
attention_dropout=dropout)
|
self.inner_cross_attn = inner_cross_attn_cls(
|
||||||
|
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||||
|
)
|
||||||
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs)
|
||||||
|
|
||||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
|
||||||
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
||||||
device = self.out_proj.weight.device
|
device = self.out_proj.weight.device
|
||||||
if not fused_ft_kernel:
|
if not fused_ft_kernel:
|
||||||
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv, self.head_dim,
|
return torch.empty(
|
||||||
dtype=dtype, device=device)
|
batch_size,
|
||||||
|
max_seqlen,
|
||||||
|
2,
|
||||||
|
self.num_heads_kv,
|
||||||
|
self.head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||||
packsize = 4 if dtype == torch.float32 else 8
|
packsize = 4 if dtype == torch.float32 else 8
|
||||||
assert self.head_dim % packsize == 0
|
assert self.head_dim % packsize == 0
|
||||||
k_cache = torch.empty(batch_size, self.num_heads_kv, self.head_dim // packsize,
|
k_cache = torch.empty(
|
||||||
max_seqlen, packsize, dtype=dtype, device=device)
|
batch_size,
|
||||||
v_cache = torch.empty(batch_size, self.num_heads_kv, max_seqlen, self.head_dim,
|
self.num_heads_kv,
|
||||||
dtype=dtype, device=device)
|
self.head_dim // packsize,
|
||||||
|
max_seqlen,
|
||||||
|
packsize,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
v_cache = torch.empty(
|
||||||
|
batch_size, self.num_heads_kv, max_seqlen, self.head_dim, dtype=dtype, device=device
|
||||||
|
)
|
||||||
return k_cache, v_cache
|
return k_cache, v_cache
|
||||||
|
|
||||||
def _update_kv_cache(self, kv, inference_params):
|
def _update_kv_cache(self, kv, inference_params):
|
||||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
||||||
"""
|
assert not self.dwconv, "Generation does not support dwconv yet"
|
||||||
assert not self.dwconv, 'Generation does not support dwconv yet'
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
|
||||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||||
|
|
||||||
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
|
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
|
||||||
@ -442,12 +539,28 @@ class MHA(nn.Module):
|
|||||||
"""
|
"""
|
||||||
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
|
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
|
||||||
return _apply_rotary_single_query_attention(
|
return _apply_rotary_single_query_attention(
|
||||||
qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
|
qkv,
|
||||||
rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
inference_params,
|
||||||
|
self.layer_idx,
|
||||||
|
self.rotary_emb_dim,
|
||||||
|
rotary_emb_base,
|
||||||
|
kv=kv,
|
||||||
|
rotary_emb_interleaved=self.rotary_emb.interleaved
|
||||||
|
if self.rotary_emb_dim > 0
|
||||||
|
else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, x_kv=None, key_padding_mask=None, cu_seqlens=None, max_seqlen=None,
|
def forward(
|
||||||
mixer_subset=None, inference_params=None, **kwargs):
|
self,
|
||||||
|
x,
|
||||||
|
x_kv=None,
|
||||||
|
key_padding_mask=None,
|
||||||
|
cu_seqlens=None,
|
||||||
|
max_seqlen=None,
|
||||||
|
mixer_subset=None,
|
||||||
|
inference_params=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Arguments:
|
Arguments:
|
||||||
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
|
||||||
@ -481,8 +594,11 @@ class MHA(nn.Module):
|
|||||||
assert cu_seqlens is None and max_seqlen is None
|
assert cu_seqlens is None and max_seqlen is None
|
||||||
assert not self.dwconv
|
assert not self.dwconv
|
||||||
|
|
||||||
kwargs = ({'cu_seqlens': cu_seqlens, 'max_seqlen': max_seqlen, **kwargs}
|
kwargs = (
|
||||||
if self.use_flash_attn else {'key_padding_mask': key_padding_mask, **kwargs})
|
{"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs}
|
||||||
|
if self.use_flash_attn
|
||||||
|
else {"key_padding_mask": key_padding_mask, **kwargs}
|
||||||
|
)
|
||||||
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
||||||
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
if not self.cross_attn and self.num_heads_kv == self.num_heads:
|
||||||
assert x_kv is None and mixer_subset is None
|
assert x_kv is None and mixer_subset is None
|
||||||
@ -491,19 +607,22 @@ class MHA(nn.Module):
|
|||||||
else:
|
else:
|
||||||
qkv, x = self.Wqkv(x)
|
qkv, x = self.Wqkv(x)
|
||||||
if self.dwconv:
|
if self.dwconv:
|
||||||
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
|
qkv = rearrange(
|
||||||
'b d s -> b s d').contiguous()
|
self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
||||||
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
|
).contiguous()
|
||||||
if (inference_params is None or inference_params.sequence_len_offset == 0
|
qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim)
|
||||||
or not inference_params.fused_ft_kernel):
|
if (
|
||||||
|
inference_params is None
|
||||||
|
or inference_params.sequence_len_offset == 0
|
||||||
|
or not inference_params.fused_ft_kernel
|
||||||
|
):
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if not self.checkpointing:
|
if not self.checkpointing:
|
||||||
context = self.inner_attn(qkv, **kwargs)
|
context = self.inner_attn(qkv, **kwargs)
|
||||||
else:
|
else:
|
||||||
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv,
|
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
|
||||||
**kwargs)
|
|
||||||
else:
|
else:
|
||||||
q = qkv[:, :, 0]
|
q = qkv[:, :, 0]
|
||||||
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
kv = self._update_kv_cache(qkv[:, :, 1:], inference_params)
|
||||||
@ -530,25 +649,31 @@ class MHA(nn.Module):
|
|||||||
qkv = self.Wqkv(x)
|
qkv = self.Wqkv(x)
|
||||||
else:
|
else:
|
||||||
qkv, x = self.Wqkv(x)
|
qkv, x = self.Wqkv(x)
|
||||||
q = qkv[..., :self.num_heads * self.head_dim]
|
q = qkv[..., : self.num_heads * self.head_dim]
|
||||||
kv = qkv[..., self.num_heads * self.head_dim:]
|
kv = qkv[..., self.num_heads * self.head_dim :]
|
||||||
q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
|
q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim)
|
||||||
kv = rearrange(kv, '... (two hkv d) -> ... two hkv d', two=2, d=self.head_dim)
|
kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
||||||
if self.dwconv:
|
if self.dwconv:
|
||||||
q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
|
q = rearrange(
|
||||||
'b d s -> b s d').contiguous()
|
self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
||||||
kv = rearrange(self.dwconv_kv(rearrange(kv, 'b s d -> b d s'))[..., :-2],
|
).contiguous()
|
||||||
'b d s -> b s d').contiguous()
|
kv = rearrange(
|
||||||
if (inference_params is None or inference_params.sequence_len_offset == 0
|
self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d"
|
||||||
or not inference_params.fused_ft_kernel):
|
).contiguous()
|
||||||
|
if (
|
||||||
|
inference_params is None
|
||||||
|
or inference_params.sequence_len_offset == 0
|
||||||
|
or not inference_params.fused_ft_kernel
|
||||||
|
):
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if not self.checkpointing:
|
if not self.checkpointing:
|
||||||
context = self.inner_cross_attn(q, kv, **kwargs)
|
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||||
else:
|
else:
|
||||||
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
|
context = torch.utils.checkpoint.checkpoint(
|
||||||
**kwargs)
|
self.inner_cross_attn, q, kv, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kv = self._update_kv_cache(kv, inference_params)
|
kv = self._update_kv_cache(kv, inference_params)
|
||||||
# If we're processing the prompt, causal=None (use self.causal).
|
# If we're processing the prompt, causal=None (use self.causal).
|
||||||
@ -557,21 +682,36 @@ class MHA(nn.Module):
|
|||||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||||
else:
|
else:
|
||||||
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
||||||
out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
|
out = self.out_proj(rearrange(context, "... h d -> ... (h d)"))
|
||||||
return out if not self.return_residual else (out, x)
|
return out if not self.return_residual else (out, x)
|
||||||
|
|
||||||
|
|
||||||
class ParallelMHA(nn.Module):
|
class ParallelMHA(nn.Module):
|
||||||
"""Multi-head self-attention and cross-attention
|
"""Multi-head self-attention and cross-attention"""
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, embed_dim, num_heads, process_group, num_heads_kv=None,
|
def __init__(
|
||||||
qkv_proj_bias=True, out_proj_bias=True,
|
self,
|
||||||
dropout=0.0, softmax_scale=None, causal=False, layer_idx=None,
|
embed_dim,
|
||||||
rotary_emb_dim=0, rotary_emb_base=10000.0, rotary_emb_scale_base=None,
|
num_heads,
|
||||||
rotary_emb_interleaved=False, use_flash_attn=False, checkpointing=False,
|
process_group,
|
||||||
sequence_parallel=True, device=None, dtype=None) -> None:
|
num_heads_kv=None,
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
qkv_proj_bias=True,
|
||||||
|
out_proj_bias=True,
|
||||||
|
dropout=0.0,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=False,
|
||||||
|
layer_idx=None,
|
||||||
|
rotary_emb_dim=0,
|
||||||
|
rotary_emb_base=10000.0,
|
||||||
|
rotary_emb_scale_base=None,
|
||||||
|
rotary_emb_interleaved=False,
|
||||||
|
use_flash_attn=False,
|
||||||
|
checkpointing=False,
|
||||||
|
sequence_parallel=True,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
@ -586,55 +726,93 @@ class ParallelMHA(nn.Module):
|
|||||||
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
|
||||||
self.num_heads_per_rank = num_heads // self.world_size
|
self.num_heads_per_rank = num_heads // self.world_size
|
||||||
self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
|
self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
|
||||||
assert self.num_heads % self.num_heads_kv == 0, "num_heads must be divisible by num_heads_kv"
|
assert (
|
||||||
|
self.num_heads % self.num_heads_kv == 0
|
||||||
|
), "num_heads must be divisible by num_heads_kv"
|
||||||
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
|
||||||
assert self.num_heads_kv % self.world_size == 0, "num_heads_kv must be divisible by world_size"
|
assert (
|
||||||
|
self.num_heads_kv % self.world_size == 0
|
||||||
|
), "num_heads_kv must be divisible by world_size"
|
||||||
self.head_dim = self.embed_dim // num_heads
|
self.head_dim = self.embed_dim // num_heads
|
||||||
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
|
||||||
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
kv_dim = 2 * self.head_dim * self.num_heads_kv
|
||||||
|
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
|
assert RotaryEmbedding is not None, "rotary_emb is not installed"
|
||||||
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, base=rotary_emb_base,
|
self.rotary_emb = RotaryEmbedding(
|
||||||
scale_base=rotary_emb_scale_base,
|
self.rotary_emb_dim,
|
||||||
interleaved=rotary_emb_interleaved, device=device)
|
base=rotary_emb_base,
|
||||||
|
scale_base=rotary_emb_scale_base,
|
||||||
|
interleaved=rotary_emb_interleaved,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
if ColumnParallelLinear is None or RowParallelLinear is None:
|
if ColumnParallelLinear is None or RowParallelLinear is None:
|
||||||
raise ImportError('fused_dense is not installed')
|
raise ImportError("fused_dense is not installed")
|
||||||
self.Wqkv = ColumnParallelLinear(embed_dim, qkv_dim, process_group,
|
self.Wqkv = ColumnParallelLinear(
|
||||||
bias=qkv_proj_bias,
|
embed_dim,
|
||||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
qkv_dim,
|
||||||
|
process_group,
|
||||||
|
bias=qkv_proj_bias,
|
||||||
|
sequence_parallel=sequence_parallel,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
|
||||||
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
inner_cross_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
|
||||||
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
self.inner_attn = inner_attn_cls(
|
||||||
attention_dropout=dropout)
|
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||||
self.inner_cross_attn = inner_cross_attn_cls(causal=causal, softmax_scale=softmax_scale,
|
)
|
||||||
attention_dropout=dropout)
|
self.inner_cross_attn = inner_cross_attn_cls(
|
||||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group,
|
causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout
|
||||||
bias=out_proj_bias,
|
)
|
||||||
sequence_parallel=sequence_parallel, **factory_kwargs)
|
self.out_proj = RowParallelLinear(
|
||||||
|
embed_dim,
|
||||||
|
embed_dim,
|
||||||
|
process_group,
|
||||||
|
bias=out_proj_bias,
|
||||||
|
sequence_parallel=sequence_parallel,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, fused_ft_kernel=True):
|
||||||
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
dtype = self.out_proj.weight.dtype if dtype is None else dtype
|
||||||
device = self.out_proj.weight.device
|
device = self.out_proj.weight.device
|
||||||
if not fused_ft_kernel:
|
if not fused_ft_kernel:
|
||||||
return torch.empty(batch_size, max_seqlen, 2, self.num_heads_kv_per_rank,
|
return torch.empty(
|
||||||
self.head_dim, dtype=dtype, device=device)
|
batch_size,
|
||||||
|
max_seqlen,
|
||||||
|
2,
|
||||||
|
self.num_heads_kv_per_rank,
|
||||||
|
self.head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
assert dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||||
packsize = 4 if dtype == torch.float32 else 8
|
packsize = 4 if dtype == torch.float32 else 8
|
||||||
assert self.head_dim % packsize == 0
|
assert self.head_dim % packsize == 0
|
||||||
k_cache = torch.empty(batch_size, self.num_heads_kv_per_rank,
|
k_cache = torch.empty(
|
||||||
self.head_dim // packsize,
|
batch_size,
|
||||||
max_seqlen, packsize, dtype=dtype, device=device)
|
self.num_heads_kv_per_rank,
|
||||||
v_cache = torch.empty(batch_size, self.num_heads_kv_per_rank, max_seqlen,
|
self.head_dim // packsize,
|
||||||
self.head_dim, dtype=dtype, device=device)
|
max_seqlen,
|
||||||
|
packsize,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
v_cache = torch.empty(
|
||||||
|
batch_size,
|
||||||
|
self.num_heads_kv_per_rank,
|
||||||
|
max_seqlen,
|
||||||
|
self.head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
return k_cache, v_cache
|
return k_cache, v_cache
|
||||||
|
|
||||||
def _update_kv_cache(self, kv, inference_params):
|
def _update_kv_cache(self, kv, inference_params):
|
||||||
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)
|
"""kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)"""
|
||||||
"""
|
assert self.layer_idx is not None, "Generation requires layer_idx in the constructor"
|
||||||
assert self.layer_idx is not None, 'Generation requires layer_idx in the constructor'
|
|
||||||
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
return _update_kv_cache(kv, inference_params, self.layer_idx)
|
||||||
|
|
||||||
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
|
def _apply_rotary_single_query_attention(self, qkv, inference_params, kv=None):
|
||||||
@ -645,8 +823,15 @@ class ParallelMHA(nn.Module):
|
|||||||
"""
|
"""
|
||||||
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
|
rotary_emb_base = self.rotary_emb.base if self.rotary_emb_dim > 0 else 0
|
||||||
return _apply_rotary_single_query_attention(
|
return _apply_rotary_single_query_attention(
|
||||||
qkv, inference_params, self.layer_idx, self.rotary_emb_dim, rotary_emb_base, kv=kv,
|
qkv,
|
||||||
rotary_emb_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False,
|
inference_params,
|
||||||
|
self.layer_idx,
|
||||||
|
self.rotary_emb_dim,
|
||||||
|
rotary_emb_base,
|
||||||
|
kv=kv,
|
||||||
|
rotary_emb_interleaved=self.rotary_emb.interleaved
|
||||||
|
if self.rotary_emb_dim > 0
|
||||||
|
else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
def forward(self, x, seqlen=None, inference_params=None, **kwargs):
|
||||||
@ -662,9 +847,12 @@ class ParallelMHA(nn.Module):
|
|||||||
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen)
|
||||||
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
seqlen_offset = 0 if inference_params is None else inference_params.sequence_len_offset
|
||||||
if self.num_heads_kv == self.num_heads:
|
if self.num_heads_kv == self.num_heads:
|
||||||
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim)
|
qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim)
|
||||||
if (inference_params is None or inference_params.sequence_len_offset == 0
|
if (
|
||||||
or not inference_params.fused_ft_kernel):
|
inference_params is None
|
||||||
|
or inference_params.sequence_len_offset == 0
|
||||||
|
or not inference_params.fused_ft_kernel
|
||||||
|
):
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
qkv = self.rotary_emb(qkv, seqlen_offset=seqlen_offset)
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
@ -682,20 +870,31 @@ class ParallelMHA(nn.Module):
|
|||||||
else:
|
else:
|
||||||
context = self._apply_rotary_single_query_attention(qkv, inference_params)
|
context = self._apply_rotary_single_query_attention(qkv, inference_params)
|
||||||
else:
|
else:
|
||||||
q = rearrange(qkv[..., :self.num_heads_per_rank * self.head_dim],
|
q = rearrange(
|
||||||
"... (h d) -> ... h d", d=self.head_dim)
|
qkv[..., : self.num_heads_per_rank * self.head_dim],
|
||||||
kv = rearrange(qkv[..., self.num_heads_per_rank * self.head_dim:],
|
"... (h d) -> ... h d",
|
||||||
"... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim)
|
d=self.head_dim,
|
||||||
if (inference_params is None or inference_params.sequence_len_offset == 0
|
)
|
||||||
or not inference_params.fused_ft_kernel):
|
kv = rearrange(
|
||||||
|
qkv[..., self.num_heads_per_rank * self.head_dim :],
|
||||||
|
"... (two hkv d) -> ... two hkv d",
|
||||||
|
two=2,
|
||||||
|
d=self.head_dim,
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
inference_params is None
|
||||||
|
or inference_params.sequence_len_offset == 0
|
||||||
|
or not inference_params.fused_ft_kernel
|
||||||
|
):
|
||||||
if self.rotary_emb_dim > 0:
|
if self.rotary_emb_dim > 0:
|
||||||
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
q, kv = self.rotary_emb(q, kv, seqlen_offset=seqlen_offset)
|
||||||
if inference_params is None:
|
if inference_params is None:
|
||||||
if not self.checkpointing:
|
if not self.checkpointing:
|
||||||
context = self.inner_cross_attn(q, kv, **kwargs)
|
context = self.inner_cross_attn(q, kv, **kwargs)
|
||||||
else:
|
else:
|
||||||
context = torch.utils.checkpoint.checkpoint(self.inner_cross_attn, q, kv,
|
context = torch.utils.checkpoint.checkpoint(
|
||||||
**kwargs)
|
self.inner_cross_attn, q, kv, **kwargs
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kv = self._update_kv_cache(kv, inference_params)
|
kv = self._update_kv_cache(kv, inference_params)
|
||||||
# If we're processing the prompt, causal=None (use self.causal).
|
# If we're processing the prompt, causal=None (use self.causal).
|
||||||
@ -704,8 +903,8 @@ class ParallelMHA(nn.Module):
|
|||||||
context = self.inner_cross_attn(q, kv, causal=causal)
|
context = self.inner_cross_attn(q, kv, causal=causal)
|
||||||
else:
|
else:
|
||||||
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
context = self._apply_rotary_single_query_attention(q, inference_params, kv=kv)
|
||||||
context = rearrange(context, 'b s h d -> b s (h d)')
|
context = rearrange(context, "b s h d -> b s (h d)")
|
||||||
if seqlen is not None:
|
if seqlen is not None:
|
||||||
context = rearrange(context, 'b s d -> (b s) d')
|
context = rearrange(context, "b s d -> (b s) d")
|
||||||
out = self.out_proj(context)
|
out = self.out_proj(context)
|
||||||
return out
|
return out
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user