change unused class member to local var (#646)
This commit is contained in:
parent
48a9ea223a
commit
cd37e82492
@ -119,9 +119,6 @@ private:
|
||||
|
||||
Index stride_permute_;
|
||||
|
||||
Index col_permute_;
|
||||
Index row_permute_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
@ -143,7 +140,7 @@ public:
|
||||
/// Computes the address offset after Permute Op in Bytes
|
||||
CUTLASS_HOST_DEVICE
|
||||
LongIndex operator()(MatrixCoord offset_init) {
|
||||
// Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i,j,k,l], the dimension of X
|
||||
// Permute as torch.permute(X1, [0, 2, 1, 3]) -> 4D Tensor indices as [i, j, k, l], the dimension of X
|
||||
// is [D0, D1, D2, D3], after permutation the dim of X1 is [D0, D2, D1, D3].
|
||||
assert(extent_.row() % D1 == 0);
|
||||
assert(extent_.column() % D2 == 0);
|
||||
@ -159,10 +156,10 @@ public:
|
||||
int i = row_init / D1;
|
||||
|
||||
// After the Permute Op
|
||||
col_permute_ = l + j * D3;
|
||||
row_permute_ = k + i * D2;
|
||||
Index col_permute = l + j * D3;
|
||||
Index row_permute = k + i * D2;
|
||||
|
||||
return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_);
|
||||
return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute);
|
||||
}
|
||||
|
||||
/// Return D1
|
||||
@ -198,9 +195,6 @@ private:
|
||||
|
||||
Index stride_permute_;
|
||||
|
||||
Index col_permute_;
|
||||
Index row_permute_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
@ -240,10 +234,10 @@ public:
|
||||
int i = BMM_batch_idx / D1;
|
||||
|
||||
// After the Permute Op
|
||||
col_permute_ = l + j * D3;
|
||||
row_permute_ = k + i * D2;
|
||||
Index col_permute = l + j * D3;
|
||||
Index row_permute = k + i * D2;
|
||||
|
||||
return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_);
|
||||
return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute);
|
||||
}
|
||||
|
||||
/// Return D1
|
||||
@ -273,9 +267,6 @@ private:
|
||||
|
||||
Index stride_permute_;
|
||||
|
||||
Index col_permute_;
|
||||
Index row_permute_;
|
||||
|
||||
public:
|
||||
//
|
||||
// Methods
|
||||
@ -313,10 +304,10 @@ public:
|
||||
int i = row_init / T1;
|
||||
|
||||
// After the Permute Op
|
||||
col_permute_ = m + j * T4 + l * T1 * T4;
|
||||
row_permute_ = i + k * T0;
|
||||
Index col_permute = m + j * T4 + l * T1 * T4;
|
||||
Index row_permute = i + k * T0;
|
||||
|
||||
return LongIndex(row_permute_) * LongIndex(stride_permute_) + LongIndex(col_permute_);
|
||||
return LongIndex(row_permute) * LongIndex(stride_permute_) + LongIndex(col_permute);
|
||||
}
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user