From 6116706c96d4106382f3f94a5fb8673311606b35 Mon Sep 17 00:00:00 2001 From: Jack Kosaian Date: Mon, 20 Mar 2023 17:07:47 -0400 Subject: [PATCH] Set batch_strides on Params::update (#883) --- .../gemm/device/gemm_universal_adapter.h | 3 +-- .../cutlass/gemm/device/gemm_universal_base.h | 3 +-- .../kernel/gemm_layernorm_mainloop_fusion.h | 12 +++++++++-- .../cutlass/gemm/kernel/gemm_planar_complex.h | 12 +++++++++-- .../gemm/kernel/gemm_planar_complex_array.h | 3 +-- include/cutlass/gemm/kernel/gemm_universal.h | 8 +++++-- .../gemm/kernel/gemm_universal_streamk.h | 3 +-- .../gemm/kernel/gemm_with_fused_epilogue.h | 21 +++++++++++++++---- .../gemm/kernel/gemm_with_k_reduction.h | 9 ++++++-- 9 files changed, 54 insertions(+), 20 deletions(-) diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 66884fb2..922fcc50 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -508,8 +508,7 @@ public: return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } - /// Lightweight update given a subset of arguments. Problem geometry is assumed to - /// remain the same. + /// Lightweight update given a subset of arguments. Status update(Arguments const &args) { return underlying_operator_.update(to_underlying_arguments(args)); diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 6fb1753e..a09afb4c 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -323,8 +323,7 @@ public: } - /// Lightweight update given a subset of arguments. Problem geometry is assumed to - /// remain the same. + /// Lightweight update given a subset of arguments. Status update(Arguments const &args) { CUTLASS_TRACE_HOST("GemmUniversalBase()::update()"); diff --git a/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h b/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h index 94e2f1dd..c2daadbf 100644 --- a/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h +++ b/include/cutlass/gemm/kernel/gemm_layernorm_mainloop_fusion.h @@ -384,8 +384,7 @@ public: ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) {} - /// Lightweight update given a subset of arguments. Problem geometry is assumed - /// to remain the same. + /// Lightweight update given a subset of arguments. void update(Arguments const &args) { ptr_A = const_cast(args.ptr_A); @@ -397,6 +396,15 @@ public: ptr_C = const_cast(args.ptr_C); ptr_D = args.ptr_D; + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_var = args.batch_stride_var; + batch_stride_mean = args.batch_stride_mean; + batch_stride_gamma = args.batch_stride_gamma; + batch_stride_beta = args.batch_stride_beta; + this->batch_stride_D = args.batch_stride_D; + ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex.h b/include/cutlass/gemm/kernel/gemm_planar_complex.h index 7dbc5923..92243a97 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex.h @@ -339,8 +339,7 @@ public: return workspace_bytes; } - /// Lightweight update given a subset of arguments. Problem geometry is assumed - /// to remain the same. + /// Lightweight update given a subset of arguments. void update(Arguments const &args) { ptr_A_real = const_cast(args.ptr_A_real); @@ -355,6 +354,15 @@ public: ptr_D_real = const_cast(args.ptr_D_real); ptr_D_imag = const_cast(args.ptr_D_imag); + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + this->batch_stride_D = args.batch_stride_D; + batch_stride_A_imag = args.batch_stride_A_imag; + batch_stride_B_imag = args.batch_stride_B_imag; + batch_stride_C_imag = args.batch_stride_C_imag; + batch_stride_D_imag = args.batch_stride_D_imag; + output_op = args.epilogue; } }; diff --git a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h index 21b80114..713946f0 100644 --- a/include/cutlass/gemm/kernel/gemm_planar_complex_array.h +++ b/include/cutlass/gemm/kernel/gemm_planar_complex_array.h @@ -304,8 +304,7 @@ public: ptr_D_imag(args.ptr_D_imag) {} - /// Lightweight update given a subset of arguments. Problem geometry is assumed - /// to remain the same. + /// Lightweight update given a subset of arguments. void update(Arguments const &args) { ptr_M = args.ptr_M; diff --git a/include/cutlass/gemm/kernel/gemm_universal.h b/include/cutlass/gemm/kernel/gemm_universal.h index fc62c01b..3dbd422d 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.h +++ b/include/cutlass/gemm/kernel/gemm_universal.h @@ -320,8 +320,7 @@ public: ptr_scatter_D_indices(const_cast(args.ptr_scatter_D_indices)) {} - /// Lightweight update given a subset of arguments. Problem geometry is assumed - /// to remain the same. + /// Lightweight update given a subset of arguments. void update(Arguments const &args) { CUTLASS_TRACE_HOST("GemmUniversal::Params::update()"); @@ -332,6 +331,11 @@ public: ptr_C = const_cast(args.ptr_C); ptr_D = args.ptr_D; + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + this->batch_stride_D = args.batch_stride_D; + ptr_gather_A_indices = const_cast(args.ptr_gather_A_indices); ptr_gather_B_indices = const_cast(args.ptr_gather_B_indices); ptr_scatter_D_indices = const_cast(args.ptr_scatter_D_indices); diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index eaa2a594..57d4fe53 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -454,8 +454,7 @@ public: } - /// Lightweight update given a subset of arguments. Problem geometry is assumed - /// to remain the same. + /// Lightweight update given a subset of arguments. void update(Arguments const &args) { CUTLASS_TRACE_HOST("GemmUniversalStreamK::Params::update()"); diff --git a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h index 8f67bd45..f41e8130 100644 --- a/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h +++ b/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h @@ -310,8 +310,7 @@ public: CUTLASS_TRACE_HOST(" ldt: " << args.ldt); } - /// Lightweight update given a subset of arguments. Problem geometry is assumed - /// to remain the same. + /// Lightweight update given a subset of arguments. CUTLASS_HOST_DEVICE void update(Arguments const &args) { @@ -325,6 +324,14 @@ public: ldr = args.ldr; ptr_Tensor = args.ptr_Tensor; + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C1 = args.batch_stride_C1; + batch_stride_C2 = args.batch_stride_C2; + batch_stride_Vector = args.batch_stride_Vector; + batch_stride_Tensor = args.batch_stride_Tensor; + this->batch_stride_D = args.batch_stride_D; + output_op = args.epilogue; CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); @@ -1025,8 +1032,7 @@ public: CUTLASS_TRACE_HOST(" ldt: " << args.ldt); } - /// Lightweight update given a subset of arguments. Problem geometry is assumed - /// to remain the same. + /// Lightweight update given a subset of arguments. CUTLASS_HOST_DEVICE void update(Arguments const &args) { @@ -1039,6 +1045,13 @@ public: ldr = args.ldr; ptr_Tensor = args.ptr_Tensor; + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_Vector = args.batch_stride_Vector; + batch_stride_Tensor = args.batch_stride_Tensor; + this->batch_stride_D = args.batch_stride_D; + output_op = args.epilogue; CUTLASS_TRACE_HOST("GemmWithFusedEpilogue::Params::update()"); diff --git a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h index 8e00e184..c9195e3a 100644 --- a/include/cutlass/gemm/kernel/gemm_with_k_reduction.h +++ b/include/cutlass/gemm/kernel/gemm_with_k_reduction.h @@ -295,8 +295,7 @@ public: return workspace_bytes; } - /// Lightweight update given a subset of arguments. Problem geometry is assumed - /// to remain the same. + /// Lightweight update given a subset of arguments. void update(Arguments const &args) { ptr_A = const_cast(args.ptr_A); @@ -305,6 +304,12 @@ public: ptr_D = args.ptr_D; ptr_gemm_k_reduction = args.ptr_gemm_k_reduction; + batch_stride_A = args.batch_stride_A; + batch_stride_B = args.batch_stride_B; + batch_stride_C = args.batch_stride_C; + batch_stride_gemm_k_reduction = args.batch_stride_gemm_k_reduction; + this->batch_stride_D = args.batch_stride_D; + output_op = args.epilogue; CUTLASS_TRACE_HOST("GemmUniversal::Params::update()");