Fix CUDA_PERROR_EXIT and print failing expression (#446)

`CUDA_PERROR_EXIT ` can lead to incorrect usage (see e.g. [this description](https://www.cs.technion.ac.il/users/yechiel/c++-faq/macros-with-if.html)) because it contains an incomplete `if` expression. Consider:

```
if (condition)
    CUDA_PERROR_EXIT(cudaFree(x))
else
    free(x);
```

The author of the code forgot to add a semicolon after the macro. In that case, the `else` will bind to the `if` inside the macro definition, leading to code that the author did not intend or expect. It the author does use a semicolon, the code will not compile, which is awkward.

The change adds a `do while` around the `if`, which always requires a semicolon.

This PR also adds the text of the failing expression to the printed error message.
This commit is contained in:
Andrei Alexandrescu 2022-04-24 16:29:43 -04:00 committed by GitHub
parent 310ed81ac3
commit d7b499deff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -80,6 +80,7 @@
* \return The CUDA error.
*/
__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
const char* expression,
const char* filename,
int line) {
(void)filename;
@ -87,10 +88,10 @@ __host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
if (error) {
#if !defined(__CUDA_ARCH__)
fprintf(
stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error));
stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error));
fflush(stderr);
#else
printf("CUDA error %d [%s, %d]\n", error, filename, line);
printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression);
#endif
}
return error;
@ -100,7 +101,7 @@ __host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
* \brief Perror macro
*/
#ifndef CUDA_PERROR
#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)
#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)
#endif
/**
@ -108,9 +109,9 @@ __host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error,
*/
#ifndef CUDA_PERROR_EXIT
#define CUDA_PERROR_EXIT(e) \
if (cuda_perror_impl((cudaError_t)(e), __FILE__, __LINE__)) { \
do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \
exit(1); \
}
} } while (0)
#endif
/**