change unused class member to local var (#646)

This commit is contained in:
Wenzhuo Liu 2022-09-29 11:52:35 +08:00 committed by GitHub
parent 48a9ea223a
commit cd37e82492
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -119,9 +119,6 @@ private:
Index stride_permute_;
Index col_permute_;
Index row_permute_;
public:
//
// Methods
@ -143,7 +140,7 @@ public:
/// Computes the address offset after Permute Op in Bytes
CUTLASS_HOST_DEVICE
LongIndex operator()(MatrixCoord offset_init) {
// Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X
// Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i, j, k, l], the dimension of X
// is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3].
assert(extent_.row() % D1 == 0);
assert(extent_.column() % D2 == 0);
@ -159,10 +156,10 @@ public:
int i = row_init / D1;
// After the Permute Op
col_permute_ = l + j * D3;
row_permute_ = k + i * D2;
Index col_permute = l + j * D3;
Index row_permute = k + i * D2;
return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_);
return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute);
}
/// Return D1
@ -198,9 +195,6 @@ private:
Index stride_permute_;
Index col_permute_;
Index row_permute_;
public:
//
// Methods
@ -240,10 +234,10 @@ public:
int i = BMM_batch_idx / D1;
// After the Permute Op
col_permute_ = l + j * D3;
row_permute_ = k + i * D2;
Index col_permute = l + j * D3;
Index row_permute = k + i * D2;
return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_);
return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute);
}
/// Return D1
@ -273,9 +267,6 @@ private:
Index stride_permute_;
Index col_permute_;
Index row_permute_;
public:
//
// Methods
@ -313,10 +304,10 @@ public:
int i = row_init / T1;
// After the Permute Op
col_permute_ = m + j * T4 + l * T1 * T4;
row_permute_ = i + k * T0;
Index col_permute = m + j * T4 + l * T1 * T4;
Index row_permute = i + k * T0;
return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_);
return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute);
}
};