[FusedDense] Kick off input all_gather before weight dtype conversion

This commit is contained in:
Tri Dao 2022-12-31 22:47:34 -08:00
parent 71befc19e1
commit 65b4064b2a
2 changed files with 38 additions and 24 deletions

View File

@ -472,13 +472,15 @@ class ParallelMHA(nn.Module):
"""
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
softmax_scale=None, causal=False, rotary_emb_dim=0, rotary_emb_scale_base=0,
softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
rotary_emb_scale_base=0,
use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.embed_dim = embed_dim
self.causal = causal
self.layer_idx = layer_idx
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.checkpointing = checkpointing

View File

@ -32,24 +32,28 @@ class FusedDenseFunc(torch.autograd.Function):
ctx.process_group = process_group
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight = [a.to(dtype=dtype) for a in [x, weight]]
bias = bias.to(dtype=dtype) if bias is not None else None
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
if torch.is_autocast_enabled():
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
weight = weight.contiguous()
if process_group is not None:
handle_x.wait()
batch_shape = total_x.shape[:-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
output = F.linear(total_x, weight, bias)
if ctx.compute_weight_gradient:
ctx.save_for_backward(x, weight)
else:
ctx.save_for_backward(weight)
batch_shape, n = x.shape[:-1], x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
if process_group is not None:
total_x, _ = all_gather_raw(x, process_group)
else:
total_x = x
output = F.linear(total_x, weight, bias)
return output if not return_residual else (output, x)
@staticmethod
@ -188,32 +192,42 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
2: recompute gelu_in and gelu_out in the bwd
"""
assert -1 <= heuristic <= 4
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
if not save_pre_act:
checkpoint_lvl = 2
assert checkpoint_lvl in [0, 1, 2]
ctx.return_residual = return_residual
ctx.process_group = process_group
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if torch.is_autocast_enabled():
x = x.to(dtype=torch.get_autocast_gpu_dtype())
x = x.contiguous()
if process_group is not None:
# We want to kick off the all_gather early, before weight dtype conversion
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
else:
total_x = x
if torch.is_autocast_enabled():
dtype = torch.get_autocast_gpu_dtype()
weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
weight1 = weight1.contiguous()
bias1 = bias1.contiguous() if bias1 is not None else None
weight2 = weight2.contiguous()
bias2 = bias2.contiguous() if bias2 is not None else None
if process_group is not None:
total_x, _ = all_gather_raw(x, process_group)
else:
total_x = x
handle_x.wait()
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
batch_dim = batch_shape.numel()
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
if heuristic == -1:
gelu_in = F.linear(total_x, weight1, bias1)
output1 = F.gelu(gelu_in, approximate='tanh')
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # This is before adding bias1
# This is before adding bias1
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1)
# with torch.jit.fuser('fuser2'):
# output1 = bias_gelu(gelu_in, bias1)
else:
@ -223,8 +237,6 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
if save_pre_act:
gelu_in = rest[0]
output2 = F.linear(output1, weight2, bias2)
ctx.checkpoint_lvl = checkpoint_lvl
ctx.heuristic = heuristic
if checkpoint_lvl == 0:
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
elif checkpoint_lvl == 1: