Set batch_strides on Params::update (#883)
This commit is contained in:
parent
2670b973dd
commit
6116706c96
@ -508,8 +508,7 @@ public:
|
||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
|
||||
/// remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
Status update(Arguments const &args) {
|
||||
|
||||
return underlying_operator_.update(to_underlying_arguments(args));
|
||||
|
@ -323,8 +323,7 @@ public:
|
||||
}
|
||||
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
|
||||
/// remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
Status update(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalBase()::update()");
|
||||
|
@ -384,8 +384,7 @@ public:
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
|
||||
{}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
@ -397,6 +396,15 @@ public:
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_var = args.batch_stride_var;
|
||||
batch_stride_mean = args.batch_stride_mean;
|
||||
batch_stride_gamma = args.batch_stride_gamma;
|
||||
batch_stride_beta = args.batch_stride_beta;
|
||||
this->batch_stride_D = args.batch_stride_D;
|
||||
|
||||
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||
|
@ -339,8 +339,7 @@ public:
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
ptr_A_real = const_cast<void *>(args.ptr_A_real);
|
||||
@ -355,6 +354,15 @@ public:
|
||||
ptr_D_real = const_cast<void *>(args.ptr_D_real);
|
||||
ptr_D_imag = const_cast<void *>(args.ptr_D_imag);
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
this->batch_stride_D = args.batch_stride_D;
|
||||
batch_stride_A_imag = args.batch_stride_A_imag;
|
||||
batch_stride_B_imag = args.batch_stride_B_imag;
|
||||
batch_stride_C_imag = args.batch_stride_C_imag;
|
||||
batch_stride_D_imag = args.batch_stride_D_imag;
|
||||
|
||||
output_op = args.epilogue;
|
||||
}
|
||||
};
|
||||
|
@ -304,8 +304,7 @@ public:
|
||||
ptr_D_imag(args.ptr_D_imag)
|
||||
{}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
ptr_M = args.ptr_M;
|
||||
|
@ -320,8 +320,7 @@ public:
|
||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
|
||||
{}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
@ -332,6 +331,11 @@ public:
|
||||
ptr_C = const_cast<void *>(args.ptr_C);
|
||||
ptr_D = args.ptr_D;
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
this->batch_stride_D = args.batch_stride_D;
|
||||
|
||||
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||
|
@ -454,8 +454,7 @@ public:
|
||||
}
|
||||
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()");
|
||||
|
@ -310,8 +310,7 @@ public:
|
||||
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
@ -325,6 +324,14 @@ public:
|
||||
ldr = args.ldr;
|
||||
ptr_Tensor = args.ptr_Tensor;
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C1 = args.batch_stride_C1;
|
||||
batch_stride_C2 = args.batch_stride_C2;
|
||||
batch_stride_Vector = args.batch_stride_Vector;
|
||||
batch_stride_Tensor = args.batch_stride_Tensor;
|
||||
this->batch_stride_D = args.batch_stride_D;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
|
||||
@ -1025,8 +1032,7 @@ public:
|
||||
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
CUTLASS_HOST_DEVICE
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
@ -1039,6 +1045,13 @@ public:
|
||||
ldr = args.ldr;
|
||||
ptr_Tensor = args.ptr_Tensor;
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_Vector = args.batch_stride_Vector;
|
||||
batch_stride_Tensor = args.batch_stride_Tensor;
|
||||
this->batch_stride_D = args.batch_stride_D;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
|
||||
|
@ -295,8 +295,7 @@ public:
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
||||
/// to remain the same.
|
||||
/// Lightweight update given a subset of arguments.
|
||||
void update(Arguments const &args)
|
||||
{
|
||||
ptr_A = const_cast<void *>(args.ptr_A);
|
||||
@ -305,6 +304,12 @@ public:
|
||||
ptr_D = args.ptr_D;
|
||||
ptr_gemm_k_reduction = args.ptr_gemm_k_reduction;
|
||||
|
||||
batch_stride_A = args.batch_stride_A;
|
||||
batch_stride_B = args.batch_stride_B;
|
||||
batch_stride_C = args.batch_stride_C;
|
||||
batch_stride_gemm_k_reduction = args.batch_stride_gemm_k_reduction;
|
||||
this->batch_stride_D = args.batch_stride_D;
|
||||
|
||||
output_op = args.epilogue;
|
||||
|
||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||
|
Loading…
Reference in New Issue
Block a user