[LayerNorm] Make sure memory addresses are aligned to 16 bytes
This commit is contained in:
parent
3a9bfd076f
commit
d2f4324f4c
@ -7,9 +7,17 @@ from torch.nn import init
|
|||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_align(x, alignment_in_bytes=16):
|
||||||
|
"""Assume that x already has last dim divisible by alignment_in_bytes
|
||||||
|
"""
|
||||||
|
# TD [2023-07-04] I'm not 100% sure that clone will align the memory
|
||||||
|
# https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440
|
||||||
|
return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone()
|
||||||
|
|
||||||
|
|
||||||
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
|
def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p,
|
||||||
epsilon, residual_in_fp32=False, is_rms_norm=False):
|
epsilon, residual_in_fp32=False, is_rms_norm=False):
|
||||||
""" Assume that arguments are contiguous
|
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||||
"""
|
"""
|
||||||
hidden_size = gamma.numel()
|
hidden_size = gamma.numel()
|
||||||
x0mat = x0.view((-1, hidden_size))
|
x0mat = x0.view((-1, hidden_size))
|
||||||
@ -26,7 +34,7 @@ def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscal
|
|||||||
|
|
||||||
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
|
def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale,
|
||||||
dropout_p, has_residual, is_rms_norm=False):
|
dropout_p, has_residual, is_rms_norm=False):
|
||||||
""" Assume that arguments are contiguous
|
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||||
dx == None means that it was a post-norm architecture
|
dx == None means that it was a post-norm architecture
|
||||||
(x = drop(x0) + residual was not returned in the fwd).
|
(x = drop(x0) + residual was not returned in the fwd).
|
||||||
x0 must not be None if we have colscale.
|
x0 must not be None if we have colscale.
|
||||||
@ -54,7 +62,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro
|
|||||||
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
|
def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset,
|
||||||
out_subset, dropout_p, epsilon, rowscale_const,
|
out_subset, dropout_p, epsilon, rowscale_const,
|
||||||
out_numrows, residual_in_fp32=False, is_rms_norm=False):
|
out_numrows, residual_in_fp32=False, is_rms_norm=False):
|
||||||
""" Assume that arguments are contiguous
|
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||||
"""
|
"""
|
||||||
hidden_size = gamma.numel()
|
hidden_size = gamma.numel()
|
||||||
x0mat = x0.view((-1, hidden_size))
|
x0mat = x0.view((-1, hidden_size))
|
||||||
@ -73,7 +81,7 @@ def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale,
|
|||||||
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
|
def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale,
|
||||||
x0_subset, out_subset, dropout_p, rowscale_const,
|
x0_subset, out_subset, dropout_p, rowscale_const,
|
||||||
x0_numrows, has_residual, is_rms_norm=False):
|
x0_numrows, has_residual, is_rms_norm=False):
|
||||||
""" Assume that arguments are contiguous
|
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||||
dx == None means that it was a post-norm architecture
|
dx == None means that it was a post-norm architecture
|
||||||
(x = drop(x0) + residual was not returned in the fwd).
|
(x = drop(x0) + residual was not returned in the fwd).
|
||||||
x0 must not be None if we have colscale.
|
x0 must not be None if we have colscale.
|
||||||
@ -103,7 +111,7 @@ def _dropout_add_layer_norm_parallel_residual_forward(
|
|||||||
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
|
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p,
|
||||||
epsilon, residual_in_fp32=False, is_rms_norm=False
|
epsilon, residual_in_fp32=False, is_rms_norm=False
|
||||||
):
|
):
|
||||||
""" Assume that arguments are contiguous
|
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||||
"""
|
"""
|
||||||
hidden_size = gamma0.numel()
|
hidden_size = gamma0.numel()
|
||||||
x0mat = x0.view((-1, hidden_size))
|
x0mat = x0.view((-1, hidden_size))
|
||||||
@ -122,7 +130,7 @@ def _dropout_add_layer_norm_parallel_residual_backward(
|
|||||||
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
|
dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1,
|
||||||
dropout_p, has_x1, has_residual, is_rms_norm=False
|
dropout_p, has_x1, has_residual, is_rms_norm=False
|
||||||
):
|
):
|
||||||
""" Assume that arguments are contiguous
|
""" Assume that arguments are contiguous and aligned to 16 bytes
|
||||||
dx == None means that it was a post-norm architecture
|
dx == None means that it was a post-norm architecture
|
||||||
(x = drop(x0) + residual was not returned in the fwd).
|
(x = drop(x0) + residual was not returned in the fwd).
|
||||||
"""
|
"""
|
||||||
@ -143,12 +151,12 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||||
x0 = x0.contiguous()
|
x0 = maybe_align(x0.contiguous(), 16)
|
||||||
residual = residual.contiguous() if residual is not None else None
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||||
gamma = gamma.contiguous()
|
gamma = maybe_align(gamma.contiguous(), 16)
|
||||||
beta = beta.contiguous() if beta is not None else None
|
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
||||||
rowscale = rowscale.contiguous() if rowscale is not None else None
|
rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None
|
||||||
colscale = colscale.contiguous() if colscale is not None else None
|
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
||||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward(
|
||||||
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon,
|
||||||
residual_in_fp32, is_rms_norm
|
residual_in_fp32, is_rms_norm
|
||||||
@ -174,8 +182,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dz, *args):
|
def backward(ctx, dz, *args):
|
||||||
# assert dz.is_contiguous()
|
# assert dz.is_contiguous()
|
||||||
dz = dz.contiguous() # this happens!
|
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
||||||
dx = args[0].contiguous() if ctx.prenorm else None
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||||
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
|
x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors
|
||||||
# x0 is None if colscale is None
|
# x0 is None if colscale is None
|
||||||
dropout_p = ctx.dropout_p
|
dropout_p = ctx.dropout_p
|
||||||
@ -196,11 +204,11 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
|||||||
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||||
rowscale_const, out_numrows, residual_in_fp32=False,
|
rowscale_const, out_numrows, residual_in_fp32=False,
|
||||||
prenorm=False, is_rms_norm=False, return_dmask=False):
|
prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||||
x0 = x0.contiguous()
|
x0 = maybe_align(x0.contiguous(), 16)
|
||||||
residual = residual.contiguous() if residual is not None else None
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||||
gamma = gamma.contiguous()
|
gamma = maybe_align(gamma.contiguous(), 16)
|
||||||
beta = beta.contiguous() if beta is not None else None
|
beta = maybe_align(beta.contiguous(), 16) if beta is not None else None
|
||||||
colscale = colscale.contiguous() if colscale is not None else None
|
colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None
|
||||||
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward(
|
||||||
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon,
|
||||||
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
|
rowscale_const, out_numrows, residual_in_fp32, is_rms_norm
|
||||||
@ -231,8 +239,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dz, *args):
|
def backward(ctx, dz, *args):
|
||||||
# assert dz.is_contiguous()
|
# assert dz.is_contiguous()
|
||||||
dz = dz.contiguous() # this happens!
|
dz = maybe_align(dz.contiguous(), 16) # this happens!
|
||||||
dx = args[0].contiguous() if ctx.prenorm else None
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||||
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
|
x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors
|
||||||
# x0 is None if colscale is None
|
# x0 is None if colscale is None
|
||||||
dropout_p = ctx.dropout_p
|
dropout_p = ctx.dropout_p
|
||||||
@ -252,13 +260,13 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||||
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False):
|
||||||
x0 = x0.contiguous()
|
x0 = maybe_align(x0.contiguous(), 16)
|
||||||
x1 = x1.contiguous() if x1 is not None else None
|
x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None
|
||||||
residual = residual.contiguous() if residual is not None else None
|
residual = maybe_align(residual.contiguous(), 16) if residual is not None else None
|
||||||
gamma0 = gamma0.contiguous()
|
gamma0 = maybe_align(gamma0.contiguous(), 16)
|
||||||
beta0 = beta0.contiguous() if beta0 is not None else None
|
beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None
|
||||||
gamma1 = gamma1.contiguous() if gamma1 is not None else None
|
gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None
|
||||||
beta1 = beta1.contiguous() if beta1 is not None else None
|
beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None
|
||||||
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
|
z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward(
|
||||||
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon,
|
||||||
residual_in_fp32, is_rms_norm
|
residual_in_fp32, is_rms_norm
|
||||||
@ -284,9 +292,9 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, dz0, dz1, *args):
|
def backward(ctx, dz0, dz1, *args):
|
||||||
dz0 = dz0.contiguous() # this happens!
|
dz0 = maybe_align(dz0.contiguous(), 16) # this happens!
|
||||||
dz1 = dz1.contiguous() if dz1 is not None else None
|
dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None
|
||||||
dx = args[0].contiguous() if ctx.prenorm else None
|
dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None
|
||||||
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
|
x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors
|
||||||
dropout_p = ctx.dropout_p
|
dropout_p = ctx.dropout_p
|
||||||
has_x1 = ctx.has_x1
|
has_x1 = ctx.has_x1
|
||||||
|
|||||||
@ -99,7 +99,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w
|
|||||||
model_ref.bias.copy_(model_pt.bias)
|
model_ref.bias.copy_(model_pt.bias)
|
||||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||||
out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
|
out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
|
||||||
model.epsilon, rowscale=rowscale, layerscale=colscale,
|
model.eps, rowscale=rowscale, layerscale=colscale,
|
||||||
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
|
residual_in_fp32=residual_in_fp32, return_dropout_mask=True)
|
||||||
assert out.dtype == input_dtype
|
assert out.dtype == input_dtype
|
||||||
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
|
print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}')
|
||||||
@ -251,7 +251,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_
|
|||||||
model_ref.bias.copy_(model_pt.bias)
|
model_ref.bias.copy_(model_pt.bias)
|
||||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||||
out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
|
out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p,
|
||||||
model.epsilon, rowscale=rowscale,
|
model.eps, rowscale=rowscale,
|
||||||
layerscale=colscale, prenorm=True,
|
layerscale=colscale, prenorm=True,
|
||||||
residual_in_fp32=residual_in_fp32,
|
residual_in_fp32=residual_in_fp32,
|
||||||
return_dropout_mask=True)
|
return_dropout_mask=True)
|
||||||
@ -412,7 +412,7 @@ def test_dropout_layer_norm_subset_training(
|
|||||||
|
|
||||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||||
out, dmask = dropout_add_layer_norm_subset(
|
out, dmask = dropout_add_layer_norm_subset(
|
||||||
x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
|
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
|
||||||
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
|
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
|
||||||
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
|
out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32,
|
||||||
return_dropout_mask=True)
|
return_dropout_mask=True)
|
||||||
@ -532,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training(
|
|||||||
|
|
||||||
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
|
||||||
out, residual, dmask = dropout_add_layer_norm_subset(
|
out, residual, dmask = dropout_add_layer_norm_subset(
|
||||||
x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale,
|
x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale,
|
||||||
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
|
x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale,
|
||||||
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
|
out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32,
|
||||||
return_dropout_mask=True)
|
return_dropout_mask=True)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user