[Bugfix] Remove xformers requirement for Pixtral (#9597)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
parent
59449095ab
commit
c91ed47c43
@ -14,8 +14,6 @@ from transformers.models.pixtral.image_processing_pixtral import (
|
|||||||
_num_image_tokens)
|
_num_image_tokens)
|
||||||
from transformers.models.pixtral.modeling_pixtral import (
|
from transformers.models.pixtral.modeling_pixtral import (
|
||||||
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
PixtralRotaryEmbedding, apply_rotary_pos_emb, position_ids_in_meshgrid)
|
||||||
from xformers.ops.fmha import memory_efficient_attention
|
|
||||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
|
||||||
@ -38,6 +36,12 @@ from vllm.utils import is_list_of
|
|||||||
from .interfaces import SupportsMultiModal, SupportsPP
|
from .interfaces import SupportsMultiModal, SupportsPP
|
||||||
from .utils import init_vllm_registered_model
|
from .utils import init_vllm_registered_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xformers import ops as xops
|
||||||
|
USE_XFORMERS_OPS = True
|
||||||
|
except ImportError:
|
||||||
|
USE_XFORMERS_OPS = False
|
||||||
|
|
||||||
|
|
||||||
def get_max_pixtral_image_tokens(ctx: InputContext):
|
def get_max_pixtral_image_tokens(ctx: InputContext):
|
||||||
tokenizer = cached_get_tokenizer(
|
tokenizer = cached_get_tokenizer(
|
||||||
@ -416,7 +420,7 @@ class Attention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
mask: BlockDiagonalMask,
|
mask: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch, patches, _ = x.shape
|
batch, patches, _ = x.shape
|
||||||
@ -427,7 +431,7 @@ class Attention(nn.Module):
|
|||||||
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||||
|
|
||||||
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
q, k = apply_rotary_emb_vit(q, k, freqs_cis=freqs_cis)
|
||||||
out = memory_efficient_attention(q, k, v, attn_bias=mask)
|
out = xops.memory_efficient_attention(q, k, v, attn_bias=mask)
|
||||||
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||||
return self.wo(out)
|
return self.wo(out)
|
||||||
|
|
||||||
@ -444,7 +448,7 @@ class TransformerBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
mask: BlockDiagonalMask,
|
mask: torch.Tensor,
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r = self.attention.forward(self.attention_norm(x),
|
r = self.attention.forward(self.attention_norm(x),
|
||||||
@ -467,7 +471,7 @@ class Transformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
mask: BlockDiagonalMask,
|
mask: torch.Tensor,
|
||||||
freqs_cis: Optional[torch.Tensor],
|
freqs_cis: Optional[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@ -562,8 +566,12 @@ class VisionTransformer(nn.Module):
|
|||||||
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
|
freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
|
||||||
|
|
||||||
# pass through Transformer with a block diagonal mask delimiting images
|
# pass through Transformer with a block diagonal mask delimiting images
|
||||||
mask = BlockDiagonalMask.from_seqlens(
|
if USE_XFORMERS_OPS:
|
||||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
||||||
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
||||||
|
else:
|
||||||
|
raise ImportError("Xformers is required for Pixtral inference "
|
||||||
|
"with the Mistral format")
|
||||||
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
out = self.transformer(patch_embeds, mask=mask, freqs_cis=freqs_cis)
|
||||||
|
|
||||||
# remove batch dimension of the single sequence
|
# remove batch dimension of the single sequence
|
||||||
@ -828,7 +836,7 @@ class PixtralHFAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: BlockDiagonalMask,
|
attention_mask: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
batch, patches, _ = hidden_states.size()
|
batch, patches, _ = hidden_states.size()
|
||||||
@ -843,12 +851,23 @@ class PixtralHFAttention(nn.Module):
|
|||||||
cos, sin = position_embeddings
|
cos, sin = position_embeddings
|
||||||
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
|
q, k = apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=0)
|
||||||
|
|
||||||
# Transpose q and k back for attention
|
if USE_XFORMERS_OPS:
|
||||||
q = q.transpose(1, 2).contiguous()
|
# Transpose q and k back for attention
|
||||||
k = k.transpose(1, 2).contiguous()
|
q = q.transpose(1, 2).contiguous()
|
||||||
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
k = k.transpose(1, 2).contiguous()
|
||||||
|
v = v.reshape(batch, patches, self.n_heads, self.head_dim)
|
||||||
|
|
||||||
|
out = xops.memory_efficient_attention(q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
attn_bias=attention_mask)
|
||||||
|
else:
|
||||||
|
v = v.reshape(batch, patches, self.n_heads,
|
||||||
|
self.head_dim).transpose(1, 2)
|
||||||
|
out = nn.functional.scaled_dot_product_attention(
|
||||||
|
q, k, v, attn_mask=attention_mask)
|
||||||
|
out = out.transpose(1, 2)
|
||||||
|
|
||||||
out = memory_efficient_attention(q, k, v, attn_bias=attention_mask)
|
|
||||||
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
out = out.reshape(batch, patches, self.n_heads * self.head_dim)
|
||||||
|
|
||||||
return self.o_proj(out)
|
return self.o_proj(out)
|
||||||
@ -877,7 +896,7 @@ class PixtralHFTransformerBlock(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: BlockDiagonalMask,
|
attention_mask: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r = self.attention.forward(self.attention_norm(hidden_states),
|
r = self.attention.forward(self.attention_norm(hidden_states),
|
||||||
@ -916,7 +935,7 @@ class PixtralHFTransformer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
attention_mask: BlockDiagonalMask,
|
attention_mask: torch.Tensor,
|
||||||
position_embeddings: torch.Tensor,
|
position_embeddings: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@ -1000,11 +1019,19 @@ class PixtralHFVisionModel(nn.Module):
|
|||||||
patch_embeds_list,
|
patch_embeds_list,
|
||||||
max_width=self.config.image_size // self.config.patch_size).to(
|
max_width=self.config.image_size // self.config.patch_size).to(
|
||||||
self.device)
|
self.device)
|
||||||
|
|
||||||
position_embedding = self.patch_positional_embedding(
|
position_embedding = self.patch_positional_embedding(
|
||||||
patch_embeds, position_ids)
|
patch_embeds, position_ids)
|
||||||
attention_mask = BlockDiagonalMask.from_seqlens(
|
|
||||||
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
if USE_XFORMERS_OPS:
|
||||||
|
attention_mask = xops.fmha.attn_bias.BlockDiagonalMask.from_seqlens(
|
||||||
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list], )
|
||||||
|
else:
|
||||||
|
from transformers.models.pixtral.modeling_pixtral import (
|
||||||
|
generate_block_attention_mask)
|
||||||
|
attention_mask = generate_block_attention_mask(
|
||||||
|
[p.shape[-2] * p.shape[-1] for p in patch_embeds_list],
|
||||||
|
patch_embeds)
|
||||||
|
|
||||||
out = self.transformer(patch_embeds, attention_mask,
|
out = self.transformer(patch_embeds, attention_mask,
|
||||||
position_embedding)
|
position_embedding)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user