[FusedDense] Kick off input all_gather before weight dtype conversion
This commit is contained in:
parent
71befc19e1
commit
65b4064b2a
@ -472,13 +472,15 @@ class ParallelMHA(nn.Module):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
|
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
|
||||||
softmax_scale=None, causal=False, rotary_emb_dim=0, rotary_emb_scale_base=0,
|
softmax_scale=None, causal=False, layer_idx=None, rotary_emb_dim=0,
|
||||||
|
rotary_emb_scale_base=0,
|
||||||
use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
|
use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
|
self.layer_idx = layer_idx
|
||||||
self.rotary_emb_dim = rotary_emb_dim
|
self.rotary_emb_dim = rotary_emb_dim
|
||||||
self.use_flash_attn = use_flash_attn
|
self.use_flash_attn = use_flash_attn
|
||||||
self.checkpointing = checkpointing
|
self.checkpointing = checkpointing
|
||||||
|
|||||||
@ -32,24 +32,28 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
|
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
dtype = torch.get_autocast_gpu_dtype()
|
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
x, weight = [a.to(dtype=dtype) for a in [x, weight]]
|
|
||||||
bias = bias.to(dtype=dtype) if bias is not None else None
|
|
||||||
|
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
|
if process_group is not None:
|
||||||
|
# We want to kick off the all_gather early, before weight dtype conversion
|
||||||
|
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||||
|
else:
|
||||||
|
total_x = x
|
||||||
|
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
weight = weight.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
|
bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None
|
||||||
weight = weight.contiguous()
|
weight = weight.contiguous()
|
||||||
|
if process_group is not None:
|
||||||
|
handle_x.wait()
|
||||||
|
batch_shape = total_x.shape[:-1]
|
||||||
|
batch_dim = batch_shape.numel()
|
||||||
|
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||||
|
output = F.linear(total_x, weight, bias)
|
||||||
if ctx.compute_weight_gradient:
|
if ctx.compute_weight_gradient:
|
||||||
ctx.save_for_backward(x, weight)
|
ctx.save_for_backward(x, weight)
|
||||||
else:
|
else:
|
||||||
ctx.save_for_backward(weight)
|
ctx.save_for_backward(weight)
|
||||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
|
||||||
batch_dim = batch_shape.numel()
|
|
||||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
|
||||||
if process_group is not None:
|
|
||||||
total_x, _ = all_gather_raw(x, process_group)
|
|
||||||
else:
|
|
||||||
total_x = x
|
|
||||||
output = F.linear(total_x, weight, bias)
|
|
||||||
return output if not return_residual else (output, x)
|
return output if not return_residual else (output, x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -188,32 +192,42 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
|||||||
2: recompute gelu_in and gelu_out in the bwd
|
2: recompute gelu_in and gelu_out in the bwd
|
||||||
"""
|
"""
|
||||||
assert -1 <= heuristic <= 4
|
assert -1 <= heuristic <= 4
|
||||||
if torch.is_autocast_enabled():
|
|
||||||
dtype = torch.get_autocast_gpu_dtype()
|
|
||||||
x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
|
|
||||||
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
|
|
||||||
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
|
|
||||||
if not save_pre_act:
|
if not save_pre_act:
|
||||||
checkpoint_lvl = 2
|
checkpoint_lvl = 2
|
||||||
assert checkpoint_lvl in [0, 1, 2]
|
assert checkpoint_lvl in [0, 1, 2]
|
||||||
ctx.return_residual = return_residual
|
ctx.return_residual = return_residual
|
||||||
ctx.process_group = process_group
|
ctx.process_group = process_group
|
||||||
|
ctx.checkpoint_lvl = checkpoint_lvl
|
||||||
|
ctx.heuristic = heuristic
|
||||||
|
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
x = x.to(dtype=torch.get_autocast_gpu_dtype())
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
|
if process_group is not None:
|
||||||
|
# We want to kick off the all_gather early, before weight dtype conversion
|
||||||
|
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||||
|
else:
|
||||||
|
total_x = x
|
||||||
|
|
||||||
|
if torch.is_autocast_enabled():
|
||||||
|
dtype = torch.get_autocast_gpu_dtype()
|
||||||
|
weight1, weight2 = [a.to(dtype=dtype) for a in [weight1, weight2]]
|
||||||
|
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
|
||||||
|
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
|
||||||
weight1 = weight1.contiguous()
|
weight1 = weight1.contiguous()
|
||||||
bias1 = bias1.contiguous() if bias1 is not None else None
|
bias1 = bias1.contiguous() if bias1 is not None else None
|
||||||
weight2 = weight2.contiguous()
|
weight2 = weight2.contiguous()
|
||||||
bias2 = bias2.contiguous() if bias2 is not None else None
|
bias2 = bias2.contiguous() if bias2 is not None else None
|
||||||
if process_group is not None:
|
if process_group is not None:
|
||||||
total_x, _ = all_gather_raw(x, process_group)
|
handle_x.wait()
|
||||||
else:
|
|
||||||
total_x = x
|
|
||||||
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||||
batch_dim = batch_shape.numel()
|
batch_dim = batch_shape.numel()
|
||||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||||
if heuristic == -1:
|
if heuristic == -1:
|
||||||
gelu_in = F.linear(total_x, weight1, bias1)
|
gelu_in = F.linear(total_x, weight1, bias1)
|
||||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||||
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # This is before adding bias1
|
# This is before adding bias1
|
||||||
|
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1)
|
||||||
# with torch.jit.fuser('fuser2'):
|
# with torch.jit.fuser('fuser2'):
|
||||||
# output1 = bias_gelu(gelu_in, bias1)
|
# output1 = bias_gelu(gelu_in, bias1)
|
||||||
else:
|
else:
|
||||||
@ -223,8 +237,6 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
|||||||
if save_pre_act:
|
if save_pre_act:
|
||||||
gelu_in = rest[0]
|
gelu_in = rest[0]
|
||||||
output2 = F.linear(output1, weight2, bias2)
|
output2 = F.linear(output1, weight2, bias2)
|
||||||
ctx.checkpoint_lvl = checkpoint_lvl
|
|
||||||
ctx.heuristic = heuristic
|
|
||||||
if checkpoint_lvl == 0:
|
if checkpoint_lvl == 0:
|
||||||
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
|
ctx.save_for_backward(x, weight1, weight2, gelu_in, output1)
|
||||||
elif checkpoint_lvl == 1:
|
elif checkpoint_lvl == 1:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user