[Model][Phi3-Small] Remove scipy from blocksparse_attention (#6343)
This commit is contained in:
parent
adf32e0a0f
commit
d59eb98489
@ -4,16 +4,35 @@
|
|||||||
|
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
try:
|
|
||||||
from scipy import sparse
|
class csr_matrix:
|
||||||
except ImportError as err:
|
"""Simple implementation of CSR matrix conversion without scipy.
|
||||||
raise ImportError("Please install scipy via "
|
This replaced scipy.sparse.csr_matrix() previously used."""
|
||||||
"`pip install scipy` to use "
|
|
||||||
"BlockSparseAttention in "
|
def __init__(self, input_array):
|
||||||
"models such as Phi-3.") from err
|
if not isinstance(input_array, np.ndarray):
|
||||||
|
raise ValueError("Input must be a NumPy array")
|
||||||
|
|
||||||
|
self.shape = input_array.shape
|
||||||
|
rows, cols = self.shape
|
||||||
|
data = []
|
||||||
|
indices = []
|
||||||
|
indptr = [0]
|
||||||
|
|
||||||
|
for i in range(rows):
|
||||||
|
for j in range(cols):
|
||||||
|
if input_array[i, j]:
|
||||||
|
data.append(input_array[i, j])
|
||||||
|
indices.append(j)
|
||||||
|
indptr.append(len(indices))
|
||||||
|
|
||||||
|
self.data = np.array(data)
|
||||||
|
self.indices = np.array(indices)
|
||||||
|
self.indptr = np.array(indptr)
|
||||||
|
|
||||||
|
|
||||||
def dense_to_crow_col(x: torch.Tensor):
|
def dense_to_crow_col(x: torch.Tensor):
|
||||||
@ -26,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor):
|
|||||||
assert x.dim() in (2, 3)
|
assert x.dim() in (2, 3)
|
||||||
if x.dim() == 2:
|
if x.dim() == 2:
|
||||||
x = x[None]
|
x = x[None]
|
||||||
x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x]
|
||||||
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x])
|
||||||
cols = [torch.from_numpy(xi.indices) for xi in x]
|
cols = [torch.from_numpy(xi.indices) for xi in x]
|
||||||
max_cols = max(len(xi) for xi in cols)
|
max_cols = max(len(xi) for xi in cols)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user