diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 2e5e3b0c..dff737ff 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -244,7 +244,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm.h b/include/cutlass/gemm/device/gemm.h index e1d0092c..b398688f 100644 --- a/include/cutlass/gemm/device/gemm.h +++ b/include/cutlass/gemm/device/gemm.h @@ -482,7 +482,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -673,7 +673,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -699,7 +699,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_array.h b/include/cutlass/gemm/device/gemm_array.h index 12bc300f..be7be25d 100644 --- a/include/cutlass/gemm/device/gemm_array.h +++ b/include/cutlass/gemm/device/gemm_array.h @@ -473,7 +473,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -700,7 +700,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -726,7 +726,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_batched.h b/include/cutlass/gemm/device/gemm_batched.h index 8f09b4a7..e1093270 100644 --- a/include/cutlass/gemm/device/gemm_batched.h +++ b/include/cutlass/gemm/device/gemm_batched.h @@ -451,7 +451,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -666,7 +666,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -692,7 +692,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_complex.h b/include/cutlass/gemm/device/gemm_complex.h index 70e0b46a..4b0fcaa9 100644 --- a/include/cutlass/gemm/device/gemm_complex.h +++ b/include/cutlass/gemm/device/gemm_complex.h @@ -465,7 +465,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); @@ -674,7 +674,7 @@ public: /// Initializes GEMM state from arguments. Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { - return underlying_operator_.initialize(to_underlying_arguments(args), workspace); + return underlying_operator_.initialize(to_underlying_arguments(args), workspace, stream); } /// Lightweight update given a subset of arguments @@ -700,7 +700,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_sparse.h b/include/cutlass/gemm/device/gemm_sparse.h index 04e2dd66..b37f585a 100644 --- a/include/cutlass/gemm/device/gemm_sparse.h +++ b/include/cutlass/gemm/device/gemm_sparse.h @@ -498,7 +498,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream); diff --git a/include/cutlass/gemm/device/gemm_universal_base.h b/include/cutlass/gemm/device/gemm_universal_base.h index 74c519a4..f15b3589 100644 --- a/include/cutlass/gemm/device/gemm_universal_base.h +++ b/include/cutlass/gemm/device/gemm_universal_base.h @@ -347,7 +347,7 @@ public: return Status::kErrorWorkspaceNull; } - params_.update(args, workspace); + params_.update(args, workspace, stream); return Status::kSuccess; } diff --git a/include/cutlass/reduction/device/reduce_split_k.h b/include/cutlass/reduction/device/reduce_split_k.h index 4c044a4c..f8558643 100644 --- a/include/cutlass/reduction/device/reduce_split_k.h +++ b/include/cutlass/reduction/device/reduce_split_k.h @@ -197,7 +197,7 @@ public: void *workspace = nullptr, cudaStream_t stream = nullptr) { - Status status = initialize(args, workspace); + Status status = initialize(args, workspace, stream); if (status == Status::kSuccess) { status = run(stream);