diff --git a/cutlass/util/debug.h b/cutlass/util/debug.h index 81650932..3a4b2fd0 100644 --- a/cutlass/util/debug.h +++ b/cutlass/util/debug.h @@ -44,10 +44,26 @@ namespace cutlass { * Formats and prints the given message to stdout */ #if !defined(CUDA_LOG) - #if !defined(__CUDA_ARCH__) - #define CUDA_LOG(format, ...) printf(format,__VA_ARGS__) - #else - #define CUDA_LOG(format, ...) printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, __VA_ARGS__); + #if defined(__clang__) && defined(__CUDA__) +static __device__ void cuda_log_location() { + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: ", blockIdx.x, blockIdx.y, + blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z); +} +static __host__ void cuda_log_location() {} + #define CUDA_LOG(format, ...) \ + do { \ + cuda_log_location(); \ + printf(format, __VA_ARGS__); \ + } while (0) + #else // NVCC + #if !defined(__CUDA_ARCH__) + #define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) + #else + #define CUDA_LOG(format, ...) \ + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ + blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, \ + threadIdx.y, threadIdx.z, __VA_ARGS__); + #endif #endif #endif