248 lines
9.5 KiB
Python
248 lines
9.5 KiB
Python
# Inspired by https://github.com/zhuzilin/ring-flash-attention
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import distributed as dist
|
|
from typing import Any, Optional, Tuple
|
|
from distributed.distributed_primtives import ContextComms
|
|
from model import Attention
|
|
import distributed.process_group_manager as pgm
|
|
|
|
from parallel.base_parallel import BaseParallel
|
|
|
|
class ContextParallel(BaseParallel):
|
|
def __init__(self, model, config):
|
|
super().__init__(model, config)
|
|
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, Attention) and not isinstance(module, RingAttention):
|
|
parent_name, child_name = name.rsplit('.', 1)
|
|
parent_module = model.get_submodule(parent_name)
|
|
setattr(parent_module, child_name, RingAttention(module))
|
|
del module
|
|
|
|
class RingAttention(nn.Module):
|
|
def __init__(self, original_mha):
|
|
super().__init__()
|
|
|
|
self.hidden_size = original_mha.hidden_size
|
|
self.num_heads = original_mha.num_heads
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
self.num_key_value_heads = original_mha.num_key_value_heads
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
|
self.is_causal = original_mha.is_causal
|
|
|
|
# Copy the weights from the original Attention
|
|
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
|
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
|
|
|
self.q_proj.weight.data.copy_(original_mha.q_proj.weight.data)
|
|
self.k_proj.weight.data.copy_(original_mha.k_proj.weight.data)
|
|
self.v_proj.weight.data.copy_(original_mha.v_proj.weight.data)
|
|
self.o_proj.weight.data.copy_(original_mha.o_proj.weight.data)
|
|
|
|
self.rotary = original_mha.rotary
|
|
|
|
def forward(self, input_ids, position_ids):
|
|
batch_size, seq_len, _ = input_ids.shape
|
|
|
|
q = self.q_proj(input_ids)
|
|
k = self.k_proj(input_ids)
|
|
v = self.v_proj(input_ids)
|
|
|
|
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
k = k.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
v = v.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
if self.rotary is not None:
|
|
cos, sin = self.rotary(v, position_ids)
|
|
q, k = self.rotary.apply_rotary_pos_emb(q, k, cos, sin)
|
|
|
|
k = self._repeat_kv(k, self.num_key_value_groups)
|
|
v = self._repeat_kv(v, self.num_key_value_groups)
|
|
|
|
sm_scale = 1.0 / (q.size(-1) ** 0.5)
|
|
output = RingAttentionFunc.apply(q, k, v, sm_scale, self.is_causal)
|
|
|
|
output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
|
|
output = self.o_proj(output)
|
|
return output
|
|
|
|
def _repeat_kv(self, x, n_rep):
|
|
batch, num_key_value_heads, seq_len, head_dim = x.shape
|
|
if n_rep == 1:
|
|
return x
|
|
x = x[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, seq_len, head_dim)
|
|
return x.reshape(batch, num_key_value_heads * n_rep, seq_len, head_dim)
|
|
|
|
class RingAttentionFunc(torch.autograd.Function):
|
|
|
|
@staticmethod
|
|
def forward(ctx, q, k, v, sm_scale, is_causal):
|
|
comm = ContextComms("comm")
|
|
#NOTE: Find a better to save these tensors without cloning
|
|
k_og = k.clone()
|
|
v_og = v.clone()
|
|
out, lse = None, None
|
|
next_k, next_v = None, None
|
|
|
|
for step in range(comm.world_size):
|
|
if step + 1 != comm.world_size:
|
|
next_k = comm.send_recv(k)
|
|
next_v = comm.send_recv(v)
|
|
comm.commit()
|
|
|
|
if not is_causal or step <= comm.rank:
|
|
block_out, block_lse = ring_attention_forward(
|
|
q, k, v, sm_scale, is_causal and step == 0
|
|
)
|
|
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
|
|
|
|
if step + 1 != comm.world_size:
|
|
comm.wait()
|
|
k = next_k
|
|
v = next_v
|
|
|
|
out = out.to(q.dtype)
|
|
ctx.save_for_backward(q, k_og, v_og, out, lse.squeeze(-1))
|
|
ctx.sm_scale = sm_scale
|
|
ctx.is_causal = is_causal
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, dout, *args):
|
|
|
|
q, k, v, out, softmax_lse = ctx.saved_tensors
|
|
sm_scale = ctx.sm_scale
|
|
is_causal = ctx.is_causal
|
|
|
|
kv_comm = ContextComms("kv_comm")
|
|
d_kv_comm = ContextComms("d_kv_comm")
|
|
dq, dk, dv = None, None, None
|
|
next_dk, next_dv = None, None
|
|
|
|
block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
|
|
block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
|
|
block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
|
|
|
|
next_dk, next_dv = None, None
|
|
next_k, next_v = None, None
|
|
|
|
for step in range(kv_comm.world_size):
|
|
if step + 1 != kv_comm.world_size:
|
|
next_k = kv_comm.send_recv(k)
|
|
next_v = kv_comm.send_recv(v)
|
|
kv_comm.commit()
|
|
|
|
if step <= kv_comm.rank or not is_causal:
|
|
bwd_causal = is_causal and step == 0
|
|
|
|
block_dq_buffer, block_dk_buffer, block_dv_buffer = ring_attention_backward(
|
|
dout, q, k, v, out, softmax_lse, sm_scale, bwd_causal
|
|
)
|
|
|
|
if dq is None:
|
|
dq = block_dq_buffer.to(torch.float32)
|
|
dk = block_dk_buffer.to(torch.float32)
|
|
dv = block_dv_buffer.to(torch.float32)
|
|
else:
|
|
dq += block_dq_buffer
|
|
d_kv_comm.wait()
|
|
dk = block_dk_buffer + next_dk
|
|
dv = block_dv_buffer + next_dv
|
|
elif step != 0:
|
|
d_kv_comm.wait()
|
|
dk = next_dk
|
|
dv = next_dv
|
|
|
|
if step + 1 != kv_comm.world_size:
|
|
kv_comm.wait()
|
|
k = next_k
|
|
v = next_v
|
|
|
|
next_dk = d_kv_comm.send_recv(dk)
|
|
next_dv = d_kv_comm.send_recv(dv)
|
|
d_kv_comm.commit()
|
|
|
|
d_kv_comm.wait()
|
|
|
|
return dq, next_dk, next_dv, None, None
|
|
|
|
def ring_attention_forward(q, k, v, sm_scale, is_causal):
|
|
batch_size, nheads, seqlen, d = q.shape
|
|
S = torch.matmul(q, k.transpose(-2, -1)) * sm_scale
|
|
|
|
if is_causal:
|
|
causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=q.device, dtype=torch.bool), diagonal=1)
|
|
causal_mask = causal_mask.unsqueeze(0).unsqueeze(1).expand(batch_size, nheads, seqlen, seqlen)
|
|
S.masked_fill_(causal_mask, float('-inf'))
|
|
|
|
# Online softmax
|
|
S_max = torch.max(S, dim=-1, keepdim=True)[0]
|
|
exp_S = torch.exp(S - S_max)
|
|
exp_sum = torch.sum(exp_S, dim=-1, keepdim=True)
|
|
log_sum_exp = torch.log(exp_sum) + S_max
|
|
P = exp_S / exp_sum
|
|
O = torch.matmul(P, v)
|
|
return O, log_sum_exp.squeeze(-1)
|
|
|
|
def ring_attention_backward(dO, Q, K, V, O, softmax_lse, sm_scale, is_causal):
|
|
batch_size, nheads, seqlen, d = Q.shape
|
|
|
|
# Recreate S and P from log_sum_exp
|
|
S = torch.matmul(Q, K.transpose(-2, -1)) * sm_scale
|
|
if is_causal:
|
|
causal_mask = torch.triu(torch.ones(seqlen, seqlen, device=Q.device, dtype=torch.bool), diagonal=1)
|
|
S = S.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), float('-inf'))
|
|
|
|
P = torch.exp(S - softmax_lse.unsqueeze(-1))
|
|
# Step 1: Compute dV
|
|
dV = torch.matmul(P.transpose(-2, -1), dO)
|
|
# Step 2: Compute dP
|
|
dP = torch.matmul(dO, V.transpose(-2, -1))
|
|
# Step 3: Compute D
|
|
D = torch.sum(dO * O, dim=-1, keepdim=True)
|
|
# Step 4: Compute dS
|
|
dS = P * (dP - D)
|
|
# Apply causal mask to dS if is_causal is True
|
|
if is_causal:
|
|
dS = dS.masked_fill(causal_mask.unsqueeze(0).unsqueeze(1), 0)
|
|
# Step 5: Compute dQ
|
|
dQ = torch.matmul(dS, K) * sm_scale
|
|
# Step 6: Compute dK
|
|
dK = torch.matmul(dS.transpose(-2, -1), Q) * sm_scale
|
|
return dQ, dK, dV
|
|
|
|
def update_out_and_lse(
|
|
out: Optional[torch.Tensor],
|
|
lse: Optional[torch.Tensor],
|
|
block_out: torch.Tensor,
|
|
block_lse: torch.Tensor,
|
|
slice_: Optional[Any] = None
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
def _update(current_out, current_lse):
|
|
# new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
|
|
# torch.exp(lse - new_lse) * out + torch.exp(block_lse - new_lse) * block_out
|
|
# For additional context and discussion, please refer to:
|
|
# https://github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795
|
|
current_out = current_out - F.sigmoid(block_lse - current_lse) * (current_out - block_out)
|
|
current_lse = current_lse - F.logsigmoid(current_lse - block_lse)
|
|
return current_out, current_lse
|
|
|
|
block_out = block_out.to(torch.float32)
|
|
block_lse = block_lse.unsqueeze(dim=-1)
|
|
|
|
if out is None:
|
|
if slice_ is not None:
|
|
raise RuntimeError("first update_out_and_lse should not pass slice_ args")
|
|
return block_out, block_lse
|
|
|
|
if slice_ is not None:
|
|
out[slice_], lse[slice_] = _update(out[slice_], lse[slice_])
|
|
else:
|
|
out, lse = _update(out, lse)
|
|
|
|
return out, lse |