[TPU][Bugfix] Fix tpu type api (#8035)
This commit is contained in:
parent
058344f89a
commit
2684efc467
@ -124,7 +124,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
self.megacore_mode = None
|
||||
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
|
||||
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
||||
or tpu_env.get("TYPE", None)
|
||||
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
||||
assert tpu_type is not None
|
||||
tpu_type = tpu_type.lower()
|
||||
|
||||
if "lite" not in tpu_type:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user