[Bugfix] Mapping physical device indices for e2e test utils (#8290)
This commit is contained in:
parent
5ec9c0fb3c
commit
40c396533d
@ -356,12 +356,23 @@ def error_on_warning():
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def get_physical_device_indices(devices):
|
||||||
|
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
||||||
|
if visible_devices is None:
|
||||||
|
return devices
|
||||||
|
|
||||||
|
visible_indices = [int(x) for x in visible_devices.split(",")]
|
||||||
|
index_mapping = {i: physical for i, physical in enumerate(visible_indices)}
|
||||||
|
return [index_mapping[i] for i in devices if i in index_mapping]
|
||||||
|
|
||||||
|
|
||||||
@_nvml()
|
@_nvml()
|
||||||
def wait_for_gpu_memory_to_clear(devices: List[int],
|
def wait_for_gpu_memory_to_clear(devices: List[int],
|
||||||
threshold_bytes: int,
|
threshold_bytes: int,
|
||||||
timeout_s: float = 120) -> None:
|
timeout_s: float = 120) -> None:
|
||||||
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
# Use nvml instead of pytorch to reduce measurement error from torch cuda
|
||||||
# context.
|
# context.
|
||||||
|
devices = get_physical_device_indices(devices)
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
while True:
|
while True:
|
||||||
output: Dict[int, str] = {}
|
output: Dict[int, str] = {}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user