[FusedDense] Kick off input all_gather before weight dtype conversion
This commit is contained in:
parent
71befc19e1
commit
65b4064b2a
@ -472,13 +472,15 @@ class ParallelMHA(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
|
||||
softmax_scale=None, causal=False, rotary_emb_dim=0, rotary_emb_scale_base=0,
|
||||
softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.process_group = process_group
|
||||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
self.layer_idx = layer_idx
|
||||
self.rotary_emb_dim = rotary_emb_dim
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.checkpointing = checkpointing
|
||||
|
||||
@ -32,24 +32,28 @@ class FusedDenseFunc(torch.autograd.Function):
|
||||
ctx.process_group = process_group
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
x, weight = [a.to(dtype=dtype) for a in [x, weight]]
|
||||
bias = bias.to(dtype=dtype) if bias is not None else None
|
||||
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
x = x.contiguous()
|
||||
if process_group is not None:
|
||||
# We want to kick off the all_gather early, before weight dtype conversion
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
||||
weight = weight.contiguous()
|
||||
if process_group is not None:
|
||||
handle_x.wait()
|
||||
batch_shape = total_x.shape[:-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
output = F.linear(total_x, weight, bias)
|
||||
if ctx.compute_weight_gradient:
|
||||
ctx.save_for_backward(x, weight)
|
||||
else:
|
||||
ctx.save_for_backward(weight)
|
||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
if process_group is not None:
|
||||
total_x, _ = all_gather_raw(x, process_group)
|
||||
else:
|
||||
total_x = x
|
||||
output = F.linear(total_x, weight, bias)
|
||||
return output if not return_residual else (output, x)
|
||||
|
||||
@staticmethod
|
||||
@ -188,32 +192,42 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
2: recompute gelu_in and gelu_out in the bwd
|
||||
"""
|
||||
assert -1 <= heuristic <= 4
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
|
||||
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
|
||||
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
|
||||
if not save_pre_act:
|
||||
checkpoint_lvl = 2
|
||||
assert checkpoint_lvl in [0, 1, 2]
|
||||
ctx.return_residual = return_residual
|
||||
ctx.process_group = process_group
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.heuristic = heuristic
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||
x = x.contiguous()
|
||||
if process_group is not None:
|
||||
# We want to kick off the all_gather early, before weight dtype conversion
|
||||
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||
else:
|
||||
total_x = x
|
||||
|
||||
if torch.is_autocast_enabled():
|
||||
dtype = torch.get_autocast_gpu_dtype()
|
||||
weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
|
||||
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
|
||||
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
|
||||
weight1 = weight1.contiguous()
|
||||
bias1 = bias1.contiguous() if bias1 is not None else None
|
||||
weight2 = weight2.contiguous()
|
||||
bias2 = bias2.contiguous() if bias2 is not None else None
|
||||
if process_group is not None:
|
||||
total_x, _ = all_gather_raw(x, process_group)
|
||||
else:
|
||||
total_x = x
|
||||
handle_x.wait()
|
||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||
batch_dim = batch_shape.numel()
|
||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||
if heuristic == -1:
|
||||
gelu_in = F.linear(total_x, weight1, bias1)
|
||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # This is before adding bias1
|
||||
# This is before adding bias1
|
||||
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1)
|
||||
# with torch.jit.fuser('fuser2'):
|
||||
# output1 = bias_gelu(gelu_in, bias1)
|
||||
else:
|
||||
@ -223,8 +237,6 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||
if save_pre_act:
|
||||
gelu_in = rest[0]
|
||||
output2 = F.linear(output1, weight2, bias2)
|
||||
ctx.checkpoint_lvl = checkpoint_lvl
|
||||
ctx.heuristic = heuristic
|
||||
if checkpoint_lvl == 0:
|
||||
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
|
||||
elif checkpoint_lvl == 1:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user