diff --git a/include/cutlass/layout/permute.h b/include/cutlass/layout/permute.h index 9642ebc5..c7b01530 100644 --- a/include/cutlass/layout/permute.h +++ b/include/cutlass/layout/permute.h @@ -119,9 +119,6 @@ private: Index stride_permute_; - Index col_permute_; - Index row_permute_; - public: // // Methods @@ -143,7 +140,7 @@ public: /// Computes the address offset after Permute Op in Bytes CUTLASS_HOST_DEVICE LongIndex operator()(MatrixCoord offset_init) { - // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X + // Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i, j, k, l], the dimension of X // is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3]. assert(extent_.row() % D1 == 0); assert(extent_.column() % D2 == 0); @@ -159,10 +156,10 @@ public: int i = row_init / D1; // After the Permute Op - col_permute_ = l + j * D3; - row_permute_ = k + i * D2; + Index col_permute = l + j * D3; + Index row_permute = k + i * D2; - return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_); + return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); } /// Return D1 @@ -198,9 +195,6 @@ private: Index stride_permute_; - Index col_permute_; - Index row_permute_; - public: // // Methods @@ -240,10 +234,10 @@ public: int i = BMM_batch_idx / D1; // After the Permute Op - col_permute_ = l + j * D3; - row_permute_ = k + i * D2; + Index col_permute = l + j * D3; + Index row_permute = k + i * D2; - return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_); + return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); } /// Return D1 @@ -273,9 +267,6 @@ private: Index stride_permute_; - Index col_permute_; - Index row_permute_; - public: // // Methods @@ -313,10 +304,10 @@ public: int i = row_init / T1; // After the Permute Op - col_permute_ = m + j * T4 + l * T1 * T4; - row_permute_ = i + k * T0; + Index col_permute = m + j * T4 + l * T1 * T4; + Index row_permute = i + k * T0; - return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_); + return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute); } };