set xxx_AVAILABLE as protected (#5082)
* sett xxx_AVAILABLE as protected * docs
This commit is contained in:
parent
0327f6b4c2
commit
059eaecbb4
|
@ -9,14 +9,14 @@ import torch
|
|||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
|
||||
from tests.backends import DDPLauncher
|
||||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_one_device():
|
||||
plugin_parity_test(
|
||||
accelerator='ddp_cpu',
|
||||
|
@ -29,7 +29,7 @@ def test_ddp_sharded_plugin_correctness_one_device():
|
|||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_one_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=1,
|
||||
|
@ -39,11 +39,11 @@ def test_ddp_sharded_plugin_correctness_one_gpu():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=1,
|
||||
|
@ -58,7 +58,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
|
|||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=2,
|
||||
|
@ -69,11 +69,11 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=2,
|
||||
|
@ -85,11 +85,11 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=2,
|
||||
|
@ -101,7 +101,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
|
@ -116,7 +116,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
|
@ -135,7 +135,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
|
|||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
|
||||
"""
|
||||
Ensures same results using multiple optimizers across multiple GPUs
|
||||
|
@ -153,7 +153,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
|
|||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
|
||||
"""
|
||||
Ensures using multiple optimizers across multiple GPUs with manual optimization
|
||||
|
|
|
@ -31,7 +31,7 @@ Native torch
|
|||
When using PyTorch 1.6+ Lightning uses the native amp implementation to support 16-bit.
|
||||
|
||||
.. testcode::
|
||||
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE
|
||||
:skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE
|
||||
|
||||
# turn on 16-bit
|
||||
trainer = Trainer(precision=16)
|
||||
|
@ -73,7 +73,7 @@ Enable 16-bit
|
|||
^^^^^^^^^^^^^
|
||||
|
||||
.. testcode::
|
||||
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE
|
||||
:skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE
|
||||
|
||||
# turn on 16-bit
|
||||
trainer = Trainer(amp_level='O2', precision=16)
|
||||
|
@ -88,7 +88,7 @@ TPU 16-bit
|
|||
16-bit on TPUs is much simpler. To use 16-bit with TPUs set precision to 16 when using the TPU flag
|
||||
|
||||
.. testcode::
|
||||
:skipif: not TPU_AVAILABLE
|
||||
:skipif: not _TPU_AVAILABLE
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(tpu_cores=8, precision=32)
|
||||
|
|
|
@ -359,12 +359,12 @@ import os
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import (
|
||||
NATIVE_AMP_AVAILABLE,
|
||||
APEX_AVAILABLE,
|
||||
XLA_AVAILABLE,
|
||||
TPU_AVAILABLE,
|
||||
_NATIVE_AMP_AVAILABLE,
|
||||
_APEX_AVAILABLE,
|
||||
_XLA_AVAILABLE,
|
||||
_TPU_AVAILABLE,
|
||||
)
|
||||
TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
|
||||
_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
|
||||
|
||||
|
||||
"""
|
||||
|
|
|
@ -135,7 +135,7 @@ Data
|
|||
Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIST.
|
||||
|
||||
.. testcode::
|
||||
:skipif: not TORCHVISION_AVAILABLE
|
||||
:skipif: not _TORCHVISION_AVAILABLE
|
||||
|
||||
from torch.utils.data import DataLoader, random_split
|
||||
from torchvision.datasets import MNIST
|
||||
|
@ -153,7 +153,7 @@ Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIS
|
|||
|
||||
.. testoutput::
|
||||
:hide:
|
||||
:skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not TORCHVISION_AVAILABLE
|
||||
:skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not _TORCHVISION_AVAILABLE
|
||||
|
||||
Downloading ...
|
||||
Extracting ...
|
||||
|
|
|
@ -1180,7 +1180,7 @@ If used on TPU will use torch.bfloat16 but tensor printing
|
|||
will still show torch.float32.
|
||||
|
||||
.. testcode::
|
||||
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE
|
||||
:skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(precision=32)
|
||||
|
|
|
@ -46,7 +46,7 @@ Example: Imagenet (computer Vision)
|
|||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. testcode::
|
||||
:skipif: not TORCHVISION_AVAILABLE
|
||||
:skipif: not _TORCHVISION_AVAILABLE
|
||||
|
||||
import torchvision.models as models
|
||||
|
||||
|
|
|
@ -32,9 +32,9 @@ import pytorch_lightning as pl
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.metrics.functional import accuracy
|
||||
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
|
||||
from pytorch_lightning.utilities import BOLTS_AVAILABLE, FAIRSCALE_PIPE_AVAILABLE
|
||||
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE
|
||||
|
||||
if BOLTS_AVAILABLE:
|
||||
if _BOLTS_AVAILABLE:
|
||||
import pl_bolts
|
||||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
|
||||
|
||||
|
@ -196,8 +196,8 @@ if __name__ == "__main__":
|
|||
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
|
||||
assert FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
|
||||
assert _BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
|
||||
assert _FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
|
||||
|
||||
cifar10_dm = instantiate_datamodule(args)
|
||||
|
||||
|
|
|
@ -15,17 +15,17 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import HOROVOD_AVAILABLE
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning import accelerators
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
|
||||
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
|
||||
from pytorch_lightning.utilities import device_parser, rank_zero_only, TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import device_parser, rank_zero_only, _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if HOROVOD_AVAILABLE:
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
||||
|
||||
|
@ -348,7 +348,7 @@ class AcceleratorConnector:
|
|||
|
||||
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}')
|
||||
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
|
||||
rank_zero_info(f'TPU available: {TPU_AVAILABLE}, using: {num_cores} TPU cores')
|
||||
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')
|
||||
|
||||
if torch.cuda.is_available() and not self.trainer.on_gpu:
|
||||
rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.')
|
||||
|
@ -365,7 +365,7 @@ class AcceleratorConnector:
|
|||
|
||||
def check_horovod(self):
|
||||
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
|
||||
if not HOROVOD_AVAILABLE:
|
||||
if not _HOROVOD_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'Requested `accelerator="horovod"`, but Horovod is not installed.'
|
||||
'Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]'
|
||||
|
|
|
@ -26,10 +26,10 @@ from pytorch_lightning.core.step_result import Result
|
|||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
all_gather_ddp_if_available,
|
||||
find_free_network_port,
|
||||
|
@ -40,7 +40,7 @@ from pytorch_lightning.utilities.distributed import (
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
|
@ -94,7 +94,7 @@ class DDPAccelerator(Accelerator):
|
|||
os.environ['LOCAL_RANK'] = '0'
|
||||
|
||||
# when user is using hydra find the absolute path
|
||||
path_lib = abspath if not HYDRA_AVAILABLE else to_absolute_path
|
||||
path_lib = abspath if not _HYDRA_AVAILABLE else to_absolute_path
|
||||
|
||||
# pull out the commands used to run the script and resolve the abs file path
|
||||
command = sys.argv
|
||||
|
@ -135,7 +135,7 @@ class DDPAccelerator(Accelerator):
|
|||
# start process
|
||||
# if hydra is available and initialized, make sure to set the cwd correctly
|
||||
cwd: Optional[str] = None
|
||||
if HYDRA_AVAILABLE:
|
||||
if _HYDRA_AVAILABLE:
|
||||
if HydraConfig.initialized():
|
||||
cwd = get_original_cwd()
|
||||
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)
|
||||
|
|
|
@ -17,9 +17,9 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE
|
||||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
all_gather_ddp_if_available,
|
||||
find_free_network_port,
|
||||
|
@ -35,7 +35,7 @@ from pytorch_lightning.utilities.distributed import (
|
|||
sync_ddp_if_available,
|
||||
)
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
|
|
|
@ -26,10 +26,10 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.distributed import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
|
@ -39,7 +39,7 @@ from pytorch_lightning.utilities.distributed import (
|
|||
)
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
if _HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
|
|
|
@ -20,10 +20,10 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
if HOROVOD_AVAILABLE:
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from pytorch_lightning.cluster_environments import ClusterEnvironment
|
|||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.utilities import (
|
||||
TPU_AVAILABLE,
|
||||
_TPU_AVAILABLE,
|
||||
move_data_to_device,
|
||||
rank_zero_info,
|
||||
rank_zero_only,
|
||||
|
@ -35,7 +35,7 @@ from pytorch_lightning.utilities import (
|
|||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if TPU_AVAILABLE:
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.parallel_loader as xla_pl
|
||||
|
@ -63,7 +63,7 @@ class TPUAccelerator(Accelerator):
|
|||
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
|
||||
|
||||
# TODO: Move this check to Trainer __init__ or device parser
|
||||
if not TPU_AVAILABLE:
|
||||
if not _TPU_AVAILABLE:
|
||||
raise MisconfigurationException('PyTorch XLA not installed.')
|
||||
|
||||
# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
|
||||
|
@ -179,7 +179,7 @@ class TPUAccelerator(Accelerator):
|
|||
See Also:
|
||||
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
|
||||
"""
|
||||
if not TPU_AVAILABLE:
|
||||
if not _TPU_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'Requested to transfer batch to TPU but XLA is not available.'
|
||||
' Are you sure this machine has TPUs?'
|
||||
|
|
|
@ -26,7 +26,7 @@ import torch
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, _TPU_AVAILABLE
|
||||
|
||||
|
||||
class EarlyStopping(Callback):
|
||||
|
@ -204,7 +204,7 @@ class EarlyStopping(Callback):
|
|||
if not isinstance(current, torch.Tensor):
|
||||
current = torch.tensor(current, device=pl_module.device)
|
||||
|
||||
if trainer.use_tpu and TPU_AVAILABLE:
|
||||
if trainer.use_tpu and _TPU_AVAILABLE:
|
||||
current = current.cpu()
|
||||
|
||||
if self.monitor_op(current - self.min_delta, self.best_score):
|
||||
|
|
|
@ -37,12 +37,12 @@ from pytorch_lightning.core.memory import ModelSummary
|
|||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
|
||||
|
||||
if TPU_AVAILABLE:
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
|
|
|
@ -17,10 +17,10 @@ from weakref import proxy
|
|||
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from pytorch_lightning.utilities import TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if TPU_AVAILABLE:
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch
|
|||
import yaml
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.parsing import parse_class_init_keys
|
||||
|
@ -32,7 +32,7 @@ from pytorch_lightning.utilities.parsing import parse_class_init_keys
|
|||
PRIMITIVE_TYPES = (bool, int, float, str)
|
||||
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
|
||||
|
||||
if OMEGACONF_AVAILABLE:
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# the older shall be on the top
|
||||
|
@ -360,7 +360,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
|||
hparams = dict(hparams)
|
||||
|
||||
# saving with OmegaConf objects
|
||||
if OMEGACONF_AVAILABLE:
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
if OmegaConf.is_config(hparams):
|
||||
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
|
||||
OmegaConf.save(hparams, fp, resolve=True)
|
||||
|
|
|
@ -17,10 +17,10 @@ from typing import Any
|
|||
import torch
|
||||
from torch import distributed as torch_distrib
|
||||
|
||||
from pytorch_lightning.utilities import GROUP_AVAILABLE
|
||||
from pytorch_lightning.utilities import _GROUP_AVAILABLE
|
||||
|
||||
WORLD = None
|
||||
if GROUP_AVAILABLE:
|
||||
if _GROUP_AVAILABLE:
|
||||
from torch.distributed import group
|
||||
WORLD = group.WORLD
|
||||
|
||||
|
|
|
@ -29,10 +29,10 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
|
||||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
||||
if OMEGACONF_AVAILABLE:
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import Container, OmegaConf
|
||||
|
||||
|
||||
|
@ -149,7 +149,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
params = self._convert_params(params)
|
||||
|
||||
# store params to output
|
||||
if OMEGACONF_AVAILABLE and isinstance(params, Container):
|
||||
if _OMEGACONF_AVAILABLE and isinstance(params, Container):
|
||||
self.hparams = OmegaConf.merge(self.hparams, params)
|
||||
else:
|
||||
self.hparams.update(params)
|
||||
|
|
|
@ -11,10 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
|
||||
|
||||
LightningShardedDataParallel = None
|
||||
if FAIRSCALE_AVAILABLE:
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
|
||||
|
||||
class LightningShardedDataParallel(ShardedDataParallel):
|
||||
|
|
|
@ -19,10 +19,10 @@ from torch.optim.optimizer import Optimizer
|
|||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_warn
|
||||
|
||||
if APEX_AVAILABLE:
|
||||
if _APEX_AVAILABLE:
|
||||
from apex import amp
|
||||
|
||||
|
||||
|
|
|
@ -23,10 +23,10 @@ from pytorch_lightning import LightningModule
|
|||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if FAIRSCALE_PIPE_AVAILABLE:
|
||||
if _FAIRSCALE_PIPE_AVAILABLE:
|
||||
import fairscale.nn.model_parallel as mpu
|
||||
from fairscale.nn import PipeRPCWrapper
|
||||
from fairscale.nn.pipe import balance as pipe_balance
|
||||
|
@ -327,7 +327,7 @@ class DDPSequentialPlugin(RPCPlugin):
|
|||
torch_distrib.barrier(group=self.data_parallel_group)
|
||||
|
||||
def _check_pipe_available(self):
|
||||
if not FAIRSCALE_PIPE_AVAILABLE:
|
||||
if not _FAIRSCALE_PIPE_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.'
|
||||
)
|
||||
|
|
|
@ -18,9 +18,9 @@ import torch
|
|||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.utilities import RPC_AVAILABLE
|
||||
from pytorch_lightning.utilities import _RPC_AVAILABLE
|
||||
|
||||
if RPC_AVAILABLE:
|
||||
if _RPC_AVAILABLE:
|
||||
from torch.distributed import rpc
|
||||
|
||||
|
||||
|
|
|
@ -16,9 +16,9 @@ from typing import Union, cast
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
|
||||
|
||||
if NATIVE_AMP_AVAILABLE and FAIRSCALE_AVAILABLE:
|
||||
if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.optim import OSS
|
||||
from fairscale.optim.grad_scaler import ShardedGradScaler
|
||||
|
||||
|
|
|
@ -17,10 +17,10 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.core.optimizer import is_lightning_optimizer
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if FAIRSCALE_AVAILABLE:
|
||||
if _FAIRSCALE_AVAILABLE:
|
||||
from fairscale.optim import OSS
|
||||
|
||||
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel
|
||||
|
@ -46,7 +46,7 @@ class DDPShardedPlugin(DDPPlugin):
|
|||
return model.transfer_batch_to_device(args, model.trainer.root_gpu)
|
||||
|
||||
def _check_fairscale(self):
|
||||
if not FAIRSCALE_AVAILABLE:
|
||||
if not _FAIRSCALE_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
'Sharded DDP Plugin requires Fairscale to be installed.'
|
||||
)
|
||||
|
|
|
@ -13,8 +13,8 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Union, Optional
|
||||
|
||||
import torch
|
||||
|
@ -22,16 +22,16 @@ import torch
|
|||
import pytorch_lightning
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if APEX_AVAILABLE:
|
||||
if _APEX_AVAILABLE:
|
||||
from apex import amp
|
||||
|
||||
if OMEGACONF_AVAILABLE:
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import Container
|
||||
|
||||
|
||||
|
@ -297,7 +297,7 @@ class CheckpointConnector:
|
|||
if hasattr(model, '_hparams_name'):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
|
||||
# dump arguments
|
||||
if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
|
||||
if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
|
||||
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
|
||||
else:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn
|
||||
|
||||
|
||||
class PrecisionConnector:
|
||||
|
@ -48,7 +48,7 @@ class PrecisionConnector:
|
|||
amp_type = amp_type.lower()
|
||||
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
|
||||
if amp_type == 'native':
|
||||
if not NATIVE_AMP_AVAILABLE:
|
||||
if not _NATIVE_AMP_AVAILABLE:
|
||||
rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
|
||||
' Consider upgrading with `pip install torch>=1.6`.'
|
||||
' We will attempt to use NVIDIA Apex for this session.')
|
||||
|
@ -59,7 +59,7 @@ class PrecisionConnector:
|
|||
self.backend = NativeAMPPlugin(self.trainer)
|
||||
|
||||
if amp_type == 'apex':
|
||||
if not APEX_AVAILABLE:
|
||||
if not _APEX_AVAILABLE:
|
||||
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
|
||||
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
|
||||
else:
|
||||
|
|
|
@ -27,14 +27,14 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import Checkpoint
|
|||
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, argparse_utils, rank_zero_warn
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, argparse_utils, rank_zero_warn
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
|
||||
if TPU_AVAILABLE:
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
if HOROVOD_AVAILABLE:
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd
|
||||
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
|
|||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum
|
||||
from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, parsing
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.memory import recursive_detach
|
||||
|
@ -496,7 +496,7 @@ class TrainLoop:
|
|||
optimizer,
|
||||
opt_idx,
|
||||
train_step_and_backward_closure,
|
||||
on_tpu=self.trainer.use_tpu and TPU_AVAILABLE,
|
||||
on_tpu=self.trainer.use_tpu and _TPU_AVAILABLE,
|
||||
using_native_amp=using_native_amp,
|
||||
using_lbfgs=is_lbfgs,
|
||||
)
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch
|
|||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only, rank_zero_warn
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
|
||||
from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils
|
||||
from pytorch_lightning.utilities.xla_device_utils import _XLA_AVAILABLE, XLADeviceUtils
|
||||
|
||||
|
||||
def _module_available(module_path: str) -> bool:
|
||||
|
@ -49,19 +49,18 @@ def _module_available(module_path: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
APEX_AVAILABLE = _module_available("apex.amp")
|
||||
NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
||||
HYDRA_AVAILABLE = _module_available("hydra")
|
||||
HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
BOLTS_AVAILABLE = _module_available("pl_bolts")
|
||||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
||||
_HYDRA_AVAILABLE = _module_available("hydra")
|
||||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
|
||||
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
|
||||
FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
|
||||
RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
|
||||
GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
|
||||
FAIRSCALE_PIPE_AVAILABLE = FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) == LooseVersion("1.6.0")
|
||||
BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
|
||||
_FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
|
||||
_RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
|
||||
_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
|
||||
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
|
||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
|
||||
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
|
||||
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
import torch
|
||||
from typing import Union, Any, List, Optional, MutableSequence
|
||||
|
||||
from pytorch_lightning.utilities import TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -106,7 +106,7 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
|
|||
if not _tpu_cores_valid(tpu_cores):
|
||||
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")
|
||||
|
||||
if tpu_cores is not None and not TPU_AVAILABLE:
|
||||
if tpu_cores is not None and not _TPU_AVAILABLE:
|
||||
raise MisconfigurationException('No TPU devices were found.')
|
||||
|
||||
return tpu_cores
|
||||
|
|
|
@ -19,9 +19,9 @@ from multiprocessing import Process, Queue
|
|||
|
||||
import torch
|
||||
|
||||
XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
|
||||
_XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
if _XLA_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
|
@ -65,7 +65,7 @@ class XLADeviceUtils:
|
|||
Return:
|
||||
Returns a str of the device hardware type. i.e TPU
|
||||
"""
|
||||
if XLA_AVAILABLE:
|
||||
if _XLA_AVAILABLE:
|
||||
return xm.xla_device_hw(device)
|
||||
|
||||
@staticmethod
|
||||
|
@ -76,7 +76,7 @@ class XLADeviceUtils:
|
|||
Return:
|
||||
A boolean value indicating if the xla device is a TPU device or not
|
||||
"""
|
||||
if XLA_AVAILABLE:
|
||||
if _XLA_AVAILABLE:
|
||||
device = xm.xla_device()
|
||||
device_type = XLADeviceUtils._fetch_xla_device_type(device)
|
||||
return device_type == "TPU"
|
||||
|
@ -89,7 +89,7 @@ class XLADeviceUtils:
|
|||
Return:
|
||||
A boolean value indicating if a XLA is installed
|
||||
"""
|
||||
return XLA_AVAILABLE
|
||||
return _XLA_AVAILABLE
|
||||
|
||||
@staticmethod
|
||||
def tpu_device_exists() -> bool:
|
||||
|
@ -99,6 +99,6 @@ class XLADeviceUtils:
|
|||
Return:
|
||||
A boolean value indicating if a TPU device exists on the system
|
||||
"""
|
||||
if XLADeviceUtils.TPU_AVAILABLE is None and XLA_AVAILABLE:
|
||||
if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE:
|
||||
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
|
||||
return XLADeviceUtils.TPU_AVAILABLE
|
||||
|
|
|
@ -26,9 +26,9 @@ sys.path = os.getenv('PYTHONPATH').split(':') + sys.path
|
|||
|
||||
from pytorch_lightning import Trainer # noqa: E402
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
|
||||
from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402
|
||||
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE # noqa: E402
|
||||
|
||||
if HOROVOD_AVAILABLE:
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod.torch as hvd # noqa: E402
|
||||
else:
|
||||
print('You requested to import Horovod which is missing or not supported for your OS.')
|
||||
|
|
|
@ -22,7 +22,7 @@ import tests.base.develop_utils as tutils
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
|
||||
|
@ -188,7 +188,7 @@ def test_amp_without_apex(tmpdir):
|
|||
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
|
||||
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
|
||||
def test_amp_with_apex(tmpdir):
|
||||
"""Check calling apex scaling in training."""
|
||||
|
||||
|
|
|
@ -29,12 +29,12 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator
|
||||
from pytorch_lightning.core.step_result import EvalResult, Result, TrainResult
|
||||
from pytorch_lightning.metrics.classification.accuracy import Accuracy
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, HOROVOD_AVAILABLE, _module_available
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, _HOROVOD_AVAILABLE, _module_available
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base.boring_model import BoringModel
|
||||
from tests.base.models import BasicGAN
|
||||
|
||||
if HOROVOD_AVAILABLE:
|
||||
if _HOROVOD_AVAILABLE:
|
||||
import horovod
|
||||
import horovod.torch as hvd
|
||||
|
||||
|
@ -128,7 +128,7 @@ def test_horovod_multi_gpu(tmpdir):
|
|||
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
|
||||
@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
|
||||
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
|
||||
def test_horovod_apex(tmpdir):
|
||||
"""Test Horovod with multi-GPU support using apex amp."""
|
||||
trainer_options = dict(
|
||||
|
@ -152,7 +152,7 @@ def test_horovod_apex(tmpdir):
|
|||
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
|
||||
@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp")
|
||||
def test_horovod_amp(tmpdir):
|
||||
"""Test Horovod with multi-GPU support using native amp."""
|
||||
trainer_options = dict(
|
||||
|
@ -239,7 +239,7 @@ def test_horovod_multi_optimizer(enable_pl_optimizer, tmpdir):
|
|||
assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable")
|
||||
@pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
|
||||
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
|
||||
def test_result_reduce_horovod(enable_pl_optimizer, tmpdir):
|
||||
|
@ -290,7 +290,7 @@ def test_result_reduce_horovod(enable_pl_optimizer, tmpdir):
|
|||
horovod.run(hvd_test_fn, np=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable")
|
||||
@pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
|
||||
def test_accuracy_metric_horovod():
|
||||
num_batches = 10
|
||||
|
|
|
@ -22,14 +22,14 @@ import tests.base.develop_pipelines as tpipes
|
|||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.accelerators import TPUAccelerator
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.utilities import TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base.datasets import TrialMNIST
|
||||
from tests.base.develop_utils import pl_multi_process_test
|
||||
|
||||
|
||||
if TPU_AVAILABLE:
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
SERIAL_EXEC = xmp.MpSerialExecutor()
|
||||
|
@ -43,7 +43,7 @@ def _serial_train_loader():
|
|||
return DataLoader(_LARGER_DATASET, batch_size=32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_model_tpu_cores_1(tmpdir):
|
||||
"""Make sure model trains on TPU."""
|
||||
|
@ -61,7 +61,7 @@ def test_model_tpu_cores_1(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.parametrize('tpu_core', [1, 5])
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_model_tpu_index(tmpdir, tpu_core):
|
||||
"""Make sure model trains on TPU."""
|
||||
|
@ -79,7 +79,7 @@ def test_model_tpu_index(tmpdir, tpu_core):
|
|||
assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_model_tpu_cores_8(tmpdir):
|
||||
"""Make sure model trains on TPU."""
|
||||
|
@ -100,7 +100,7 @@ def test_model_tpu_cores_8(tmpdir):
|
|||
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_model_16bit_tpu_cores_1(tmpdir):
|
||||
"""Make sure model trains on TPU."""
|
||||
|
@ -120,7 +120,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.parametrize('tpu_core', [1, 5])
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_model_16bit_tpu_index(tmpdir, tpu_core):
|
||||
"""Make sure model trains on TPU."""
|
||||
|
@ -140,7 +140,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
|
|||
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_model_16bit_tpu_cores_8(tmpdir):
|
||||
"""Make sure model trains on TPU."""
|
||||
|
@ -162,7 +162,7 @@ def test_model_16bit_tpu_cores_8(tmpdir):
|
|||
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_model_tpu_early_stop(tmpdir):
|
||||
"""Test if single TPU core training works"""
|
||||
|
@ -179,7 +179,7 @@ def test_model_tpu_early_stop(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_tpu_grad_norm(tmpdir):
|
||||
"""Test if grad_norm works on TPU."""
|
||||
|
@ -197,7 +197,7 @@ def test_tpu_grad_norm(tmpdir):
|
|||
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_dataloaders_passed_to_fit(tmpdir):
|
||||
"""Test if dataloaders passed to trainer works on TPU"""
|
||||
|
@ -217,7 +217,7 @@ def test_dataloaders_passed_to_fit(tmpdir):
|
|||
['tpu_cores', 'expected_tpu_id'],
|
||||
[pytest.param(1, None), pytest.param(8, None), pytest.param([1], 1), pytest.param([8], 8)],
|
||||
)
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires missing TPU")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires missing TPU")
|
||||
def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id):
|
||||
"""Test if trainer.tpu_id is set as expected"""
|
||||
assert Trainer(tpu_cores=tpu_cores).tpu_id == expected_tpu_id
|
||||
|
@ -229,7 +229,7 @@ def test_tpu_misconfiguration():
|
|||
Trainer(tpu_cores=[1, 8])
|
||||
|
||||
|
||||
@pytest.mark.skipif(TPU_AVAILABLE, reason="test requires missing TPU")
|
||||
@pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU")
|
||||
def test_exception_when_no_tpu_found(tmpdir):
|
||||
"""Test if exception is thrown when xla devices are not available"""
|
||||
|
||||
|
@ -238,13 +238,13 @@ def test_exception_when_no_tpu_found(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.parametrize('tpu_cores', [1, 8, [1]])
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):
|
||||
"""Test if distributed_backend is set to `tpu` when tpu_cores is not None"""
|
||||
assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_broadcast_on_tpu():
|
||||
""" Checks if an object from the master process is broadcasted to other processes correctly"""
|
||||
|
@ -274,7 +274,7 @@ def test_broadcast_on_tpu():
|
|||
pytest.param(10, None, True),
|
||||
],
|
||||
)
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
|
||||
if error_expected:
|
||||
|
@ -291,7 +291,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
|
|||
pytest.param("--tpu_cores=1,",
|
||||
{'tpu_cores': '1,'})
|
||||
])
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
@pl_multi_process_test
|
||||
def test_tpu_cores_with_argparse(cli_args, expected):
|
||||
"""Test passing tpu_cores in command line"""
|
||||
|
|
|
@ -7,11 +7,11 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
|
||||
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@mock.patch.dict(os.environ, {
|
||||
"CUDA_VISIBLE_DEVICES": "0,1",
|
||||
"SLURM_NTASKS": "2",
|
||||
|
@ -45,7 +45,7 @@ def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@mock.patch.dict(os.environ, {
|
||||
"CUDA_VISIBLE_DEVICES": "0,1",
|
||||
"SLURM_NTASKS": "2",
|
||||
|
@ -89,7 +89,7 @@ class GradientUnscaleBoringModel(BoringModel):
|
|||
assert norm.item() < 15.
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_amp_gradient_unscale(tmpdir):
|
||||
model = GradientUnscaleBoringModel()
|
||||
|
@ -118,7 +118,7 @@ class UnscaleAccumulateGradBatchesBoringModel(BoringModel):
|
|||
assert norm.item() < 15.
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir):
|
||||
model = UnscaleAccumulateGradBatchesBoringModel()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE
|
||||
from tests.base.boring_model import BoringModel
|
||||
from pytorch_lightning import Trainer
|
||||
import pytest
|
||||
|
@ -8,7 +8,7 @@ from unittest import mock
|
|||
from pytorch_lightning.plugins.apex import ApexPlugin
|
||||
|
||||
|
||||
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
|
||||
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
|
||||
@mock.patch.dict(os.environ, {
|
||||
"CUDA_VISIBLE_DEVICES": "0,1",
|
||||
"SLURM_NTASKS": "2",
|
||||
|
@ -42,7 +42,7 @@ def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
|
||||
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
|
||||
@mock.patch.dict(os.environ, {
|
||||
"CUDA_VISIBLE_DEVICES": "0,1",
|
||||
"SLURM_NTASKS": "2",
|
||||
|
|
|
@ -8,7 +8,7 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
@ -104,7 +104,7 @@ def test_ddp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
|
|||
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
|
||||
)
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed sharded plugin is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_choice_string_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
|
||||
class CB(Callback):
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
|
@ -169,7 +169,7 @@ def test_ddp_invalid_choice_string_ddp_cpu(tmpdir, ddp_backend, gpus, num_proces
|
|||
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
|
||||
)
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed sharded plugin is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_invalid_choice_string_and_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
|
||||
"""
|
||||
Test passing a lightning custom ddp plugin and a default ddp plugin throws an error.
|
||||
|
|
|
@ -21,7 +21,7 @@ from torch import nn
|
|||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
|
||||
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import RandomDataset
|
||||
|
||||
|
@ -33,7 +33,7 @@ def cleanup(ctx, model):
|
|||
del model
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
|
@ -60,7 +60,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
|
|||
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
|
@ -87,7 +87,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None):
|
|||
assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision'
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
|
@ -115,7 +115,7 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None):
|
|||
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import RPC_AVAILABLE
|
||||
from pytorch_lightning.utilities import _RPC_AVAILABLE
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ from tests.base.boring_model import BoringModel
|
|||
["ddp_backend", "gpus", "num_processes"],
|
||||
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
|
||||
)
|
||||
@pytest.mark.skipif(not RPC_AVAILABLE, reason="RPC is not available")
|
||||
@pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available")
|
||||
def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes):
|
||||
class CB(Callback):
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
|
@ -88,7 +88,7 @@ class CustomRPCPlugin(RPCPlugin):
|
|||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not RPC_AVAILABLE, reason="RPC is not available")
|
||||
@pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
def test_rpc_function_calls_ddp(tmpdir):
|
||||
|
|
|
@ -8,8 +8,8 @@ import torch
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
|
||||
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE, APEX_AVAILABLE
|
||||
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, _FAIRSCALE_AVAILABLE
|
||||
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
@ -30,7 +30,7 @@ from tests.base.boring_model import BoringModel
|
|||
["ddp_backend", "gpus", "num_processes"],
|
||||
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
|
||||
)
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes):
|
||||
"""
|
||||
Test to ensure that plugin is correctly chosen
|
||||
|
@ -55,8 +55,8 @@ def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_invalid_apex_sharded(tmpdir):
|
||||
"""
|
||||
Test to ensure that we raise an error when we try to use apex and sharded
|
||||
|
@ -91,8 +91,8 @@ def test_invalid_apex_sharded(tmpdir):
|
|||
["ddp_backend", "gpus", "num_processes"],
|
||||
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
|
||||
)
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes):
|
||||
"""
|
||||
Test to ensure that plugin native amp plugin is correctly chosen when using sharded
|
||||
|
@ -121,7 +121,7 @@ def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes):
|
|||
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir):
|
||||
"""
|
||||
Test to ensure that checkpoint is saved correctly
|
||||
|
@ -147,7 +147,7 @@ def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir):
|
|||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir):
|
||||
"""
|
||||
Test to ensure that checkpoint is saved correctly when using multiple GPUs
|
||||
|
@ -174,7 +174,7 @@ def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir):
|
|||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_finetune(tmpdir):
|
||||
"""
|
||||
Test to ensure that we can save and restart training (simulate fine-tuning)
|
||||
|
@ -200,7 +200,7 @@ def test_ddp_sharded_plugin_finetune(tmpdir):
|
|||
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir):
|
||||
"""
|
||||
Test to ensure that resuming from checkpoint works
|
||||
|
@ -234,7 +234,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir):
|
|||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir):
|
||||
"""
|
||||
Test to ensure that resuming from checkpoint works when downsizing number of GPUS
|
||||
|
@ -268,7 +268,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir):
|
|||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir):
|
||||
"""
|
||||
Test to ensure that resuming from checkpoint works when going from GPUs- > CPU
|
||||
|
@ -300,7 +300,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir):
|
|||
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_test(tmpdir):
|
||||
"""
|
||||
Test to ensure we can use test without fit
|
||||
|
@ -318,7 +318,7 @@ def test_ddp_sharded_plugin_test(tmpdir):
|
|||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_test_multigpu(tmpdir):
|
||||
"""
|
||||
Test to ensure we can use test without fit
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch.distributed as torch_distrib
|
|||
import torch.nn.functional as F
|
||||
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base.boring_model import BoringModel
|
||||
|
||||
|
@ -301,7 +301,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir):
|
|||
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
|
||||
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
|
||||
def test_multiple_optimizers_manual_apex(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
|
|
|
@ -16,7 +16,7 @@ import torch
|
|||
import pytest
|
||||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities import _APEX_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ from pytorch_lightning.loggers import TensorBoardLogger
|
|||
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
|
||||
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import BoringModel, EvalModelTemplate
|
||||
|
@ -968,7 +968,7 @@ def test_gradient_clipping(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
|
||||
def test_gradient_clipping_fp16(tmpdir):
|
||||
"""
|
||||
Test gradient clipping with fp16
|
||||
|
|
|
@ -19,7 +19,7 @@ from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
|
|||
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVAILABLE
|
||||
from pytorch_lightning.utilities import AMPType, _NATIVE_AMP_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base.datamodules import MNISTDataModule
|
||||
|
@ -328,7 +328,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
|
|||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
|
||||
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
|
||||
def test_auto_scale_batch_size_with_amp(tmpdir):
|
||||
model = EvalModelTemplate()
|
||||
batch_size_before = model.batch_size
|
||||
|
|
|
@ -16,21 +16,21 @@ import time
|
|||
import pytest
|
||||
|
||||
import pytorch_lightning.utilities.xla_device_utils as xla_utils
|
||||
from pytorch_lightning.utilities import XLA_AVAILABLE, TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities import _XLA_AVAILABLE, _TPU_AVAILABLE
|
||||
from tests.base.develop_utils import pl_multi_process_test
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
if _XLA_AVAILABLE:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
# lets hope that in or env we have installed XLA only for TPU devices, otherwise, it is testing in the cycle "if I am true test that I am true :D"
|
||||
@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
|
||||
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
|
||||
def test_tpu_device_absence():
|
||||
"""Check tpu_device_exists returns None when torch_xla is not available"""
|
||||
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires torch_xla to be installed")
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires torch_xla to be installed")
|
||||
@pl_multi_process_test
|
||||
def test_tpu_device_presence():
|
||||
"""Check tpu_device_exists returns True when TPU is available"""
|
||||
|
|
Loading…
Reference in New Issue