set xxx_AVAILABLE as protected (#5082)

* sett xxx_AVAILABLE as protected

* docs
This commit is contained in:
Jirka Borovec 2020-12-14 15:49:05 +01:00 committed by GitHub
parent 0327f6b4c2
commit 059eaecbb4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 188 additions and 189 deletions

View File

@ -9,14 +9,14 @@ import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
from tests.backends import DDPLauncher
from tests.base.boring_model import BoringModel, RandomDataset
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_device():
plugin_parity_test(
accelerator='ddp_cpu',
@ -29,7 +29,7 @@ def test_ddp_sharded_plugin_correctness_one_device():
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
plugin_parity_test(
gpus=1,
@ -39,11 +39,11 @@ def test_ddp_sharded_plugin_correctness_one_gpu():
)
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
plugin_parity_test(
gpus=1,
@ -58,7 +58,7 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
plugin_parity_test(
gpus=2,
@ -69,11 +69,11 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
)
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
plugin_parity_test(
gpus=2,
@ -85,11 +85,11 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
)
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
plugin_parity_test(
gpus=2,
@ -101,7 +101,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
)
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@ -116,7 +116,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_ddp(tmpdir, args=None):
)
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
@ -135,7 +135,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
"""
Ensures same results using multiple optimizers across multiple GPUs
@ -153,7 +153,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
"""
Ensures using multiple optimizers across multiple GPUs with manual optimization

View File

@ -31,7 +31,7 @@ Native torch
When using PyTorch 1.6+ Lightning uses the native amp implementation to support 16-bit.
.. testcode::
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE
:skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE
# turn on 16-bit
trainer = Trainer(precision=16)
@ -73,7 +73,7 @@ Enable 16-bit
^^^^^^^^^^^^^
.. testcode::
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE
:skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE
# turn on 16-bit
trainer = Trainer(amp_level='O2', precision=16)
@ -88,7 +88,7 @@ TPU 16-bit
16-bit on TPUs is much simpler. To use 16-bit with TPUs set precision to 16 when using the TPU flag
.. testcode::
:skipif: not TPU_AVAILABLE
:skipif: not _TPU_AVAILABLE
# DEFAULT
trainer = Trainer(tpu_cores=8, precision=32)

View File

@ -359,12 +359,12 @@ import os
import torch
from pytorch_lightning.utilities import (
NATIVE_AMP_AVAILABLE,
APEX_AVAILABLE,
XLA_AVAILABLE,
TPU_AVAILABLE,
_NATIVE_AMP_AVAILABLE,
_APEX_AVAILABLE,
_XLA_AVAILABLE,
_TPU_AVAILABLE,
)
TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
_TORCHVISION_AVAILABLE = importlib.util.find_spec("torchvision") is not None
"""

View File

@ -135,7 +135,7 @@ Data
Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIST.
.. testcode::
:skipif: not TORCHVISION_AVAILABLE
:skipif: not _TORCHVISION_AVAILABLE
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
@ -153,7 +153,7 @@ Lightning operates on pure dataloaders. Here's the PyTorch code for loading MNIS
.. testoutput::
:hide:
:skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not TORCHVISION_AVAILABLE
:skipif: os.path.isdir(os.path.join(os.getcwd(), 'MNIST')) or not _TORCHVISION_AVAILABLE
Downloading ...
Extracting ...

View File

@ -1180,7 +1180,7 @@ If used on TPU will use torch.bfloat16 but tensor printing
will still show torch.float32.
.. testcode::
:skipif: not APEX_AVAILABLE and not NATIVE_AMP_AVAILABLE
:skipif: not _APEX_AVAILABLE and not _NATIVE_AMP_AVAILABLE
# default used by the Trainer
trainer = Trainer(precision=32)

View File

@ -46,7 +46,7 @@ Example: Imagenet (computer Vision)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. testcode::
:skipif: not TORCHVISION_AVAILABLE
:skipif: not _TORCHVISION_AVAILABLE
import torchvision.models as models

View File

@ -32,9 +32,9 @@ import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
from pytorch_lightning.utilities import BOLTS_AVAILABLE, FAIRSCALE_PIPE_AVAILABLE
from pytorch_lightning.utilities import _BOLTS_AVAILABLE, _FAIRSCALE_PIPE_AVAILABLE
if BOLTS_AVAILABLE:
if _BOLTS_AVAILABLE:
import pl_bolts
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
@ -196,8 +196,8 @@ if __name__ == "__main__":
parser = pl_bolts.datamodules.CIFAR10DataModule.add_argparse_args(parser)
args = parser.parse_args()
assert BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
assert FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
assert _BOLTS_AVAILABLE, "Bolts is required for this example, install it via pip install pytorch-lightning-bolts"
assert _FAIRSCALE_PIPE_AVAILABLE, "FairScale and PyTorch 1.6 is required for this example."
cifar10_dm = instantiate_datamodule(args)

View File

@ -15,17 +15,17 @@ import os
import torch
from pytorch_lightning.utilities import HOROVOD_AVAILABLE
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
from pytorch_lightning import _logger as log
from pytorch_lightning import accelerators
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
from pytorch_lightning.utilities import device_parser, rank_zero_only, TPU_AVAILABLE
from pytorch_lightning.utilities import device_parser, rank_zero_only, _TPU_AVAILABLE
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if HOROVOD_AVAILABLE:
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
@ -348,7 +348,7 @@ class AcceleratorConnector:
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.trainer.on_gpu}')
num_cores = self.trainer.tpu_cores if self.trainer.tpu_cores is not None else 0
rank_zero_info(f'TPU available: {TPU_AVAILABLE}, using: {num_cores} TPU cores')
rank_zero_info(f'TPU available: {_TPU_AVAILABLE}, using: {num_cores} TPU cores')
if torch.cuda.is_available() and not self.trainer.on_gpu:
rank_zero_warn('GPU available but not used. Set the --gpus flag when calling the script.')
@ -365,7 +365,7 @@ class AcceleratorConnector:
def check_horovod(self):
"""Raises a `MisconfigurationException` if the Trainer is not configured correctly for Horovod."""
if not HOROVOD_AVAILABLE:
if not _HOROVOD_AVAILABLE:
raise MisconfigurationException(
'Requested `accelerator="horovod"`, but Horovod is not installed.'
'Install with \n $HOROVOD_WITH_PYTORCH=1 pip install horovod[pytorch]'

View File

@ -26,10 +26,10 @@ from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
if HYDRA_AVAILABLE:
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path

View File

@ -30,7 +30,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
@ -40,7 +40,7 @@ from pytorch_lightning.utilities.distributed import (
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
if HYDRA_AVAILABLE:
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
@ -94,7 +94,7 @@ class DDPAccelerator(Accelerator):
os.environ['LOCAL_RANK'] = '0'
# when user is using hydra find the absolute path
path_lib = abspath if not HYDRA_AVAILABLE else to_absolute_path
path_lib = abspath if not _HYDRA_AVAILABLE else to_absolute_path
# pull out the commands used to run the script and resolve the abs file path
command = sys.argv
@ -135,7 +135,7 @@ class DDPAccelerator(Accelerator):
# start process
# if hydra is available and initialized, make sure to set the cwd correctly
cwd: Optional[str] = None
if HYDRA_AVAILABLE:
if _HYDRA_AVAILABLE:
if HydraConfig.initialized():
cwd = get_original_cwd()
proc = subprocess.Popen(command, env=env_copy, cwd=cwd)

View File

@ -17,9 +17,9 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE
from pytorch_lightning.utilities import _HYDRA_AVAILABLE
if HYDRA_AVAILABLE:
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path

View File

@ -26,7 +26,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
@ -35,7 +35,7 @@ from pytorch_lightning.utilities.distributed import (
sync_ddp_if_available,
)
if HYDRA_AVAILABLE:
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path

View File

@ -26,10 +26,10 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
if HYDRA_AVAILABLE:
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path

View File

@ -27,7 +27,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
@ -39,7 +39,7 @@ from pytorch_lightning.utilities.distributed import (
)
from pytorch_lightning.utilities.seed import seed_everything
if HYDRA_AVAILABLE:
if _HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path

View File

@ -20,10 +20,10 @@ from torch.optim.lr_scheduler import _LRScheduler
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
if HOROVOD_AVAILABLE:
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd

View File

@ -26,7 +26,7 @@ from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
TPU_AVAILABLE,
_TPU_AVAILABLE,
move_data_to_device,
rank_zero_info,
rank_zero_only,
@ -35,7 +35,7 @@ from pytorch_lightning.utilities import (
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if TPU_AVAILABLE:
if _TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as xla_pl
@ -63,7 +63,7 @@ class TPUAccelerator(Accelerator):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
# TODO: Move this check to Trainer __init__ or device parser
if not TPU_AVAILABLE:
if not _TPU_AVAILABLE:
raise MisconfigurationException('PyTorch XLA not installed.')
# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
@ -179,7 +179,7 @@ class TPUAccelerator(Accelerator):
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not TPU_AVAILABLE:
if not _TPU_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'

View File

@ -26,7 +26,7 @@ import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, TPU_AVAILABLE
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, _TPU_AVAILABLE
class EarlyStopping(Callback):
@ -204,7 +204,7 @@ class EarlyStopping(Callback):
if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)
if trainer.use_tpu and TPU_AVAILABLE:
if trainer.use_tpu and _TPU_AVAILABLE:
current = current.cpu()
if self.monitor_op(current - self.min_delta, self.best_score):

View File

@ -37,12 +37,12 @@ from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, PRIMITIVE_TYPES, ModelIO
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, get_init_args
if TPU_AVAILABLE:
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

View File

@ -17,10 +17,10 @@ from weakref import proxy
from torch.optim.optimizer import Optimizer
from pytorch_lightning.utilities import TPU_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if TPU_AVAILABLE:
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

View File

@ -24,7 +24,7 @@ import torch
import yaml
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, OMEGACONF_AVAILABLE
from pytorch_lightning.utilities import rank_zero_warn, AttributeDict, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.parsing import parse_class_init_keys
@ -32,7 +32,7 @@ from pytorch_lightning.utilities.parsing import parse_class_init_keys
PRIMITIVE_TYPES = (bool, int, float, str)
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
if OMEGACONF_AVAILABLE:
if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
# the older shall be on the top
@ -360,7 +360,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
hparams = dict(hparams)
# saving with OmegaConf objects
if OMEGACONF_AVAILABLE:
if _OMEGACONF_AVAILABLE:
if OmegaConf.is_config(hparams):
with fs.open(config_yaml, "w", encoding="utf-8") as fp:
OmegaConf.save(hparams, fp, resolve=True)

View File

@ -17,10 +17,10 @@ from typing import Any
import torch
from torch import distributed as torch_distrib
from pytorch_lightning.utilities import GROUP_AVAILABLE
from pytorch_lightning.utilities import _GROUP_AVAILABLE
WORLD = None
if GROUP_AVAILABLE:
if _GROUP_AVAILABLE:
from torch.distributed import group
WORLD = group.WORLD

View File

@ -29,10 +29,10 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, OMEGACONF_AVAILABLE
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn, _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.cloud_io import get_filesystem
if OMEGACONF_AVAILABLE:
if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
@ -149,7 +149,7 @@ class TensorBoardLogger(LightningLoggerBase):
params = self._convert_params(params)
# store params to output
if OMEGACONF_AVAILABLE and isinstance(params, Container):
if _OMEGACONF_AVAILABLE and isinstance(params, Container):
self.hparams = OmegaConf.merge(self.hparams, params)
else:
self.hparams.update(params)

View File

@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
LightningShardedDataParallel = None
if FAIRSCALE_AVAILABLE:
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
class LightningShardedDataParallel(ShardedDataParallel):

View File

@ -19,10 +19,10 @@ from torch.optim.optimizer import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_warn
if APEX_AVAILABLE:
if _APEX_AVAILABLE:
from apex import amp

View File

@ -23,10 +23,10 @@ from pytorch_lightning import LightningModule
from pytorch_lightning import _logger as log
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if FAIRSCALE_PIPE_AVAILABLE:
if _FAIRSCALE_PIPE_AVAILABLE:
import fairscale.nn.model_parallel as mpu
from fairscale.nn import PipeRPCWrapper
from fairscale.nn.pipe import balance as pipe_balance
@ -327,7 +327,7 @@ class DDPSequentialPlugin(RPCPlugin):
torch_distrib.barrier(group=self.data_parallel_group)
def _check_pipe_available(self):
if not FAIRSCALE_PIPE_AVAILABLE:
if not _FAIRSCALE_PIPE_AVAILABLE:
raise MisconfigurationException(
'PipeRPCPlugin requires FairScale and currently is only supported on PyTorch 1.6.'
)

View File

@ -18,9 +18,9 @@ import torch
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import RPC_AVAILABLE
from pytorch_lightning.utilities import _RPC_AVAILABLE
if RPC_AVAILABLE:
if _RPC_AVAILABLE:
from torch.distributed import rpc

View File

@ -16,9 +16,9 @@ from typing import Union, cast
from torch.optim import Optimizer
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, _NATIVE_AMP_AVAILABLE
if NATIVE_AMP_AVAILABLE and FAIRSCALE_AVAILABLE:
if _NATIVE_AMP_AVAILABLE and _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler

View File

@ -17,10 +17,10 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, AMPType, rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if FAIRSCALE_AVAILABLE:
if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel
@ -46,7 +46,7 @@ class DDPShardedPlugin(DDPPlugin):
return model.transfer_batch_to_device(args, model.trainer.root_gpu)
def _check_fairscale(self):
if not FAIRSCALE_AVAILABLE:
if not _FAIRSCALE_AVAILABLE:
raise MisconfigurationException(
'Sharded DDP Plugin requires Fairscale to be installed.'
)

View File

@ -13,8 +13,8 @@
# limitations under the License.
import os
from pathlib import Path
import re
from pathlib import Path
from typing import Union, Optional
import torch
@ -22,16 +22,16 @@ import torch
import pytorch_lightning
from pytorch_lightning import _logger as log
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import APEX_AVAILABLE, AMPType, OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if APEX_AVAILABLE:
if _APEX_AVAILABLE:
from apex import amp
if OMEGACONF_AVAILABLE:
if _OMEGACONF_AVAILABLE:
from omegaconf import Container
@ -297,7 +297,7 @@ class CheckpointConnector:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# dump arguments
if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
else:

View File

@ -15,7 +15,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.plugins.apex import ApexPlugin
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn
from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, AMPType, rank_zero_warn
class PrecisionConnector:
@ -48,7 +48,7 @@ class PrecisionConnector:
amp_type = amp_type.lower()
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
if amp_type == 'native':
if not NATIVE_AMP_AVAILABLE:
if not _NATIVE_AMP_AVAILABLE:
rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
' Consider upgrading with `pip install torch>=1.6`.'
' We will attempt to use NVIDIA Apex for this session.')
@ -59,7 +59,7 @@ class PrecisionConnector:
self.backend = NativeAMPPlugin(self.trainer)
if amp_type == 'apex':
if not APEX_AVAILABLE:
if not _APEX_AVAILABLE:
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
else:

View File

@ -27,14 +27,14 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import Checkpoint
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, TPU_AVAILABLE, argparse_utils, rank_zero_warn
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE, _TPU_AVAILABLE, argparse_utils, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.model_utils import is_overridden
if TPU_AVAILABLE:
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
if HOROVOD_AVAILABLE:
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd

View File

@ -26,7 +26,7 @@ from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.supporters import Accumulator, TensorRunningAccum
from pytorch_lightning.utilities import TPU_AVAILABLE, AMPType, parsing
from pytorch_lightning.utilities import _TPU_AVAILABLE, AMPType, parsing
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.memory import recursive_detach
@ -496,7 +496,7 @@ class TrainLoop:
optimizer,
opt_idx,
train_step_and_backward_closure,
on_tpu=self.trainer.use_tpu and TPU_AVAILABLE,
on_tpu=self.trainer.use_tpu and _TPU_AVAILABLE,
using_native_amp=using_native_amp,
using_lbfgs=is_lbfgs,
)

View File

@ -24,7 +24,7 @@ import torch
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import AllGatherGrad, rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable
from pytorch_lightning.utilities.xla_device_utils import XLA_AVAILABLE, XLADeviceUtils
from pytorch_lightning.utilities.xla_device_utils import _XLA_AVAILABLE, XLADeviceUtils
def _module_available(module_path: str) -> bool:
@ -49,19 +49,18 @@ def _module_available(module_path: str) -> bool:
return False
APEX_AVAILABLE = _module_available("apex.amp")
NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
OMEGACONF_AVAILABLE = _module_available("omegaconf")
HYDRA_AVAILABLE = _module_available("hydra")
HOROVOD_AVAILABLE = _module_available("horovod.torch")
BOLTS_AVAILABLE = _module_available("pl_bolts")
_APEX_AVAILABLE = _module_available("apex.amp")
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_HYDRA_AVAILABLE = _module_available("hydra")
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
FAIRSCALE_PIPE_AVAILABLE = FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) == LooseVersion("1.6.0")
BOLTS_AVAILABLE = _module_available('pl_bolts')
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
_FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
_RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(torch.__version__) >= LooseVersion("1.6.0")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps

View File

@ -14,7 +14,7 @@
import torch
from typing import Union, Any, List, Optional, MutableSequence
from pytorch_lightning.utilities import TPU_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -106,7 +106,7 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
if not _tpu_cores_valid(tpu_cores):
raise MisconfigurationException("`tpu_cores` can only be 1, 8 or [<1-8>]")
if tpu_cores is not None and not TPU_AVAILABLE:
if tpu_cores is not None and not _TPU_AVAILABLE:
raise MisconfigurationException('No TPU devices were found.')
return tpu_cores

View File

@ -19,9 +19,9 @@ from multiprocessing import Process, Queue
import torch
XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
_XLA_AVAILABLE = importlib.util.find_spec("torch_xla") is not None
if XLA_AVAILABLE:
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
@ -65,7 +65,7 @@ class XLADeviceUtils:
Return:
Returns a str of the device hardware type. i.e TPU
"""
if XLA_AVAILABLE:
if _XLA_AVAILABLE:
return xm.xla_device_hw(device)
@staticmethod
@ -76,7 +76,7 @@ class XLADeviceUtils:
Return:
A boolean value indicating if the xla device is a TPU device or not
"""
if XLA_AVAILABLE:
if _XLA_AVAILABLE:
device = xm.xla_device()
device_type = XLADeviceUtils._fetch_xla_device_type(device)
return device_type == "TPU"
@ -89,7 +89,7 @@ class XLADeviceUtils:
Return:
A boolean value indicating if a XLA is installed
"""
return XLA_AVAILABLE
return _XLA_AVAILABLE
@staticmethod
def tpu_device_exists() -> bool:
@ -99,6 +99,6 @@ class XLADeviceUtils:
Return:
A boolean value indicating if a TPU device exists on the system
"""
if XLADeviceUtils.TPU_AVAILABLE is None and XLA_AVAILABLE:
if XLADeviceUtils.TPU_AVAILABLE is None and _XLA_AVAILABLE:
XLADeviceUtils.TPU_AVAILABLE = pl_multi_process(XLADeviceUtils._is_device_tpu)()
return XLADeviceUtils.TPU_AVAILABLE

View File

@ -26,9 +26,9 @@ sys.path = os.getenv('PYTHONPATH').split(':') + sys.path
from pytorch_lightning import Trainer # noqa: E402
from pytorch_lightning.callbacks import ModelCheckpoint # noqa: E402
from pytorch_lightning.utilities import HOROVOD_AVAILABLE # noqa: E402
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE # noqa: E402
if HOROVOD_AVAILABLE:
if _HOROVOD_AVAILABLE:
import horovod.torch as hvd # noqa: E402
else:
print('You requested to import Horovod which is missing or not supported for your OS.')

View File

@ -22,7 +22,7 @@ import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities import _APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
@ -188,7 +188,7 @@ def test_amp_without_apex(tmpdir):
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
def test_amp_with_apex(tmpdir):
"""Check calling apex scaling in training."""

View File

@ -29,12 +29,12 @@ from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.horovod_accelerator import HorovodAccelerator
from pytorch_lightning.core.step_result import EvalResult, Result, TrainResult
from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVAILABLE, HOROVOD_AVAILABLE, _module_available
from pytorch_lightning.utilities import _APEX_AVAILABLE, _NATIVE_AMP_AVAILABLE, _HOROVOD_AVAILABLE, _module_available
from tests.base import EvalModelTemplate
from tests.base.boring_model import BoringModel
from tests.base.models import BasicGAN
if HOROVOD_AVAILABLE:
if _HOROVOD_AVAILABLE:
import horovod
import horovod.torch as hvd
@ -128,7 +128,7 @@ def test_horovod_multi_gpu(tmpdir):
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
def test_horovod_apex(tmpdir):
"""Test Horovod with multi-GPU support using apex amp."""
trainer_options = dict(
@ -152,7 +152,7 @@ def test_horovod_apex(tmpdir):
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.skipif(not HOROVOD_NCCL_AVAILABLE, reason="test requires Horovod with NCCL support")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires torch.cuda.amp")
def test_horovod_amp(tmpdir):
"""Test Horovod with multi-GPU support using native amp."""
trainer_options = dict(
@ -239,7 +239,7 @@ def test_horovod_multi_optimizer(enable_pl_optimizer, tmpdir):
assert get_model_params(model.discriminator) == get_optimizer_params(trainer.optimizers[1])
@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable")
@pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable")
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_result_reduce_horovod(enable_pl_optimizer, tmpdir):
@ -290,7 +290,7 @@ def test_result_reduce_horovod(enable_pl_optimizer, tmpdir):
horovod.run(hvd_test_fn, np=2)
@pytest.mark.skipif(not HOROVOD_AVAILABLE, reason="Horovod is unavailable")
@pytest.mark.skipif(not _HOROVOD_AVAILABLE, reason="Horovod is unavailable")
@pytest.mark.skipif(platform.system() == "Windows", reason="Horovod is not supported on Windows")
def test_accuracy_metric_horovod():
num_batches = 10

View File

@ -22,14 +22,14 @@ import tests.base.develop_pipelines as tpipes
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.utilities import TPU_AVAILABLE
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datasets import TrialMNIST
from tests.base.develop_utils import pl_multi_process_test
if TPU_AVAILABLE:
if _TPU_AVAILABLE:
import torch_xla
import torch_xla.distributed.xla_multiprocessing as xmp
SERIAL_EXEC = xmp.MpSerialExecutor()
@ -43,7 +43,7 @@ def _serial_train_loader():
return DataLoader(_LARGER_DATASET, batch_size=32)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_tpu_cores_1(tmpdir):
"""Make sure model trains on TPU."""
@ -61,7 +61,7 @@ def test_model_tpu_cores_1(tmpdir):
@pytest.mark.parametrize('tpu_core', [1, 5])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_tpu_index(tmpdir, tpu_core):
"""Make sure model trains on TPU."""
@ -79,7 +79,7 @@ def test_model_tpu_index(tmpdir, tpu_core):
assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}'
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_tpu_cores_8(tmpdir):
"""Make sure model trains on TPU."""
@ -100,7 +100,7 @@ def test_model_tpu_cores_8(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_16bit_tpu_cores_1(tmpdir):
"""Make sure model trains on TPU."""
@ -120,7 +120,7 @@ def test_model_16bit_tpu_cores_1(tmpdir):
@pytest.mark.parametrize('tpu_core', [1, 5])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_16bit_tpu_index(tmpdir, tpu_core):
"""Make sure model trains on TPU."""
@ -140,7 +140,7 @@ def test_model_16bit_tpu_index(tmpdir, tpu_core):
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_16bit_tpu_cores_8(tmpdir):
"""Make sure model trains on TPU."""
@ -162,7 +162,7 @@ def test_model_16bit_tpu_cores_8(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_model_tpu_early_stop(tmpdir):
"""Test if single TPU core training works"""
@ -179,7 +179,7 @@ def test_model_tpu_early_stop(tmpdir):
trainer.fit(model)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_tpu_grad_norm(tmpdir):
"""Test if grad_norm works on TPU."""
@ -197,7 +197,7 @@ def test_tpu_grad_norm(tmpdir):
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_dataloaders_passed_to_fit(tmpdir):
"""Test if dataloaders passed to trainer works on TPU"""
@ -217,7 +217,7 @@ def test_dataloaders_passed_to_fit(tmpdir):
['tpu_cores', 'expected_tpu_id'],
[pytest.param(1, None), pytest.param(8, None), pytest.param([1], 1), pytest.param([8], 8)],
)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires missing TPU")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires missing TPU")
def test_tpu_id_to_be_as_expected(tpu_cores, expected_tpu_id):
"""Test if trainer.tpu_id is set as expected"""
assert Trainer(tpu_cores=tpu_cores).tpu_id == expected_tpu_id
@ -229,7 +229,7 @@ def test_tpu_misconfiguration():
Trainer(tpu_cores=[1, 8])
@pytest.mark.skipif(TPU_AVAILABLE, reason="test requires missing TPU")
@pytest.mark.skipif(_TPU_AVAILABLE, reason="test requires missing TPU")
def test_exception_when_no_tpu_found(tmpdir):
"""Test if exception is thrown when xla devices are not available"""
@ -238,13 +238,13 @@ def test_exception_when_no_tpu_found(tmpdir):
@pytest.mark.parametrize('tpu_cores', [1, 8, [1]])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
def test_distributed_backend_set_when_using_tpu(tmpdir, tpu_cores):
"""Test if distributed_backend is set to `tpu` when tpu_cores is not None"""
assert Trainer(tpu_cores=tpu_cores).distributed_backend == "tpu"
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_broadcast_on_tpu():
""" Checks if an object from the master process is broadcasted to other processes correctly"""
@ -274,7 +274,7 @@ def test_broadcast_on_tpu():
pytest.param(10, None, True),
],
)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
if error_expected:
@ -291,7 +291,7 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
pytest.param("--tpu_cores=1,",
{'tpu_cores': '1,'})
])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_tpu_cores_with_argparse(cli_args, expected):
"""Test passing tpu_cores in command line"""

View File

@ -7,11 +7,11 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.native_amp import NativeAMPPlugin
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from tests.base.boring_model import BoringModel
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
@ -45,7 +45,7 @@ def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
trainer.fit(model)
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
@ -89,7 +89,7 @@ class GradientUnscaleBoringModel(BoringModel):
assert norm.item() < 15.
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_amp_gradient_unscale(tmpdir):
model = GradientUnscaleBoringModel()
@ -118,7 +118,7 @@ class UnscaleAccumulateGradBatchesBoringModel(BoringModel):
assert norm.item() < 15.
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_amp_gradient_unscale_accumulate_grad_batches(tmpdir):
model = UnscaleAccumulateGradBatchesBoringModel()

View File

@ -1,5 +1,5 @@
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities import _APEX_AVAILABLE
from tests.base.boring_model import BoringModel
from pytorch_lightning import Trainer
import pytest
@ -8,7 +8,7 @@ from unittest import mock
from pytorch_lightning.plugins.apex import ApexPlugin
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
@ -42,7 +42,7 @@ def test_amp_choice_default_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
trainer.fit(model)
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@mock.patch.dict(os.environ, {
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",

View File

@ -8,7 +8,7 @@ from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin
from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel
@ -104,7 +104,7 @@ def test_ddp_choice_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed sharded plugin is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_choice_string_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
@ -169,7 +169,7 @@ def test_ddp_invalid_choice_string_ddp_cpu(tmpdir, ddp_backend, gpus, num_proces
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed sharded plugin is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_invalid_choice_string_and_custom_ddp_cpu(tmpdir, ddp_backend, gpus, num_processes):
"""
Test passing a lightning custom ddp plugin and a default ddp plugin throws an error.

View File

@ -21,7 +21,7 @@ from torch import nn
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.plugins.ddp_sequential_plugin import DDPSequentialPlugin
from pytorch_lightning.utilities import FAIRSCALE_PIPE_AVAILABLE
from pytorch_lightning.utilities import _FAIRSCALE_PIPE_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import RandomDataset
@ -33,7 +33,7 @@ def cleanup(ctx, model):
del model
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
@ -60,7 +60,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None):
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
@ -87,7 +87,7 @@ def test_ddp_sequential_plugin_ddp_rpc_manual_amp(tmpdir, args=None):
assert str(e) == 'DDPSequentialPlugin is currently not supported in Automatic Mixed Precision'
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
@ -115,7 +115,7 @@ def test_ddp_sequential_plugin_ddp_rpc_automatic(tmpdir, args=None):
trainer.accelerator_backend.ddp_plugin.exit_rpc_process()
@pytest.mark.skipif(not FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@pytest.mark.skipif(not _FAIRSCALE_PIPE_AVAILABLE, reason="test requires FairScale to be installed")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',

View File

@ -8,7 +8,7 @@ import torch
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import RPC_AVAILABLE
from pytorch_lightning.utilities import _RPC_AVAILABLE
from tests.base.boring_model import BoringModel
@ -28,7 +28,7 @@ from tests.base.boring_model import BoringModel
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
@pytest.mark.skipif(not RPC_AVAILABLE, reason="RPC is not available")
@pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available")
def test_rpc_choice(tmpdir, ddp_backend, gpus, num_processes):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
@ -88,7 +88,7 @@ class CustomRPCPlugin(RPCPlugin):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not RPC_AVAILABLE, reason="RPC is not available")
@pytest.mark.skipif(not _RPC_AVAILABLE, reason="RPC is not available")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_rpc_function_calls_ddp(tmpdir):

View File

@ -8,8 +8,8 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE, APEX_AVAILABLE
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, _APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel
@ -30,7 +30,7 @@ from tests.base.boring_model import BoringModel
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes):
"""
Test to ensure that plugin is correctly chosen
@ -55,8 +55,8 @@ def test_ddp_choice_sharded(tmpdir, ddp_backend, gpus, num_processes):
trainer.fit(model)
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_invalid_apex_sharded(tmpdir):
"""
Test to ensure that we raise an error when we try to use apex and sharded
@ -91,8 +91,8 @@ def test_invalid_apex_sharded(tmpdir):
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes):
"""
Test to ensure that plugin native amp plugin is correctly chosen when using sharded
@ -121,7 +121,7 @@ def test_ddp_choice_sharded_amp(tmpdir, ddp_backend, gpus, num_processes):
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir):
"""
Test to ensure that checkpoint is saved correctly
@ -147,7 +147,7 @@ def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir):
"""
Test to ensure that checkpoint is saved correctly when using multiple GPUs
@ -174,7 +174,7 @@ def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_finetune(tmpdir):
"""
Test to ensure that we can save and restart training (simulate fine-tuning)
@ -200,7 +200,7 @@ def test_ddp_sharded_plugin_finetune(tmpdir):
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir):
"""
Test to ensure that resuming from checkpoint works
@ -234,7 +234,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint(tmpdir):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir):
"""
Test to ensure that resuming from checkpoint works when downsizing number of GPUS
@ -268,7 +268,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_downsize_gpus(tmpdir):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir):
"""
Test to ensure that resuming from checkpoint works when going from GPUs- > CPU
@ -300,7 +300,7 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir):
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_test(tmpdir):
"""
Test to ensure we can use test without fit
@ -318,7 +318,7 @@ def test_ddp_sharded_plugin_test(tmpdir):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
@pytest.mark.skipif(not _FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_test_multigpu(tmpdir):
"""
Test to ensure we can use test without fit

View File

@ -22,7 +22,7 @@ import torch.distributed as torch_distrib
import torch.nn.functional as F
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities import _APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel
@ -301,7 +301,7 @@ def test_multiple_optimizers_manual_native_amp(tmpdir):
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not APEX_AVAILABLE, reason="test requires apex")
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
def test_multiple_optimizers_manual_apex(tmpdir):
"""
Tests that only training_step can be used

View File

@ -16,7 +16,7 @@ import torch
import pytest
from tests.base.boring_model import BoringModel, RandomDataset
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import APEX_AVAILABLE
from pytorch_lightning.utilities import _APEX_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException

View File

@ -33,7 +33,7 @@ from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import BoringModel, EvalModelTemplate
@ -968,7 +968,7 @@ def test_gradient_clipping(tmpdir):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
def test_gradient_clipping_fp16(tmpdir):
"""
Test gradient clipping with fp16

View File

@ -19,7 +19,7 @@ from torch.utils.data import RandomSampler, SequentialSampler, DataLoader
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import AMPType, NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities import AMPType, _NATIVE_AMP_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.base.datamodules import MNISTDataModule
@ -328,7 +328,7 @@ def test_error_on_dataloader_passed_to_fit(tmpdir):
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
@pytest.mark.skipif(not _NATIVE_AMP_AVAILABLE, reason="test requires native AMP.")
def test_auto_scale_batch_size_with_amp(tmpdir):
model = EvalModelTemplate()
batch_size_before = model.batch_size

View File

@ -16,21 +16,21 @@ import time
import pytest
import pytorch_lightning.utilities.xla_device_utils as xla_utils
from pytorch_lightning.utilities import XLA_AVAILABLE, TPU_AVAILABLE
from pytorch_lightning.utilities import _XLA_AVAILABLE, _TPU_AVAILABLE
from tests.base.develop_utils import pl_multi_process_test
if XLA_AVAILABLE:
if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
# lets hope that in or env we have installed XLA only for TPU devices, otherwise, it is testing in the cycle "if I am true test that I am true :D"
@pytest.mark.skipif(XLA_AVAILABLE, reason="test requires torch_xla to be absent")
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""
assert xla_utils.XLADeviceUtils.tpu_device_exists() is None
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires torch_xla to be installed")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires torch_xla to be installed")
@pl_multi_process_test
def test_tpu_device_presence():
"""Check tpu_device_exists returns True when TPU is available"""