[Bugfix] Add verbose error if scipy is missing for blocksparse attention (#5695)
This commit is contained in:
parent
f1e15da6fe
commit
e58294ddf2
@ -6,7 +6,14 @@ from functools import lru_cache
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from scipy import sparse
|
||||
|
||||
try:
|
||||
from scipy import sparse
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install scipy via "
|
||||
"`pip install scipy` to use "
|
||||
"BlockSparseAttention in "
|
||||
"models such as Phi-3.") from err
|
||||
|
||||
|
||||
def dense_to_crow_col(x: torch.Tensor):
|
||||
@ -77,11 +84,11 @@ def _get_sparse_attn_mask_homo_head(
|
||||
):
|
||||
"""
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be
|
||||
OOM if it is too big) if `return_dense==True`,
|
||||
- all token dense mask (be aware that it can be
|
||||
OOM if it is too big) if `return_dense==True`,
|
||||
otherwise, None
|
||||
"""
|
||||
with torch.no_grad():
|
||||
@ -148,10 +155,10 @@ def get_sparse_attn_mask(
|
||||
:param dense_mask_type: "binary" (0 for skip token, 1 for others)
|
||||
or "bias" (-inf for skip token, 0 or others)
|
||||
:return: a tuple of 3:
|
||||
- tuple of crow_indices, col_indices representation
|
||||
- tuple of crow_indices, col_indices representation
|
||||
of CSR format.
|
||||
- block dense mask
|
||||
- all token dense mask (be aware that it can be OOM if it
|
||||
- all token dense mask (be aware that it can be OOM if it
|
||||
is too big) if `return_dense==True`, otherwise, None
|
||||
"""
|
||||
assert dense_mask_type in ("binary", "bias")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user