Set batch_strides on Params::update (#883)
This commit is contained in:
parent
2670b973dd
commit
6116706c96
@ -508,8 +508,7 @@ public:
|
|||||||
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
|
/// Lightweight update given a subset of arguments.
|
||||||
/// remain the same.
|
|
||||||
Status update(Arguments const &args) {
|
Status update(Arguments const &args) {
|
||||||
|
|
||||||
return underlying_operator_.update(to_underlying_arguments(args));
|
return underlying_operator_.update(to_underlying_arguments(args));
|
||||||
|
@ -323,8 +323,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed to
|
/// Lightweight update given a subset of arguments.
|
||||||
/// remain the same.
|
|
||||||
Status update(Arguments const &args)
|
Status update(Arguments const &args)
|
||||||
{
|
{
|
||||||
CUTLASS_TRACE_HOST("GemmUniversalBase()::update()");
|
CUTLASS_TRACE_HOST("GemmUniversalBase()::update()");
|
||||||
|
@ -384,8 +384,7 @@ public:
|
|||||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
|
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
|
||||||
{}
|
{}
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
/// Lightweight update given a subset of arguments.
|
||||||
/// to remain the same.
|
|
||||||
void update(Arguments const &args)
|
void update(Arguments const &args)
|
||||||
{
|
{
|
||||||
ptr_A = const_cast<void *>(args.ptr_A);
|
ptr_A = const_cast<void *>(args.ptr_A);
|
||||||
@ -397,6 +396,15 @@ public:
|
|||||||
ptr_C = const_cast<void *>(args.ptr_C);
|
ptr_C = const_cast<void *>(args.ptr_C);
|
||||||
ptr_D = args.ptr_D;
|
ptr_D = args.ptr_D;
|
||||||
|
|
||||||
|
batch_stride_A = args.batch_stride_A;
|
||||||
|
batch_stride_B = args.batch_stride_B;
|
||||||
|
batch_stride_C = args.batch_stride_C;
|
||||||
|
batch_stride_var = args.batch_stride_var;
|
||||||
|
batch_stride_mean = args.batch_stride_mean;
|
||||||
|
batch_stride_gamma = args.batch_stride_gamma;
|
||||||
|
batch_stride_beta = args.batch_stride_beta;
|
||||||
|
this->batch_stride_D = args.batch_stride_D;
|
||||||
|
|
||||||
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
||||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||||
|
@ -339,8 +339,7 @@ public:
|
|||||||
return workspace_bytes;
|
return workspace_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
/// Lightweight update given a subset of arguments.
|
||||||
/// to remain the same.
|
|
||||||
void update(Arguments const &args)
|
void update(Arguments const &args)
|
||||||
{
|
{
|
||||||
ptr_A_real = const_cast<void *>(args.ptr_A_real);
|
ptr_A_real = const_cast<void *>(args.ptr_A_real);
|
||||||
@ -355,6 +354,15 @@ public:
|
|||||||
ptr_D_real = const_cast<void *>(args.ptr_D_real);
|
ptr_D_real = const_cast<void *>(args.ptr_D_real);
|
||||||
ptr_D_imag = const_cast<void *>(args.ptr_D_imag);
|
ptr_D_imag = const_cast<void *>(args.ptr_D_imag);
|
||||||
|
|
||||||
|
batch_stride_A = args.batch_stride_A;
|
||||||
|
batch_stride_B = args.batch_stride_B;
|
||||||
|
batch_stride_C = args.batch_stride_C;
|
||||||
|
this->batch_stride_D = args.batch_stride_D;
|
||||||
|
batch_stride_A_imag = args.batch_stride_A_imag;
|
||||||
|
batch_stride_B_imag = args.batch_stride_B_imag;
|
||||||
|
batch_stride_C_imag = args.batch_stride_C_imag;
|
||||||
|
batch_stride_D_imag = args.batch_stride_D_imag;
|
||||||
|
|
||||||
output_op = args.epilogue;
|
output_op = args.epilogue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -304,8 +304,7 @@ public:
|
|||||||
ptr_D_imag(args.ptr_D_imag)
|
ptr_D_imag(args.ptr_D_imag)
|
||||||
{}
|
{}
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
/// Lightweight update given a subset of arguments.
|
||||||
/// to remain the same.
|
|
||||||
void update(Arguments const &args)
|
void update(Arguments const &args)
|
||||||
{
|
{
|
||||||
ptr_M = args.ptr_M;
|
ptr_M = args.ptr_M;
|
||||||
|
@ -320,8 +320,7 @@ public:
|
|||||||
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
|
ptr_scatter_D_indices(const_cast<int *>(args.ptr_scatter_D_indices))
|
||||||
{}
|
{}
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
/// Lightweight update given a subset of arguments.
|
||||||
/// to remain the same.
|
|
||||||
void update(Arguments const &args)
|
void update(Arguments const &args)
|
||||||
{
|
{
|
||||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||||
@ -332,6 +331,11 @@ public:
|
|||||||
ptr_C = const_cast<void *>(args.ptr_C);
|
ptr_C = const_cast<void *>(args.ptr_C);
|
||||||
ptr_D = args.ptr_D;
|
ptr_D = args.ptr_D;
|
||||||
|
|
||||||
|
batch_stride_A = args.batch_stride_A;
|
||||||
|
batch_stride_B = args.batch_stride_B;
|
||||||
|
batch_stride_C = args.batch_stride_C;
|
||||||
|
this->batch_stride_D = args.batch_stride_D;
|
||||||
|
|
||||||
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
ptr_gather_A_indices = const_cast<int *>(args.ptr_gather_A_indices);
|
||||||
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
ptr_gather_B_indices = const_cast<int *>(args.ptr_gather_B_indices);
|
||||||
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
ptr_scatter_D_indices = const_cast<int *>(args.ptr_scatter_D_indices);
|
||||||
|
@ -454,8 +454,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
/// Lightweight update given a subset of arguments.
|
||||||
/// to remain the same.
|
|
||||||
void update(Arguments const &args)
|
void update(Arguments const &args)
|
||||||
{
|
{
|
||||||
CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()");
|
CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()");
|
||||||
|
@ -310,8 +310,7 @@ public:
|
|||||||
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
|
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
/// Lightweight update given a subset of arguments.
|
||||||
/// to remain the same.
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void update(Arguments const &args)
|
void update(Arguments const &args)
|
||||||
{
|
{
|
||||||
@ -325,6 +324,14 @@ public:
|
|||||||
ldr = args.ldr;
|
ldr = args.ldr;
|
||||||
ptr_Tensor = args.ptr_Tensor;
|
ptr_Tensor = args.ptr_Tensor;
|
||||||
|
|
||||||
|
batch_stride_A = args.batch_stride_A;
|
||||||
|
batch_stride_B = args.batch_stride_B;
|
||||||
|
batch_stride_C1 = args.batch_stride_C1;
|
||||||
|
batch_stride_C2 = args.batch_stride_C2;
|
||||||
|
batch_stride_Vector = args.batch_stride_Vector;
|
||||||
|
batch_stride_Tensor = args.batch_stride_Tensor;
|
||||||
|
this->batch_stride_D = args.batch_stride_D;
|
||||||
|
|
||||||
output_op = args.epilogue;
|
output_op = args.epilogue;
|
||||||
|
|
||||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
|
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
|
||||||
@ -1025,8 +1032,7 @@ public:
|
|||||||
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
|
CUTLASS_TRACE_HOST(" ldt: " << args.ldt);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
/// Lightweight update given a subset of arguments.
|
||||||
/// to remain the same.
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
void update(Arguments const &args)
|
void update(Arguments const &args)
|
||||||
{
|
{
|
||||||
@ -1039,6 +1045,13 @@ public:
|
|||||||
ldr = args.ldr;
|
ldr = args.ldr;
|
||||||
ptr_Tensor = args.ptr_Tensor;
|
ptr_Tensor = args.ptr_Tensor;
|
||||||
|
|
||||||
|
batch_stride_A = args.batch_stride_A;
|
||||||
|
batch_stride_B = args.batch_stride_B;
|
||||||
|
batch_stride_C = args.batch_stride_C;
|
||||||
|
batch_stride_Vector = args.batch_stride_Vector;
|
||||||
|
batch_stride_Tensor = args.batch_stride_Tensor;
|
||||||
|
this->batch_stride_D = args.batch_stride_D;
|
||||||
|
|
||||||
output_op = args.epilogue;
|
output_op = args.epilogue;
|
||||||
|
|
||||||
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
|
CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()");
|
||||||
|
@ -295,8 +295,7 @@ public:
|
|||||||
return workspace_bytes;
|
return workspace_bytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lightweight update given a subset of arguments. Problem geometry is assumed
|
/// Lightweight update given a subset of arguments.
|
||||||
/// to remain the same.
|
|
||||||
void update(Arguments const &args)
|
void update(Arguments const &args)
|
||||||
{
|
{
|
||||||
ptr_A = const_cast<void *>(args.ptr_A);
|
ptr_A = const_cast<void *>(args.ptr_A);
|
||||||
@ -305,6 +304,12 @@ public:
|
|||||||
ptr_D = args.ptr_D;
|
ptr_D = args.ptr_D;
|
||||||
ptr_gemm_k_reduction = args.ptr_gemm_k_reduction;
|
ptr_gemm_k_reduction = args.ptr_gemm_k_reduction;
|
||||||
|
|
||||||
|
batch_stride_A = args.batch_stride_A;
|
||||||
|
batch_stride_B = args.batch_stride_B;
|
||||||
|
batch_stride_C = args.batch_stride_C;
|
||||||
|
batch_stride_gemm_k_reduction = args.batch_stride_gemm_k_reduction;
|
||||||
|
this->batch_stride_D = args.batch_stride_D;
|
||||||
|
|
||||||
output_op = args.epilogue;
|
output_op = args.epilogue;
|
||||||
|
|
||||||
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");
|
||||||
|
Loading…
Reference in New Issue
Block a user