Use lru_cache for some environment detection utils (#3508)
This commit is contained in:
parent
63e8b28a99
commit
20478c4d3a
@ -11,7 +11,7 @@ from packaging.version import parse, Version
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import asyncio
|
import asyncio
|
||||||
from functools import partial
|
from functools import partial, lru_cache
|
||||||
from typing import (
|
from typing import (
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
@ -120,6 +120,7 @@ def is_hip() -> bool:
|
|||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
def is_neuron() -> bool:
|
def is_neuron() -> bool:
|
||||||
try:
|
try:
|
||||||
import transformers_neuronx
|
import transformers_neuronx
|
||||||
@ -128,6 +129,7 @@ def is_neuron() -> bool:
|
|||||||
return transformers_neuronx is not None
|
return transformers_neuronx is not None
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
"""Returns the maximum shared memory per thread block in bytes."""
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
# NOTE: This import statement should be executed lazily since
|
# NOTE: This import statement should be executed lazily since
|
||||||
@ -151,6 +153,7 @@ def random_uuid() -> str:
|
|||||||
return str(uuid.uuid4().hex)
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
def in_wsl() -> bool:
|
def in_wsl() -> bool:
|
||||||
# Reference: https://github.com/microsoft/WSL/issues/4071
|
# Reference: https://github.com/microsoft/WSL/issues/4071
|
||||||
return "microsoft" in " ".join(uname()).lower()
|
return "microsoft" in " ".join(uname()).lower()
|
||||||
@ -225,6 +228,7 @@ def set_cuda_visible_devices(device_ids: List[int]) -> None:
|
|||||||
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
|
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
def get_nvcc_cuda_version() -> Optional[Version]:
|
def get_nvcc_cuda_version() -> Optional[Version]:
|
||||||
cuda_home = os.environ.get('CUDA_HOME')
|
cuda_home = os.environ.get('CUDA_HOME')
|
||||||
if not cuda_home:
|
if not cuda_home:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user