diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index b4a4b48..4aa8474 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -1,4 +1,5 @@ -// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h #pragma once @@ -13,53 +14,53 @@ /// some_function(...); /// }); /// ``` -#define BOOL_SWITCH(COND, CONST_NAME, ...) \ - [&] { \ - if (COND) { \ - constexpr bool CONST_NAME = true; \ - return __VA_ARGS__(); \ - } else { \ - constexpr bool CONST_NAME = false; \ - return __VA_ARGS__(); \ - } \ - }() +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() -#define FP16_SWITCH(COND, ...) \ - [&] { \ - if (COND) { \ - using elem_type = cutlass::half_t; \ - return __VA_ARGS__(); \ - } else { \ - using elem_type = cutlass::bfloat16_t; \ - return __VA_ARGS__(); \ - } \ - }() +#define FP16_SWITCH(COND, ...) \ + [&] { \ + if (COND) { \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + } else { \ + using elem_type = cutlass::bfloat16_t; \ + return __VA_ARGS__(); \ + } \ + }() -#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ - [&] { \ - if (HEADDIM <= 32) { \ - constexpr int kHeadDim = 32; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 64) { \ - constexpr int kHeadDim = 64; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 96) { \ - constexpr int kHeadDim = 96; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 128) { \ - constexpr int kHeadDim = 128; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 160) { \ - constexpr int kHeadDim = 160; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 192) { \ - constexpr int kHeadDim = 192; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 224) { \ - constexpr int kHeadDim = 224; \ - return __VA_ARGS__(); \ - } else if (HEADDIM <= 256) { \ - constexpr int kHeadDim = 256; \ - return __VA_ARGS__(); \ - } \ - }() +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }()