Remove contiguous checks
This commit is contained in:
parent
3f1b4d38e7
commit
5d5bfbb619
@ -287,10 +287,6 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous");
|
||||
TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous");
|
||||
TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
|
||||
Loading…
Reference in New Issue
Block a user