[LayerNorm] Make sure memory addresses are aligned to 16 bytes
This commit is contained in:
parent
3a9bfd076f
commit
d2f4324f4c
@ -7,9 +7,17 @@ from torch.nn import init
|
||||
import dropout_layer_norm
|
||||
|
||||
|
||||
def maybe_align(x, alignment_in_bytes=16):
|
||||
"""Assume that x already has last dim divisible by alignment_in_bytes
|
||||
"""
|
||||
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
|
||||
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
|
||||
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
|
||||
|
||||
|
||||
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
|
||||
epsilon, residual_in_fp32=False, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
@ -26,7 +34,7 @@ def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscal
|
||||
|
||||
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
|
||||
dropout_p, has_residual, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
@ -54,7 +62,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
|
||||
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
|
||||
out_subset, dropout_p, epsilon, rowscale_const,
|
||||
out_numrows, residual_in_fp32=False, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
hidden_size = gamma.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
@ -73,7 +81,7 @@ def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale,
|
||||
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
|
||||
x0_subset, out_subset, dropout_p, rowscale_const,
|
||||
x0_numrows, has_residual, is_rms_norm=False):
|
||||
""" Assume that arguments are contiguous
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
x0 must not be None if we have colscale.
|
||||
@ -103,7 +111,7 @@ def _dropout_add_layer_norm_parallel_residual_forward(
|
||||
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
|
||||
epsilon, residual_in_fp32=False, is_rms_norm=False
|
||||
):
|
||||
""" Assume that arguments are contiguous
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
"""
|
||||
hidden_size = gamma0.numel()
|
||||
x0mat = x0.view((-1, hidden_size))
|
||||
@ -122,7 +130,7 @@ def _dropout_add_layer_norm_parallel_residual_backward(
|
||||
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
|
||||
dropout_p, has_x1, has_residual, is_rms_norm=False
|
||||
):
|
||||
""" Assume that arguments are contiguous
|
||||
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||
dx == None means that it was a post-norm architecture
|
||||
(x = drop(x0) + residual was not returned in the fwd).
|
||||
"""
|
||||
@ -143,12 +151,12 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
residual = residual.contiguous() if residual is not None else None
|
||||
gamma = gamma.contiguous()
|
||||
beta = beta.contiguous() if beta is not None else None
|
||||
rowscale = rowscale.contiguous() if rowscale is not None else None
|
||||
colscale = colscale.contiguous() if colscale is not None else None
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
gamma = maybe_align(gamma.contiguous(), 16)
|
||||
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
||||
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
|
||||
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||
residual_in_fp32, is_rms_norm
|
||||
@ -174,8 +182,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dz, *args):
|
||||
# assert dz.is_contiguous()
|
||||
dz = dz.contiguous() # this happens!
|
||||
dx = args[0].contiguous() if ctx.prenorm else None
|
||||
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
||||
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
|
||||
# x0 is None if colscale is None
|
||||
dropout_p = ctx.dropout_p
|
||||
@ -196,11 +204,11 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32=False,
|
||||
prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
residual = residual.contiguous() if residual is not None else None
|
||||
gamma = gamma.contiguous()
|
||||
beta = beta.contiguous() if beta is not None else None
|
||||
colscale = colscale.contiguous() if colscale is not None else None
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
gamma = maybe_align(gamma.contiguous(), 16)
|
||||
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
||||
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
||||
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
|
||||
@ -231,8 +239,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, dz, *args):
|
||||
# assert dz.is_contiguous()
|
||||
dz = dz.contiguous() # this happens!
|
||||
dx = args[0].contiguous() if ctx.prenorm else None
|
||||
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
||||
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
|
||||
# x0 is None if colscale is None
|
||||
dropout_p = ctx.dropout_p
|
||||
@ -252,13 +260,13 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||
x0 = x0.contiguous()
|
||||
x1 = x1.contiguous() if x1 is not None else None
|
||||
residual = residual.contiguous() if residual is not None else None
|
||||
gamma0 = gamma0.contiguous()
|
||||
beta0 = beta0.contiguous() if beta0 is not None else None
|
||||
gamma1 = gamma1.contiguous() if gamma1 is not None else None
|
||||
beta1 = beta1.contiguous() if beta1 is not None else None
|
||||
x0 = maybe_align(x0.contiguous(), 16)
|
||||
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
|
||||
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||
gamma0 = maybe_align(gamma0.contiguous(), 16)
|
||||
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
|
||||
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
|
||||
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
|
||||
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
|
||||
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||
residual_in_fp32, is_rms_norm
|
||||
@ -284,9 +292,9 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dz0, dz1, *args):
|
||||
dz0 = dz0.contiguous() # this happens!
|
||||
dz1 = dz1.contiguous() if dz1 is not None else None
|
||||
dx = args[0].contiguous() if ctx.prenorm else None
|
||||
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
|
||||
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
|
||||
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
|
||||
dropout_p = ctx.dropout_p
|
||||
has_x1 = ctx.has_x1
|
||||
|
||||
@ -99,7 +99,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
|
||||
model_ref.bias.copy_(model_pt.bias)
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
|
||||
model.epsilon, rowscale=rowscale, layerscale=colscale,
|
||||
model.eps, rowscale=rowscale, layerscale=colscale,
|
||||
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
|
||||
assert out.dtype == input_dtype
|
||||
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
|
||||
@ -251,7 +251,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
|
||||
model_ref.bias.copy_(model_pt.bias)
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
|
||||
model.epsilon, rowscale=rowscale,
|
||||
model.eps, rowscale=rowscale,
|
||||
layerscale=colscale, prenorm=True,
|
||||
residual_in_fp32=residual_in_fp32,
|
||||
return_dropout_mask=True)
|
||||
@ -412,7 +412,7 @@ def test_dropout_layer_norm_subset_training(
|
||||
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
out, dmask = dropout_add_layer_norm_subset(
|
||||
x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
|
||||
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
|
||||
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
|
||||
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
|
||||
return_dropout_mask=True)
|
||||
@ -532,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training(
|
||||
|
||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||
out, residual, dmask = dropout_add_layer_norm_subset(
|
||||
x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
|
||||
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
|
||||
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
|
||||
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
|
||||
return_dropout_mask=True)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user