Update logic for checking TPUs availability (#6767)
* Update logic for checking TPUs availability * fix flake8 * add fix
This commit is contained in:
parent
a72a7992a2
commit
13f67ad313
|
@ -14,6 +14,7 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
@ -23,11 +24,11 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.plugins.training_type.utils import on_colab_kaggle
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn, _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, _TPU_AVAILABLE, rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
@ -39,8 +40,7 @@ else:
|
|||
xm, xla_pl, xmp, ParallelLoader, rendezvous = [None] * 5
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf import DictConfig, ListConfig
|
||||
from omegaconf import DictConfig, ListConfig, OmegaConf
|
||||
|
||||
|
||||
class TPUSpawnPlugin(DDPSpawnPlugin):
|
||||
|
@ -118,6 +118,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
self.__save_end_of_training_weights(self.lightning_module)
|
||||
self.transfer_distrib_spawn_state_on_fit_end(results)
|
||||
|
||||
if self.global_rank == 0:
|
||||
time.sleep(2)
|
||||
|
||||
self.barrier("end-process")
|
||||
|
||||
def __save_end_of_training_weights(self, model: LightningModule) -> None:
|
||||
|
|
|
@ -17,13 +17,10 @@ import queue as q
|
|||
import traceback
|
||||
from multiprocessing import Process, Queue
|
||||
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
|
||||
|
||||
if _XLA_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
#: define waiting time got checking TPU available in sec
|
||||
TPU_CHECK_TIMEOUT = 25
|
||||
|
@ -64,23 +61,13 @@ class XLADeviceUtils:
|
|||
@pl_multi_process
|
||||
def _is_device_tpu() -> bool:
|
||||
"""
|
||||
Check if device is TPU
|
||||
Check if TPU devices are available
|
||||
|
||||
Return:
|
||||
A boolean value indicating if the xla device is a TPU device or not
|
||||
A boolean value indicating if TPU devices are available
|
||||
"""
|
||||
|
||||
def _fn(_: int, mp_queue):
|
||||
try:
|
||||
device = xm.xla_device()
|
||||
mp_queue.put(device.type == 'xla')
|
||||
except Exception:
|
||||
mp_queue.put(False)
|
||||
|
||||
smp = mp.get_context("spawn")
|
||||
queue = smp.SimpleQueue()
|
||||
xmp.spawn(_fn, args=(queue, ), nprocs=1)
|
||||
return queue.get()
|
||||
return len(xm.get_xla_supported_devices("TPU")) > 0
|
||||
|
||||
@staticmethod
|
||||
def xla_available() -> bool:
|
||||
|
|
Loading…
Reference in New Issue