From d59eb98489103877e9476ef5263305aa3e3f9e23 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Thu, 11 Jul 2024 22:47:17 -0400 Subject: [PATCH] [Model][Phi3-Small] Remove scipy from blocksparse_attention (#6343) --- .../ops/blocksparse_attention/utils.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/vllm/attention/ops/blocksparse_attention/utils.py b/vllm/attention/ops/blocksparse_attention/utils.py index b1808970..78d75223 100644 --- a/vllm/attention/ops/blocksparse_attention/utils.py +++ b/vllm/attention/ops/blocksparse_attention/utils.py @@ -4,16 +4,35 @@ from functools import lru_cache +import numpy as np import torch import triton -try: - from scipy import sparse -except ImportError as err: - raise ImportError("Please install scipy via " - "`pip install scipy` to use " - "BlockSparseAttention in " - "models such as Phi-3.") from err + +class csr_matrix: + """Simple implementation of CSR matrix conversion without scipy. + This replaced scipy.sparse.csr_matrix() previously used.""" + + def __init__(self, input_array): + if not isinstance(input_array, np.ndarray): + raise ValueError("Input must be a NumPy array") + + self.shape = input_array.shape + rows, cols = self.shape + data = [] + indices = [] + indptr = [0] + + for i in range(rows): + for j in range(cols): + if input_array[i, j]: + data.append(input_array[i, j]) + indices.append(j) + indptr.append(len(indices)) + + self.data = np.array(data) + self.indices = np.array(indices) + self.indptr = np.array(indptr) def dense_to_crow_col(x: torch.Tensor): @@ -26,7 +45,7 @@ def dense_to_crow_col(x: torch.Tensor): assert x.dim() in (2, 3) if x.dim() == 2: x = x[None] - x = [sparse.csr_matrix(xi.bool().cpu().numpy()) for xi in x] + x = [csr_matrix(xi.bool().cpu().numpy()) for xi in x] crows = torch.vstack([torch.from_numpy(xi.indptr) for xi in x]) cols = [torch.from_numpy(xi.indices) for xi in x] max_cols = max(len(xi) for xi in cols)