[TPU][Bugfix] Fix tpu type api (#8035)
This commit is contained in:
parent
058344f89a
commit
2684efc467
@ -124,7 +124,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||||||
|
|
||||||
self.megacore_mode = None
|
self.megacore_mode = None
|
||||||
tpu_env = torch_xla.tpu.get_tpu_env()
|
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||||
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
|
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
||||||
|
or tpu_env.get("TYPE", None)
|
||||||
|
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
||||||
|
assert tpu_type is not None
|
||||||
tpu_type = tpu_type.lower()
|
tpu_type = tpu_type.lower()
|
||||||
|
|
||||||
if "lite" not in tpu_type:
|
if "lite" not in tpu_type:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user