diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 7a6954ce..c45f7b28 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -116,7 +116,7 @@ class PallasAttentionBackendImpl(AttentionImpl): self.megacore_mode = None tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() - if not tpu_type.endswith("lite"): + if "lite" not in tpu_type: if self.num_kv_heads % 2 == 0: self.megacore_mode = "kv_head" else: