standarize fp8 generator (#1078)

This commit is contained in:
ANIKET SHIVAM 2023-09-07 11:36:33 -07:00 committed by GitHub
parent 88c0d7c726
commit 34bbadd3ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4614,14 +4614,14 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version):
# some schedules disabled to save on library size # some schedules disabled to save on library size
if CudaToolkitVersionSatisfies(cuda_version, 12, 1): if CudaToolkitVersionSatisfies(cuda_version, 12, 1):
schedules = [ schedules = [
[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], #[KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto],
[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
# [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized], [KernelScheduleType.TmaWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedPingpongFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized], [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized],
# [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized] [KernelScheduleType.TmaWarpSpecializedFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]
] ]
stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], stream_k_schedules = [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized],
[KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]] [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.NoSmemWarpSpecialized]]
else: else:
schedules = [ schedules = [