Set batch_strides on Params::update (#883)

This commit is contained in:
Jack Kosaian 2023-03-20 17:07:47 -04:00 committed by GitHub
parent 2670b973dd
commit 6116706c96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 54 additions and 20 deletions

View File

@ -508,8 +508,7 @@ public:
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
}
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
/// remain the same.
/// Lightweight update given a subset of arguments.
Status update(Arguments const &args) {
return underlying_operator_.update(to_underlying_arguments(args));

View File

@ -323,8 +323,7 @@ public:
}
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
/// remain the same.
/// Lightweight update given a subset of arguments.
Status update(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversalBase()::update()");

View File

@ -384,8 +384,7 @@ public:
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
{}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
/// Lightweight update given a subset of arguments.
void update(Arguments const &args)
{
ptr_A = const_cast<void *>(args.ptr_A);
@ -397,6 +396,15 @@ public:
ptr_C = const_cast<void *>(args.ptr_C);
ptr_D = args.ptr_D;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_var = args.batch_stride_var;
batch_stride_mean = args.batch_stride_mean;
batch_stride_gamma = args.batch_stride_gamma;
batch_stride_beta = args.batch_stride_beta;
this->batch_stride_D = args.batch_stride_D;
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);

View File

@ -339,8 +339,7 @@ public:
return workspace_bytes;
}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
/// Lightweight update given a subset of arguments.
void update(Arguments const &args)
{
ptr_A_real = const_cast<void *>(args.ptr_A_real);
@ -355,6 +354,15 @@ public:
ptr_D_real = const_cast<void *>(args.ptr_D_real);
ptr_D_imag = const_cast<void *>(args.ptr_D_imag);
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
this->batch_stride_D = args.batch_stride_D;
batch_stride_A_imag = args.batch_stride_A_imag;
batch_stride_B_imag = args.batch_stride_B_imag;
batch_stride_C_imag = args.batch_stride_C_imag;
batch_stride_D_imag = args.batch_stride_D_imag;
output_op = args.epilogue;
}
};

View File

@ -304,8 +304,7 @@ public:
ptr_D_imag(args.ptr_D_imag)
{}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
/// Lightweight update given a subset of arguments.
void update(Arguments const &args)
{
ptr_M = args.ptr_M;

View File

@ -320,8 +320,7 @@ public:
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
{}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
/// Lightweight update given a subset of arguments.
void update(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
@ -332,6 +331,11 @@ public:
ptr_C = const_cast<void *>(args.ptr_C);
ptr_D = args.ptr_D;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
this->batch_stride_D = args.batch_stride_D;
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);

View File

@ -454,8 +454,7 @@ public:
}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
/// Lightweight update given a subset of arguments.
void update(Arguments const &args)
{
CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()");

View File

@ -310,8 +310,7 @@ public:
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
/// Lightweight update given a subset of arguments.
CUTLASS_HOST_DEVICE
void update(Arguments const &args)
{
@ -325,6 +324,14 @@ public:
ldr = args.ldr;
ptr_Tensor = args.ptr_Tensor;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C1 = args.batch_stride_C1;
batch_stride_C2 = args.batch_stride_C2;
batch_stride_Vector = args.batch_stride_Vector;
batch_stride_Tensor = args.batch_stride_Tensor;
this->batch_stride_D = args.batch_stride_D;
output_op = args.epilogue;
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
@ -1025,8 +1032,7 @@ public:
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
/// Lightweight update given a subset of arguments.
CUTLASS_HOST_DEVICE
void update(Arguments const &args)
{
@ -1039,6 +1045,13 @@ public:
ldr = args.ldr;
ptr_Tensor = args.ptr_Tensor;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_Vector = args.batch_stride_Vector;
batch_stride_Tensor = args.batch_stride_Tensor;
this->batch_stride_D = args.batch_stride_D;
output_op = args.epilogue;
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");

View File

@ -295,8 +295,7 @@ public:
return workspace_bytes;
}
/// Lightweight update given a subset of arguments. Problem geometry is assumed
/// to remain the same.
/// Lightweight update given a subset of arguments.
void update(Arguments const &args)
{
ptr_A = const_cast<void *>(args.ptr_A);
@ -305,6 +304,12 @@ public:
ptr_D = args.ptr_D;
ptr_gemm_k_reduction = args.ptr_gemm_k_reduction;
batch_stride_A = args.batch_stride_A;
batch_stride_B = args.batch_stride_B;
batch_stride_C = args.batch_stride_C;
batch_stride_gemm_k_reduction = args.batch_stride_gemm_k_reduction;
this->batch_stride_D = args.batch_stride_D;
output_op = args.epilogue;
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");