From bb01f2915eb3ade94b086033d7f2a6fe7de3c067 Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Wed, 23 Oct 2024 22:03:44 -0400 Subject: [PATCH] [Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image (#9626) Signed-off-by: mgoin --- vllm/model_executor/models/mllama.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 23e2b520..475364f3 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -795,17 +795,19 @@ class MllamaTextCrossAttention(nn.Module): kv_len = k.shape[0] q = q.transpose(0, 1).view(self.num_local_key_value_heads, self.num_key_value_groups, q_len, - self.head_dim) + self.head_dim).contiguous() k = k.transpose(0, 1)[:, None, :, :].expand(self.num_local_key_value_heads, self.num_key_value_groups, - kv_len, self.head_dim) + kv_len, + self.head_dim).contiguous() v = v.transpose(0, 1)[:, None, :, :].expand(self.num_local_key_value_heads, self.num_key_value_groups, - kv_len, self.head_dim) + kv_len, + self.head_dim).contiguous() attention_mask = attention_mask.view(1, 1, q_len, kv_len) output = F.scaled_dot_product_attention(q, k,