[LayerNorm] Make sure memory addresses are aligned to 16 bytes

This commit is contained in:
Tri Dao 2023-07-04 14:52:42 -07:00
parent 3a9bfd076f
commit d2f4324f4c
2 changed files with 43 additions and 35 deletions

View File

@ -7,9 +7,17 @@ from torch.nn import init
import dropout_layer_norm
def maybe_align(x, alignment_in_bytes=16):
"""Assume that x already has last dim divisible by alignment_in_bytes
"""
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
@ -26,7 +34,7 @@ def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscal
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
dropout_p, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
@ -54,7 +62,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
out_subset, dropout_p, epsilon, rowscale_const,
out_numrows, residual_in_fp32=False, is_rms_norm=False):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
hidden_size = gamma.numel()
x0mat = x0.view((-1, hidden_size))
@ -73,7 +81,7 @@ def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale,
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
x0_subset, out_subset, dropout_p, rowscale_const,
x0_numrows, has_residual, is_rms_norm=False):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
x0 must not be None if we have colscale.
@ -103,7 +111,7 @@ def _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
epsilon, residual_in_fp32=False, is_rms_norm=False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous and aligned to 16 bytes
"""
hidden_size = gamma0.numel()
x0mat = x0.view((-1, hidden_size))
@ -122,7 +130,7 @@ def _dropout_add_layer_norm_parallel_residual_backward(
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
dropout_p, has_x1, has_residual, is_rms_norm=False
):
""" Assume that arguments are contiguous
""" Assume that arguments are contiguous and aligned to 16 bytes
dx == None means that it was a post-norm architecture
(x = drop(x0) + residual was not returned in the fwd).
"""
@ -143,12 +151,12 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
residual = residual.contiguous() if residual is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous() if beta is not None else None
rowscale = rowscale.contiguous() if rowscale is not None else None
colscale = colscale.contiguous() if colscale is not None else None
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
@ -174,8 +182,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
dx = args[0].contiguous() if ctx.prenorm else None
dz = maybe_align(dz.contiguous(), 16) # this happens!
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
@ -196,11 +204,11 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32=False,
prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
residual = residual.contiguous() if residual is not None else None
gamma = gamma.contiguous()
beta = beta.contiguous() if beta is not None else None
colscale = colscale.contiguous() if colscale is not None else None
x0 = maybe_align(x0.contiguous(), 16)
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma = maybe_align(gamma.contiguous(), 16)
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
@ -231,8 +239,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
@staticmethod
def backward(ctx, dz, *args):
# assert dz.is_contiguous()
dz = dz.contiguous() # this happens!
dx = args[0].contiguous() if ctx.prenorm else None
dz = maybe_align(dz.contiguous(), 16) # this happens!
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
# x0 is None if colscale is None
dropout_p = ctx.dropout_p
@ -252,13 +260,13 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
x0 = x0.contiguous()
x1 = x1.contiguous() if x1 is not None else None
residual = residual.contiguous() if residual is not None else None
gamma0 = gamma0.contiguous()
beta0 = beta0.contiguous() if beta0 is not None else None
gamma1 = gamma1.contiguous() if gamma1 is not None else None
beta1 = beta1.contiguous() if beta1 is not None else None
x0 = maybe_align(x0.contiguous(), 16)
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
gamma0 = maybe_align(gamma0.contiguous(), 16)
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
residual_in_fp32, is_rms_norm
@ -284,9 +292,9 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
@staticmethod
def backward(ctx, dz0, dz1, *args):
dz0 = dz0.contiguous() # this happens!
dz1 = dz1.contiguous() if dz1 is not None else None
dx = args[0].contiguous() if ctx.prenorm else None
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
dropout_p = ctx.dropout_p
has_x1 = ctx.has_x1

View File

@ -99,7 +99,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale, layerscale=colscale,
model.eps, rowscale=rowscale, layerscale=colscale,
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
assert out.dtype == input_dtype
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
@ -251,7 +251,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
model_ref.bias.copy_(model_pt.bias)
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
model.epsilon, rowscale=rowscale,
model.eps, rowscale=rowscale,
layerscale=colscale, prenorm=True,
residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
@ -412,7 +412,7 @@ def test_dropout_layer_norm_subset_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, dmask = dropout_add_layer_norm_subset(
x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)
@ -532,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training(
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
out, residual, dmask = dropout_add_layer_norm_subset(
x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
return_dropout_mask=True)