Implement TensorParallel for FusedDense and FusedDenseGeluDense
This commit is contained in:
parent
dff68c2b22
commit
226a1b721d
@ -2,6 +2,7 @@
|
|||||||
// We make it work for bfloat16
|
// We make it work for bfloat16
|
||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
#include <torch/torch.h>
|
#include <torch/torch.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
@ -50,6 +51,10 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
|
|||||||
CHECK_SHAPE(input, batch_size, in_features);
|
CHECK_SHAPE(input, batch_size, in_features);
|
||||||
CHECK_SHAPE(d_output, batch_size, out_features);
|
CHECK_SHAPE(d_output, batch_size, out_features);
|
||||||
|
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||||
|
|
||||||
// create output/workspace tensor
|
// create output/workspace tensor
|
||||||
auto opts = input.options();
|
auto opts = input.options();
|
||||||
auto d_weight = at::empty({out_features, in_features}, opts);
|
auto d_weight = at::empty({out_features, in_features}, opts);
|
||||||
@ -104,6 +109,10 @@ std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
|
|||||||
CHECK_SHAPE(bias, out_features);
|
CHECK_SHAPE(bias, out_features);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||||
|
|
||||||
// create output/workspace tensor
|
// create output/workspace tensor
|
||||||
auto opts = input.options();
|
auto opts = input.options();
|
||||||
auto output = at::empty({batch_size, out_features}, opts);
|
auto output = at::empty({batch_size, out_features}, opts);
|
||||||
@ -153,6 +162,10 @@ std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
|
|||||||
CHECK_SHAPE(d_output, batch_size, out_features);
|
CHECK_SHAPE(d_output, batch_size, out_features);
|
||||||
CHECK_SHAPE(gelu_in, batch_size, in_features);
|
CHECK_SHAPE(gelu_in, batch_size, in_features);
|
||||||
|
|
||||||
|
// Otherwise the kernel will be launched from cuda:0 device
|
||||||
|
// Cast to char to avoid compiler warning about narrowing
|
||||||
|
at::cuda::CUDAGuard device_guard{(char)weight.get_device()};
|
||||||
|
|
||||||
// create output/workspace tensor
|
// create output/workspace tensor
|
||||||
auto opts = weight.options();
|
auto opts = weight.options();
|
||||||
auto d_bias = at::empty({in_features}, opts);
|
auto d_bias = at::empty({in_features}, opts);
|
||||||
|
|||||||
@ -5,9 +5,9 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from flash_attn.ops.fused_dense import FusedDenseGeluDense
|
from flash_attn.ops.fused_dense import FusedDenseGeluDense, ParallelFusedDenseGeluDense
|
||||||
except ImportError:
|
except ImportError:
|
||||||
FusedDenseGeluDense = None
|
FusedDenseGeluDense, ParallelFusedDenseGeluDense = None, None
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
class Mlp(nn.Module):
|
||||||
|
|||||||
@ -1,35 +1,55 @@
|
|||||||
# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
|
# Copyright (c) 2022, Tri Dao.
|
||||||
|
# Inspired by https://github.com/NVIDIA/apex/blob/master/apex/fused_dense/fused_dense.py
|
||||||
# We make it work with pytorch amp and with bfloat16.
|
# We make it work with pytorch amp and with bfloat16.
|
||||||
|
# The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
|
|
||||||
# import fused_dense_cuda # from apex
|
# import fused_dense_cuda # from apex
|
||||||
import fused_dense_lib as fused_dense_cuda
|
import fused_dense_lib as fused_dense_cuda
|
||||||
|
|
||||||
from flash_attn.ops.gelu_activation import gelu_bwd
|
from flash_attn.ops.gelu_activation import gelu_bwd
|
||||||
|
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, reduce_scatter
|
||||||
|
|
||||||
|
|
||||||
class FusedDenseFunc(torch.autograd.Function):
|
class FusedDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x, weight, bias, return_residual=False):
|
def forward(ctx, x, weight, bias, return_residual=False, process_group=None):
|
||||||
|
"""
|
||||||
|
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||||
|
we do an all_gather_raw of x before doing the matmul.
|
||||||
|
"""
|
||||||
if torch.is_autocast_enabled():
|
if torch.is_autocast_enabled():
|
||||||
dtype = torch.get_autocast_gpu_dtype()
|
dtype = torch.get_autocast_gpu_dtype()
|
||||||
x, weight = [a.to(dtype=dtype) for a in [x, weight]]
|
x, weight = [a.to(dtype=dtype) for a in [x, weight]]
|
||||||
bias = bias.to(dtype=dtype) if bias is not None else None
|
bias = bias.to(dtype=dtype) if bias is not None else None
|
||||||
|
|
||||||
ctx.return_residual = return_residual
|
ctx.return_residual = return_residual
|
||||||
|
ctx.process_group = process_group
|
||||||
|
ctx.compute_weight_gradient = weight.requires_grad
|
||||||
|
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
weight = weight.contiguous()
|
weight = weight.contiguous()
|
||||||
|
if ctx.compute_weight_gradient:
|
||||||
ctx.save_for_backward(x, weight)
|
ctx.save_for_backward(x, weight)
|
||||||
|
else:
|
||||||
|
ctx.save_for_backward(weight)
|
||||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
batch_shape, n = x.shape[:-1], x.shape[-1]
|
||||||
batch_dim = batch_shape.numel()
|
batch_dim = batch_shape.numel()
|
||||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||||
output = F.linear(x, weight, bias)
|
if process_group is not None:
|
||||||
|
total_x, _ = all_gather_raw(x, process_group)
|
||||||
|
else:
|
||||||
|
total_x = x
|
||||||
|
output = F.linear(total_x, weight, bias)
|
||||||
return output if not return_residual else (output, x)
|
return output if not return_residual else (output, x)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -39,37 +59,56 @@ class FusedDenseFunc(torch.autograd.Function):
|
|||||||
if ctx.return_residual:
|
if ctx.return_residual:
|
||||||
grad_input, = args
|
grad_input, = args
|
||||||
grad_input = grad_input.contiguous()
|
grad_input = grad_input.contiguous()
|
||||||
|
process_group = ctx.process_group
|
||||||
|
if ctx.compute_weight_gradient:
|
||||||
x, weight = ctx.saved_tensors
|
x, weight = ctx.saved_tensors
|
||||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
if process_group is not None:
|
||||||
|
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||||
|
else:
|
||||||
|
total_x = x
|
||||||
|
else:
|
||||||
|
weight, = ctx.saved_tensors
|
||||||
|
total_x = None
|
||||||
|
batch_shape = grad_output.shape[:-1]
|
||||||
batch_dim = batch_shape.numel()
|
batch_dim = batch_shape.numel()
|
||||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||||
if ctx.needs_input_grad[1]:
|
|
||||||
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
|
||||||
x.reshape(batch_dim, n), grad_output, ctx.needs_input_grad[2]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
grad_weight = None
|
|
||||||
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
if not ctx.return_residual:
|
if not ctx.return_residual:
|
||||||
grad_input = F.linear(grad_output, weight.t())
|
grad_input = F.linear(grad_output, weight.t())
|
||||||
else:
|
else:
|
||||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_output, weight)
|
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||||
grad_input = grad_input.reshape_as(x)
|
grad_output, weight)
|
||||||
|
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||||
|
if process_group is not None:
|
||||||
|
grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group,
|
||||||
|
async_op=True)
|
||||||
else:
|
else:
|
||||||
grad_input = None
|
grad_input = None
|
||||||
return grad_input, grad_weight, grad_bias, None
|
if ctx.needs_input_grad[1]:
|
||||||
|
assert ctx.compute_weight_gradient
|
||||||
|
if process_group is not None:
|
||||||
|
handle_x.wait()
|
||||||
|
grad_weight, grad_bias = fused_dense_cuda.linear_bias_wgrad(
|
||||||
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_output, ctx.needs_input_grad[2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight = None
|
||||||
|
grad_bias = grad_output if ctx.needs_input_grad[2] else None
|
||||||
|
if process_group is not None and ctx.needs_input_grad[0]:
|
||||||
|
handle_grad_input.wait()
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None
|
||||||
|
|
||||||
|
|
||||||
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
|
def fused_dense_func(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None,
|
||||||
return_residual: bool = False):
|
return_residual: bool = False, process_group: Optional[ProcessGroup] = None):
|
||||||
batch_dim = x.shape[:-1].numel()
|
batch_dim = x.shape[:-1].numel()
|
||||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||||
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
or (x.dtype == torch.float32 and torch.is_autocast_enabled()))
|
||||||
if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024
|
if (x.is_cuda and weight.is_cuda and (bias is None or bias.is_cuda) and batch_dim <= 64 * 1024
|
||||||
and dtype_eligible):
|
and dtype_eligible):
|
||||||
return FusedDenseFunc.apply(x, weight, bias, return_residual)
|
return FusedDenseFunc.apply(x, weight, bias, return_residual, process_group)
|
||||||
else:
|
else:
|
||||||
|
assert process_group is None
|
||||||
out = F.linear(x, weight, bias)
|
out = F.linear(x, weight, bias)
|
||||||
return out if not return_residual else (out, x)
|
return out if not return_residual else (out, x)
|
||||||
|
|
||||||
@ -81,17 +120,69 @@ class FusedDense(nn.Linear):
|
|||||||
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype)
|
||||||
self.return_residual = return_residual
|
self.return_residual = return_residual
|
||||||
|
|
||||||
|
def forward(self, x, process_group=None):
|
||||||
|
"""
|
||||||
|
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||||
|
we do an all_gather of x before doing the matmul.
|
||||||
|
"""
|
||||||
|
return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual,
|
||||||
|
process_group=process_group)
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnParallelLinear(nn.Linear):
|
||||||
|
|
||||||
|
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
|
||||||
|
bias: bool = True, device=None, dtype=None) -> None:
|
||||||
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
|
if out_features % world_size != 0:
|
||||||
|
raise ValueError(f'out_features ({out_features}) must be divisible by '
|
||||||
|
f'world_size ({world_size})')
|
||||||
|
super().__init__(in_features, out_features // world_size, bias=bias,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return fused_dense_func(x, self.weight, self.bias, return_residual=self.return_residual)
|
"""
|
||||||
|
We're doing Tensor Parallel with sequence parallelism: we do an all_gather of
|
||||||
|
x before doing the matmul.
|
||||||
|
"""
|
||||||
|
return fused_dense_func(x, self.weight, self.bias, process_group=self.process_group)
|
||||||
|
|
||||||
|
|
||||||
|
class RowParallelLinear(nn.Linear):
|
||||||
|
|
||||||
|
def __init__(self, in_features: int, out_features: int, process_group: ProcessGroup,
|
||||||
|
bias: bool = True, device=None, dtype=None) -> None:
|
||||||
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
|
rank = torch.distributed.get_rank(process_group)
|
||||||
|
if in_features % world_size != 0:
|
||||||
|
raise ValueError(f'in_features ({in_features}) must be divisible by '
|
||||||
|
f'world_size ({world_size})')
|
||||||
|
# Only rank 0 will have bias
|
||||||
|
super().__init__(in_features // world_size, out_features, bias=bias and rank == 0,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
We're doing Tensor Parallel with sequence parallelism: we do the matmul and then
|
||||||
|
a reduce_scatter of the result.
|
||||||
|
"""
|
||||||
|
out = fused_dense_func(x, self.weight, self.bias)
|
||||||
|
return reduce_scatter(out, self.process_group)
|
||||||
|
|
||||||
|
|
||||||
class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@custom_fwd
|
@custom_fwd
|
||||||
def forward(ctx, x, weight1, bias1, weight2, bias2, save_gelu_in=True, return_residual=False,
|
def forward(ctx, x, weight1, bias1, weight2, bias2, save_pre_act=True, return_residual=False,
|
||||||
checkpoint_lvl=0, heuristic=0):
|
checkpoint_lvl=0, heuristic=0, process_group=None):
|
||||||
"""checkpoint_lvl:
|
"""
|
||||||
|
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||||
|
we do an all_gather of x before doing the matmul.
|
||||||
|
|
||||||
|
checkpoint_lvl:
|
||||||
0: no recomputation in the bwd
|
0: no recomputation in the bwd
|
||||||
1: recompute gelu_out in the bwd
|
1: recompute gelu_out in the bwd
|
||||||
2: recompute gelu_in and gelu_out in the bwd
|
2: recompute gelu_in and gelu_out in the bwd
|
||||||
@ -102,28 +193,34 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
|||||||
x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
|
x, weight1, weight2 = [a.to(dtype=dtype) for a in [x, weight1, weight2]]
|
||||||
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
|
bias1 = bias1.to(dtype=dtype) if bias1 is not None else None
|
||||||
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
|
bias2 = bias2.to(dtype=dtype) if bias2 is not None else None
|
||||||
if not save_gelu_in:
|
if not save_pre_act:
|
||||||
checkpoint_lvl = 2
|
checkpoint_lvl = 2
|
||||||
assert checkpoint_lvl in [0, 1, 2]
|
assert checkpoint_lvl in [0, 1, 2]
|
||||||
ctx.return_residual = return_residual
|
ctx.return_residual = return_residual
|
||||||
|
ctx.process_group = process_group
|
||||||
x = x.contiguous()
|
x = x.contiguous()
|
||||||
weight1 = weight1.contiguous()
|
weight1 = weight1.contiguous()
|
||||||
bias1 = bias1.contiguous() if bias1 is not None else None
|
bias1 = bias1.contiguous() if bias1 is not None else None
|
||||||
weight2 = weight2.contiguous()
|
weight2 = weight2.contiguous()
|
||||||
bias2 = bias2.contiguous() if bias2 is not None else None
|
bias2 = bias2.contiguous() if bias2 is not None else None
|
||||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
if process_group is not None:
|
||||||
|
total_x, _ = all_gather_raw(x, process_group)
|
||||||
|
else:
|
||||||
|
total_x = x
|
||||||
|
batch_shape, n = total_x.shape[:-1], total_x.shape[-1]
|
||||||
batch_dim = batch_shape.numel()
|
batch_dim = batch_shape.numel()
|
||||||
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
assert batch_dim <= 64 * 1024, 'fused_dense only supports dimension at most 64k'
|
||||||
if heuristic == -1:
|
if heuristic == -1:
|
||||||
gelu_in = F.linear(x, weight1, bias1)
|
gelu_in = F.linear(total_x, weight1, bias1)
|
||||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||||
# gelu_in = F.linear(x.reshape(batch_dim, n), weight1) # This is before adding bias1
|
# gelu_in = F.linear(total_x.reshape(batch_dim, n), weight1) # This is before adding bias1
|
||||||
# with torch.jit.fuser('fuser2'):
|
# with torch.jit.fuser('fuser2'):
|
||||||
# output1 = bias_gelu(gelu_in, bias1)
|
# output1 = bias_gelu(gelu_in, bias1)
|
||||||
else:
|
else:
|
||||||
output1, *rest = fused_dense_cuda.linear_gelu_forward(x.reshape(batch_dim, n), weight1,
|
output1, *rest = fused_dense_cuda.linear_gelu_forward(
|
||||||
bias1, save_gelu_in, heuristic)
|
total_x.reshape(batch_dim, n), weight1, bias1, save_pre_act, heuristic
|
||||||
if save_gelu_in:
|
)
|
||||||
|
if save_pre_act:
|
||||||
gelu_in = rest[0]
|
gelu_in = rest[0]
|
||||||
output2 = F.linear(output1, weight2, bias2)
|
output2 = F.linear(output1, weight2, bias2)
|
||||||
ctx.checkpoint_lvl = checkpoint_lvl
|
ctx.checkpoint_lvl = checkpoint_lvl
|
||||||
@ -145,9 +242,15 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
|||||||
if ctx.return_residual:
|
if ctx.return_residual:
|
||||||
grad_input, = args
|
grad_input, = args
|
||||||
grad_input = grad_input.contiguous()
|
grad_input = grad_input.contiguous()
|
||||||
|
process_group = ctx.process_group
|
||||||
x, weight1, weight2, *rest = ctx.saved_tensors
|
x, weight1, weight2, *rest = ctx.saved_tensors
|
||||||
batch_shape, n = x.shape[:-1], x.shape[-1]
|
if process_group is None:
|
||||||
|
total_x = x
|
||||||
|
batch_shape = grad_output.shape[:-1]
|
||||||
batch_dim = batch_shape.numel()
|
batch_dim = batch_shape.numel()
|
||||||
|
if checkpoint_lvl in [0, 1]:
|
||||||
|
if process_group is not None:
|
||||||
|
total_x, handle_x = all_gather_raw(x, process_group, async_op=True)
|
||||||
if checkpoint_lvl == 0:
|
if checkpoint_lvl == 0:
|
||||||
gelu_in, output1 = rest
|
gelu_in, output1 = rest
|
||||||
elif checkpoint_lvl == 1:
|
elif checkpoint_lvl == 1:
|
||||||
@ -155,12 +258,15 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
|||||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||||
elif checkpoint_lvl == 2:
|
elif checkpoint_lvl == 2:
|
||||||
bias1, = rest
|
bias1, = rest
|
||||||
|
if process_group is not None:
|
||||||
|
total_x, _ = all_gather_raw(x, process_group)
|
||||||
if ctx.heuristic == -1:
|
if ctx.heuristic == -1:
|
||||||
gelu_in = F.linear(x, weight1, bias1)
|
gelu_in = F.linear(total_x, weight1, bias1)
|
||||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||||
else:
|
else:
|
||||||
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
|
output1, gelu_in = fused_dense_cuda.linear_gelu_forward(
|
||||||
x.reshape(batch_dim, n), weight1, bias1, True, ctx.heuristic
|
total_x.reshape(batch_dim, total_x.shape[-1]), weight1, bias1, True,
|
||||||
|
ctx.heuristic
|
||||||
)
|
)
|
||||||
|
|
||||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||||
@ -178,13 +284,6 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
|||||||
grad_output1 = F.linear(grad_output, weight2.t())
|
grad_output1 = F.linear(grad_output, weight2.t())
|
||||||
with torch.jit.fuser('fuser2'):
|
with torch.jit.fuser('fuser2'):
|
||||||
grad_gelu = gelu_bwd(grad_output1, gelu_in)
|
grad_gelu = gelu_bwd(grad_output1, gelu_in)
|
||||||
if ctx.needs_input_grad[1]:
|
|
||||||
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
|
|
||||||
x.reshape(batch_dim, n), grad_gelu, ctx.needs_input_grad[2]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
grad_weight1 = None
|
|
||||||
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
|
|
||||||
else:
|
else:
|
||||||
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
|
# The cublasLt epilogue has to compute both gelu grad and bias grad, we can't
|
||||||
# just compute gelu grad
|
# just compute gelu grad
|
||||||
@ -193,26 +292,49 @@ class FusedDenseGeluDenseFunc(torch.autograd.Function):
|
|||||||
)
|
)
|
||||||
if not ctx.needs_input_grad[2]:
|
if not ctx.needs_input_grad[2]:
|
||||||
grad_bias1 = None
|
grad_bias1 = None
|
||||||
if ctx.needs_input_grad[1]:
|
|
||||||
grad_weight1 = F.linear(grad_gelu.t(), x.reshape(batch_dim, n).t())
|
|
||||||
else:
|
|
||||||
grad_weight1 = None
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
if not ctx.return_residual:
|
if not ctx.return_residual:
|
||||||
grad_input = F.linear(grad_gelu, weight1.t())
|
grad_input = F.linear(grad_gelu, weight1.t())
|
||||||
else:
|
else:
|
||||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, n), grad_gelu, weight1)
|
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||||
grad_input = grad_input.reshape_as(x)
|
grad_gelu, weight1)
|
||||||
|
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||||
|
if process_group is not None:
|
||||||
|
grad_input, handle_grad_input = reduce_scatter_raw(grad_input, process_group,
|
||||||
|
async_op=True)
|
||||||
else:
|
else:
|
||||||
grad_input = None
|
grad_input = None
|
||||||
return grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2, None, None, None, None
|
if ctx.heuristic == -1:
|
||||||
|
if ctx.needs_input_grad[1]:
|
||||||
|
if process_group is not None:
|
||||||
|
handle_x.wait()
|
||||||
|
grad_weight1, grad_bias1 = fused_dense_cuda.linear_bias_wgrad(
|
||||||
|
total_x.reshape(batch_dim, total_x.shape[-1]), grad_gelu,
|
||||||
|
ctx.needs_input_grad[2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
grad_weight1 = None
|
||||||
|
grad_bias1 = grad_gelu if ctx.needs_input_grad[2] else None
|
||||||
|
else:
|
||||||
|
if ctx.needs_input_grad[1]:
|
||||||
|
if process_group is not None:
|
||||||
|
handle_x.wait()
|
||||||
|
grad_weight1 = F.linear(grad_gelu.t(),
|
||||||
|
total_x.reshape(batch_dim, total_x.shape[-1]).t())
|
||||||
|
else:
|
||||||
|
grad_weight1 = None
|
||||||
|
if process_group is not None and ctx.needs_input_grad[0]:
|
||||||
|
handle_grad_input.wait()
|
||||||
|
return (grad_input, grad_weight1, grad_bias1, grad_weight2, grad_bias2,
|
||||||
|
None, None, None, None, None)
|
||||||
|
|
||||||
|
|
||||||
def fused_dense_gelu_dense_func(
|
def fused_dense_gelu_dense_func(
|
||||||
x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
|
x: Tensor, weight1: Tensor, weight2: Tensor, bias1: Optional[Tensor] = None,
|
||||||
bias2: Optional[Tensor] = None,
|
bias2: Optional[Tensor] = None,
|
||||||
save_gelu_in: bool = True, return_residual: bool = False,
|
save_pre_act: bool = True, return_residual: bool = False,
|
||||||
checkpoint_lvl: int = 0, heuristic: int = 0
|
checkpoint_lvl: int = 0, heuristic: int = 0,
|
||||||
|
process_group: Optional[ProcessGroup] = None
|
||||||
):
|
):
|
||||||
batch_dim = x.shape[:-1].numel()
|
batch_dim = x.shape[:-1].numel()
|
||||||
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
dtype_eligible = (x.dtype in [torch.float16, torch.bfloat16]
|
||||||
@ -222,9 +344,10 @@ def fused_dense_gelu_dense_func(
|
|||||||
and dtype_eligible):
|
and dtype_eligible):
|
||||||
return FusedDenseGeluDenseFunc.apply(
|
return FusedDenseGeluDenseFunc.apply(
|
||||||
x, weight1, bias1, weight2, bias2,
|
x, weight1, bias1, weight2, bias2,
|
||||||
save_gelu_in, return_residual, checkpoint_lvl, heuristic
|
save_pre_act, return_residual, checkpoint_lvl, heuristic, process_group
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert process_group is None
|
||||||
gelu_in = F.linear(x, weight1, bias1)
|
gelu_in = F.linear(x, weight1, bias1)
|
||||||
output1 = F.gelu(gelu_in, approximate='tanh')
|
output1 = F.gelu(gelu_in, approximate='tanh')
|
||||||
output2 = F.linear(output1, weight2, bias2)
|
output2 = F.linear(output1, weight2, bias2)
|
||||||
@ -237,6 +360,10 @@ class FusedDenseGeluDense(nn.Module):
|
|||||||
bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
|
bias2=True, return_residual=False, checkpoint_lvl=0, heuristic=0,
|
||||||
device=None, dtype=None):
|
device=None, dtype=None):
|
||||||
"""
|
"""
|
||||||
|
If process_group is not None, we're doing Tensor Parallel with sequence parallelism:
|
||||||
|
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
||||||
|
Finally we do a reduce_scatter of the output.
|
||||||
|
|
||||||
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||||
0: no recomputation in the bwd
|
0: no recomputation in the bwd
|
||||||
1: recompute gelu_out in the bwd
|
1: recompute gelu_out in the bwd
|
||||||
@ -261,9 +388,58 @@ class FusedDenseGeluDense(nn.Module):
|
|||||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs)
|
||||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, process_group=None):
|
||||||
return fused_dense_gelu_dense_func(
|
out = fused_dense_gelu_dense_func(
|
||||||
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
|
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
|
||||||
save_gelu_in=self.training, return_residual=self.return_residual,
|
save_pre_act=self.training, return_residual=self.return_residual,
|
||||||
checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic
|
checkpoint_lvl=self.checkpoint_lvl, heuristic=self.heuristic,
|
||||||
|
process_group=process_group
|
||||||
)
|
)
|
||||||
|
if self.return_residual:
|
||||||
|
out, x = out
|
||||||
|
if process_group is not None:
|
||||||
|
out = reduce_scatter(out, process_group)
|
||||||
|
return out if not self.return_residual else (out, x)
|
||||||
|
|
||||||
|
|
||||||
|
class ParallelFusedDenseGeluDense(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features, hidden_features, out_features=None,
|
||||||
|
process_group: ProcessGroup = None, bias1=True, bias2=True,
|
||||||
|
checkpoint_lvl=0, heuristic=0, device=None, dtype=None):
|
||||||
|
"""
|
||||||
|
process_group is required. We're doing Tensor Parallel with sequence parallelism:
|
||||||
|
we do an all_gather of x before doing the matmul, gelu, then matmul.
|
||||||
|
Finally we do a reduce_scatter of the output.
|
||||||
|
|
||||||
|
checkpoint_lvl (increasing lvl means slower but more memory saving):
|
||||||
|
0: no recomputation in the bwd
|
||||||
|
1: recompute gelu_out in the bwd
|
||||||
|
2: recompute gelu_in and gelu_out in the bwd
|
||||||
|
heuristic:
|
||||||
|
-1: don't fuse gemm + gelu (separate kernel)
|
||||||
|
0..4: use this heuristic for the algo section in the fused gemm + gelu
|
||||||
|
For CUDA >= 11.8, you'd want heuristic=0 for both fp16 and bf16 for best perf.
|
||||||
|
For CUDA <= 11.7, you'd want heuristic=1 for fp16 and heuristic=-1 for bf16.
|
||||||
|
"""
|
||||||
|
assert checkpoint_lvl in [0, 1, 2]
|
||||||
|
assert process_group is not None
|
||||||
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
super().__init__()
|
||||||
|
if out_features is None:
|
||||||
|
out_features = in_features
|
||||||
|
self.process_group = process_group
|
||||||
|
self.checkpoint_lvl = checkpoint_lvl
|
||||||
|
self.heuristic = heuristic
|
||||||
|
self.fc1 = ColumnParallelLinear(in_features, hidden_features, process_group,
|
||||||
|
bias=bias1, **factory_kwargs)
|
||||||
|
self.fc2 = RowParallelLinear(hidden_features, out_features, process_group,
|
||||||
|
bias=bias2, **factory_kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
out = fused_dense_gelu_dense_func(
|
||||||
|
x, self.fc1.weight, self.fc2.weight, self.fc1.bias, self.fc2.bias,
|
||||||
|
save_pre_act=self.training, checkpoint_lvl=self.checkpoint_lvl,
|
||||||
|
heuristic=self.heuristic, process_group=self.process_group
|
||||||
|
)
|
||||||
|
return reduce_scatter(out, self.process_group)
|
||||||
|
|||||||
74
flash_attn/utils/distributed.py
Normal file
74
flash_attn/utils/distributed.py
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
|
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
|
||||||
|
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent
|
||||||
|
# version of PyTorch. The following 4 lines are for backward compatibility with
|
||||||
|
# older PyTorch.
|
||||||
|
if "all_gather_into_tensor" not in dir(torch.distributed):
|
||||||
|
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
|
||||||
|
if "reduce_scatter_tensor" not in dir(torch.distributed):
|
||||||
|
torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
|
||||||
|
|
||||||
|
|
||||||
|
# Raw operation, oes does support autograd, but does support async
|
||||||
|
def all_gather_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||||
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
|
output = torch.empty(world_size * input_.shape[0], *input_.shape[1:],
|
||||||
|
dtype=input_.dtype, device=input_.device)
|
||||||
|
handle = torch.distributed.all_gather_into_tensor(output, input_.contiguous(),
|
||||||
|
group=process_group, async_op=async_op)
|
||||||
|
return output, handle
|
||||||
|
|
||||||
|
|
||||||
|
# Raw operation, oes does support autograd, but does support async
|
||||||
|
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False):
|
||||||
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
|
assert input_.shape[0] % world_size == 0
|
||||||
|
output = torch.empty(input_.shape[0] // world_size, *input_.shape[1:],
|
||||||
|
dtype=input_.dtype, device=input_.device)
|
||||||
|
handle = torch.distributed.reduce_scatter_tensor(output, input_.contiguous(),
|
||||||
|
group=process_group,
|
||||||
|
async_op=async_op)
|
||||||
|
return output, handle
|
||||||
|
|
||||||
|
|
||||||
|
class AllGatherFunc(torch.autograd.Function):
|
||||||
|
"""Gather the input from sequence parallel region and concatenate."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||||
|
ctx.process_group = process_group
|
||||||
|
output, _ = all_gather_raw(input_, process_group)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output: Tensor):
|
||||||
|
grad_input, _ = reduce_scatter_raw(grad_output, ctx.process_group)
|
||||||
|
return grad_input, None
|
||||||
|
|
||||||
|
|
||||||
|
# Supports autograd, but does not support async
|
||||||
|
all_gather = AllGatherFunc.apply
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceScatterFunc(torch.autograd.Function):
|
||||||
|
"""Reduce scatter the input from the sequence parallel region and concatenate."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_: Tensor, process_group: ProcessGroup) -> Tensor:
|
||||||
|
ctx.process_group = process_group
|
||||||
|
output, _ = reduce_scatter_raw(input_, process_group)
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx, grad_output: Tensor):
|
||||||
|
grad_input, _ = all_gather_raw(grad_output, ctx.process_group)
|
||||||
|
return grad_input, None
|
||||||
|
|
||||||
|
|
||||||
|
# Supports autograd, but does not support async
|
||||||
|
reduce_scatter = ReduceScatterFunc.apply
|
||||||
180
tests/ops/test_fused_dense_parallel.py
Normal file
180
tests/ops/test_fused_dense_parallel.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
# Run test with:
|
||||||
|
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/ops/test_fused_dense_parallel.py
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from apex.transformer import parallel_state
|
||||||
|
from apex.transformer import tensor_parallel
|
||||||
|
|
||||||
|
from flash_attn.ops.fused_dense import FusedDense, FusedDenseGeluDense
|
||||||
|
from flash_attn.ops.fused_dense import ColumnParallelLinear, ParallelFusedDenseGeluDense
|
||||||
|
|
||||||
|
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
|
||||||
|
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||||
|
# @pytest.mark.parametrize('world_size', [8])
|
||||||
|
@pytest.mark.parametrize('has_bias', [True, False])
|
||||||
|
# @pytest.mark.parametrize('has_bias', [True])
|
||||||
|
@pytest.mark.parametrize('out_features', [1024, 4096])
|
||||||
|
# @pytest.mark.parametrize('out_features', [1024])
|
||||||
|
@pytest.mark.parametrize('in_features', [1024, 4096])
|
||||||
|
# @pytest.mark.parametrize('in_features', [4096])
|
||||||
|
def test_fused_linear_bias(in_features, out_features, has_bias, world_size, dtype):
|
||||||
|
assert out_features % world_size == 0
|
||||||
|
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||||
|
device = f'cuda:{torch.distributed.get_rank()}'
|
||||||
|
assert world_size <= torch.distributed.get_world_size()
|
||||||
|
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||||
|
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||||
|
# set seed
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
batch_size = 8
|
||||||
|
seqlen = 512
|
||||||
|
assert batch_size * seqlen % world_size == 0
|
||||||
|
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
|
||||||
|
requires_grad=True)
|
||||||
|
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
model_pt = torch.nn.Linear(in_features, out_features, bias=has_bias, device=device, dtype=dtype)
|
||||||
|
partition_out_features = out_features // world_size
|
||||||
|
model = ColumnParallelLinear(in_features, out_features,
|
||||||
|
parallel_state.get_tensor_model_parallel_group(), bias=has_bias,
|
||||||
|
device=device, dtype=dtype)
|
||||||
|
with torch.no_grad():
|
||||||
|
model.weight.copy_(
|
||||||
|
model_pt.weight[rank * partition_out_features:(rank + 1) * partition_out_features]
|
||||||
|
)
|
||||||
|
if has_bias:
|
||||||
|
model.bias.copy_(
|
||||||
|
model_pt.bias[rank * partition_out_features:(rank + 1) * partition_out_features]
|
||||||
|
)
|
||||||
|
|
||||||
|
out = model(x)
|
||||||
|
out_pt = model_pt(x_pt)
|
||||||
|
assert torch.allclose(
|
||||||
|
out, out_pt[:, rank * partition_out_features:(rank + 1) * partition_out_features],
|
||||||
|
rtol=rtol, atol=atol
|
||||||
|
)
|
||||||
|
|
||||||
|
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||||
|
g = torch.randn_like(out_pt) / 32
|
||||||
|
out_pt.backward(g)
|
||||||
|
out.backward(g[:, rank * partition_out_features:(rank + 1) * partition_out_features])
|
||||||
|
parallel_state.destroy_model_parallel()
|
||||||
|
|
||||||
|
partition_batch_dim = batch_size * seqlen // world_size
|
||||||
|
assert torch.allclose(
|
||||||
|
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||||
|
rtol=rtol, atol=atol
|
||||||
|
)
|
||||||
|
# The error for d_weight and d_bias is quite a bit higher
|
||||||
|
assert torch.allclose(
|
||||||
|
model.weight.grad,
|
||||||
|
model_pt.weight.grad[rank * partition_out_features:(rank + 1) * partition_out_features],
|
||||||
|
rtol=rtol, atol=atol * 10
|
||||||
|
)
|
||||||
|
if has_bias:
|
||||||
|
assert torch.allclose(
|
||||||
|
model.bias.grad,
|
||||||
|
model_pt.bias.grad[rank * partition_out_features:(rank + 1) * partition_out_features],
|
||||||
|
rtol=rtol, atol=atol * 5
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
|
||||||
|
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
|
||||||
|
# @pytest.mark.parametrize('world_size', [2])
|
||||||
|
@pytest.mark.parametrize('has_bias2', [True, False])
|
||||||
|
# @pytest.mark.parametrize('has_bias2', [True])
|
||||||
|
@pytest.mark.parametrize('out_features', [1024, 4096])
|
||||||
|
# @pytest.mark.parametrize('out_features', [1024])
|
||||||
|
@pytest.mark.parametrize('in_features', [1024, 4096])
|
||||||
|
# @pytest.mark.parametrize('in_features', [1024])
|
||||||
|
def test_fused_dense_gelu_dense(in_features, out_features, has_bias2, world_size, dtype):
|
||||||
|
assert out_features % world_size == 0
|
||||||
|
rtol, atol = (3e-3, 3e-2) if dtype == torch.bfloat16 else (3e-3, 3e-3)
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||||
|
device = f'cuda:{torch.distributed.get_rank()}'
|
||||||
|
assert world_size <= torch.distributed.get_world_size()
|
||||||
|
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
|
||||||
|
rank = parallel_state.get_tensor_model_parallel_rank()
|
||||||
|
# set seed
|
||||||
|
torch.random.manual_seed(0)
|
||||||
|
batch_size = 8
|
||||||
|
seqlen = 512
|
||||||
|
assert batch_size * seqlen % world_size == 0
|
||||||
|
x_pt = torch.randn(batch_size * seqlen, in_features, device=device, dtype=dtype,
|
||||||
|
requires_grad=True)
|
||||||
|
# We need to generate g here so that all processes get the same gradient,
|
||||||
|
# as rank 0 will have an extra bias that changes the RNG.
|
||||||
|
# If we don't divide by batch_size, the gradient gets a bit too large.
|
||||||
|
g = torch.randn_like(x_pt) / 32
|
||||||
|
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
|
||||||
|
|
||||||
|
model_pt_fc1 = torch.nn.Linear(in_features, out_features, device=device, dtype=dtype)
|
||||||
|
model_pt_fc2 = torch.nn.Linear(out_features, in_features, bias=has_bias2, device=device,
|
||||||
|
dtype=dtype)
|
||||||
|
partition_out_features = out_features // world_size
|
||||||
|
partition_in_features = in_features // world_size
|
||||||
|
model = ParallelFusedDenseGeluDense(in_features, out_features, in_features,
|
||||||
|
process_group=parallel_state.get_tensor_model_parallel_group(),
|
||||||
|
bias2=has_bias2 and rank == 0, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
model.fc1.weight.copy_(
|
||||||
|
model_pt_fc1.weight[rank * partition_out_features:(rank + 1) * partition_out_features]
|
||||||
|
)
|
||||||
|
model.fc1.bias.copy_(
|
||||||
|
model_pt_fc1.bias[rank * partition_out_features:(rank + 1) * partition_out_features]
|
||||||
|
)
|
||||||
|
model.fc2.weight.copy_(
|
||||||
|
model_pt_fc2.weight[:, rank * partition_out_features:(rank + 1) * partition_out_features]
|
||||||
|
)
|
||||||
|
if has_bias2 and rank == 0:
|
||||||
|
model.fc2.bias.copy_(model_pt_fc2.bias)
|
||||||
|
|
||||||
|
out = model(x)
|
||||||
|
out_pt = model_pt_fc2(F.gelu(model_pt_fc1(x_pt), approximate='tanh'))
|
||||||
|
partition_batch_dim = batch_size * seqlen // world_size
|
||||||
|
assert torch.allclose(
|
||||||
|
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||||
|
rtol=rtol, atol=atol
|
||||||
|
)
|
||||||
|
|
||||||
|
out_pt.backward(g)
|
||||||
|
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
|
||||||
|
parallel_state.destroy_model_parallel()
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
|
||||||
|
rtol=rtol, atol=atol
|
||||||
|
)
|
||||||
|
# The error for d_weight and d_bias is quite a bit higher
|
||||||
|
assert torch.allclose(
|
||||||
|
model.fc1.weight.grad,
|
||||||
|
model_pt_fc1.weight.grad[rank * partition_out_features:(rank + 1) * partition_out_features],
|
||||||
|
rtol=rtol, atol=atol * 10
|
||||||
|
)
|
||||||
|
assert torch.allclose(
|
||||||
|
model.fc1.bias.grad,
|
||||||
|
model_pt_fc1.bias.grad[rank * partition_out_features:(rank + 1) * partition_out_features],
|
||||||
|
rtol=rtol, atol=atol * 5
|
||||||
|
)
|
||||||
|
assert torch.allclose(
|
||||||
|
model.fc2.weight.grad,
|
||||||
|
model_pt_fc2.weight.grad[:, rank * partition_out_features:(rank + 1) * partition_out_features],
|
||||||
|
rtol=rtol, atol=atol * 10
|
||||||
|
)
|
||||||
|
if has_bias2 and rank == 0:
|
||||||
|
assert torch.allclose(model.fc2.bias.grad, model_pt_fc2.bias.grad, rtol=rtol, atol=atol * 5)
|
||||||
Loading…
Reference in New Issue
Block a user