diff --git a/flash_attn/ops/layer_norm.py b/flash_attn/ops/layer_norm.py index c5e4a27..c42ea90 100644 --- a/flash_attn/ops/layer_norm.py +++ b/flash_attn/ops/layer_norm.py @@ -7,9 +7,17 @@ from torch.nn import init import dropout_layer_norm +def maybe_align(x, alignment_in_bytes=16): + """Assume that x already has last dim divisible by alignment_in_bytes + """ + # TD [2023-07-04] I'm not 100% sure that clone will align the memory + # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 + return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() + + def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, is_rms_norm=False): - """ Assume that arguments are contiguous + """ Assume that arguments are contiguous and aligned to 16 bytes """ hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) @@ -26,7 +34,7 @@ def _dropout_add_layer_norm_forward(x0, residual, gamma, beta, rowscale, colscal def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, rowscale, colscale, dropout_p, has_residual, is_rms_norm=False): - """ Assume that arguments are contiguous + """ Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. @@ -54,7 +62,7 @@ def _dropout_add_layer_norm_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, ro def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32=False, is_rms_norm=False): - """ Assume that arguments are contiguous + """ Assume that arguments are contiguous and aligned to 16 bytes """ hidden_size = gamma.numel() x0mat = x0.view((-1, hidden_size)) @@ -73,7 +81,7 @@ def _dropout_add_layer_norm_subset_forward(x0, residual, gamma, beta, colscale, def _dropout_add_layer_norm_subset_backward(dz, dx, x, x0, dmask, mu, rsigma, gamma, colscale, x0_subset, out_subset, dropout_p, rowscale_const, x0_numrows, has_residual, is_rms_norm=False): - """ Assume that arguments are contiguous + """ Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). x0 must not be None if we have colscale. @@ -103,7 +111,7 @@ def _dropout_add_layer_norm_parallel_residual_forward( x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32=False, is_rms_norm=False ): - """ Assume that arguments are contiguous + """ Assume that arguments are contiguous and aligned to 16 bytes """ hidden_size = gamma0.numel() x0mat = x0.view((-1, hidden_size)) @@ -122,7 +130,7 @@ def _dropout_add_layer_norm_parallel_residual_backward( dz0, dz1, dx, x, dmask0, dmask1, mu, rsigma, gamma0, gamma1, dropout_p, has_x1, has_residual, is_rms_norm=False ): - """ Assume that arguments are contiguous + """ Assume that arguments are contiguous and aligned to 16 bytes dx == None means that it was a post-norm architecture (x = drop(x0) + residual was not returned in the fwd). """ @@ -143,12 +151,12 @@ class DropoutAddLayerNormFn(torch.autograd.Function): @staticmethod def forward(ctx, x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): - x0 = x0.contiguous() - residual = residual.contiguous() if residual is not None else None - gamma = gamma.contiguous() - beta = beta.contiguous() if beta is not None else None - rowscale = rowscale.contiguous() if rowscale is not None else None - colscale = colscale.contiguous() if colscale is not None else None + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( x0, residual, gamma, beta, rowscale, colscale, dropout_p, epsilon, residual_in_fp32, is_rms_norm @@ -174,8 +182,8 @@ class DropoutAddLayerNormFn(torch.autograd.Function): @staticmethod def backward(ctx, dz, *args): # assert dz.is_contiguous() - dz = dz.contiguous() # this happens! - dx = args[0].contiguous() if ctx.prenorm else None + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors # x0 is None if colscale is None dropout_p = ctx.dropout_p @@ -196,11 +204,11 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): def forward(ctx, x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): - x0 = x0.contiguous() - residual = residual.contiguous() if residual is not None else None - gamma = gamma.contiguous() - beta = beta.contiguous() if beta is not None else None - colscale = colscale.contiguous() if colscale is not None else None + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( x0, residual, gamma, beta, colscale, x0_subset, out_subset, dropout_p, epsilon, rowscale_const, out_numrows, residual_in_fp32, is_rms_norm @@ -231,8 +239,8 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function): @staticmethod def backward(ctx, dz, *args): # assert dz.is_contiguous() - dz = dz.contiguous() # this happens! - dx = args[0].contiguous() if ctx.prenorm else None + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors # x0 is None if colscale is None dropout_p = ctx.dropout_p @@ -252,13 +260,13 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): @staticmethod def forward(ctx, x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32=False, prenorm=False, is_rms_norm=False, return_dmask=False): - x0 = x0.contiguous() - x1 = x1.contiguous() if x1 is not None else None - residual = residual.contiguous() if residual is not None else None - gamma0 = gamma0.contiguous() - beta0 = beta0.contiguous() if beta0 is not None else None - gamma1 = gamma1.contiguous() if gamma1 is not None else None - beta1 = beta1.contiguous() if beta1 is not None else None + x0 = maybe_align(x0.contiguous(), 16) + x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma0 = maybe_align(gamma0.contiguous(), 16) + beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None + gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None + beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None z0mat, z1mat, xmat, dmask0, dmask1, mu, rsigma = _dropout_add_layer_norm_parallel_residual_forward( x0, x1, residual, gamma0, beta0, gamma1, beta1, dropout_p, epsilon, residual_in_fp32, is_rms_norm @@ -284,9 +292,9 @@ class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): @staticmethod def backward(ctx, dz0, dz1, *args): - dz0 = dz0.contiguous() # this happens! - dz1 = dz1.contiguous() if dz1 is not None else None - dx = args[0].contiguous() if ctx.prenorm else None + dz0 = maybe_align(dz0.contiguous(), 16) # this happens! + dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors dropout_p = ctx.dropout_p has_x1 = ctx.has_x1 diff --git a/tests/ops/test_dropout_layer_norm.py b/tests/ops/test_dropout_layer_norm.py index 5b7a013..f72a2a8 100644 --- a/tests/ops/test_dropout_layer_norm.py +++ b/tests/ops/test_dropout_layer_norm.py @@ -99,7 +99,7 @@ def test_dropout_layer_norm_training(hidden_size, input_dtype, residual_dtype, w model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p, - model.epsilon, rowscale=rowscale, layerscale=colscale, + model.eps, rowscale=rowscale, layerscale=colscale, residual_in_fp32=residual_in_fp32, return_dropout_mask=True) assert out.dtype == input_dtype print(f'Actual dropout fraction: {1 - dmask.float().mean().item()}') @@ -251,7 +251,7 @@ def test_dropout_layer_norm_prenorm_training(hidden_size, input_dtype, residual_ model_ref.bias.copy_(model_pt.bias) residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, residual, dmask = our_layer_norm_func(x0, res, model.weight, model.bias, model.p, - model.epsilon, rowscale=rowscale, + model.eps, rowscale=rowscale, layerscale=colscale, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True) @@ -412,7 +412,7 @@ def test_dropout_layer_norm_subset_training( residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, dmask = dropout_add_layer_norm_subset( - x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale, + x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale, x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale, out_numrows = out_numrows, prenorm=False, residual_in_fp32=residual_in_fp32, return_dropout_mask=True) @@ -532,7 +532,7 @@ def test_dropout_layer_norm_subset_prenorm_training( residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32 out, residual, dmask = dropout_add_layer_norm_subset( - x0, res, model.weight, model.bias, model.p, model.epsilon, layerscale=colscale, + x0, res, model.weight, model.bias, model.p, model.eps, layerscale=colscale, x0_subset=x0_subset, out_subset=out_subset, rowscale_const=drop_path_scale, out_numrows = out_numrows, prenorm=True, residual_in_fp32=residual_in_fp32, return_dropout_mask=True)