[TPU] Update pallas.py to support trillium (#8871)
This commit is contained in:
parent
6d792d2f31
commit
8df2dc3c88
@ -130,7 +130,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
assert tpu_type is not None
|
assert tpu_type is not None
|
||||||
tpu_type = tpu_type.lower()
|
tpu_type = tpu_type.lower()
|
||||||
|
|
||||||
if "lite" not in tpu_type:
|
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
|
||||||
if self.num_kv_heads % 2 == 0:
|
if self.num_kv_heads % 2 == 0:
|
||||||
self.megacore_mode = "kv_head"
|
self.megacore_mode = "kv_head"
|
||||||
else:
|
else:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user