[Model][Phi3-Small] Remove scipy from blocksparse_attention (#6343)

This commit is contained in:
Michael Goin 2024-07-11 22:47:17 -04:00 committed by GitHub
parent adf32e0a0f
commit d59eb98489
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,16 +4,35 @@
from functools import lru_cache
import numpy as np
import torch
import triton
try:
from scipy import sparse
except ImportError as err:
raise ImportError("Please install scipy via "
"`pip install scipy` to use "
"BlockSparseAttention in "
"models such as Phi-3.") from err
class csr_matrix:
"""Simple implementation of CSR matrix conversion without scipy.
This replaced scipy.sparse.csr_matrix() previously used."""
def __init__(self, input_array):
if not isinstance(input_array, np.ndarray):
raise ValueError("Input must be a NumPy array")
self.shape = input_array.shape
rows, cols = self.shape
data = []
indices = []
indptr = [0]
for i in range(rows):
for j in range(cols):
if input_array[i, j]:
data.append(input_array[i, j])
indices.append(j)
indptr.append(len(indices))
self.data = np.array(data)
self.indices = np.array(indices)
self.indptr = np.array(indptr)
def dense_to_crow_col(x: torch.Tensor):
@ -26,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
assert x.dim() in (2, 3)
if x.dim() == 2:
x = x[None]
x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
cols = [torch.from_numpy(xi.indices) for xi in x]
max_cols = max(len(xi) for xi in cols)