From 58a6d59784b9601bab631058028af7d9cd780f23 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 1 Mar 2021 13:17:09 +0100 Subject: [PATCH] simplify skip-if tests >> 0/n (#5920) * skipif + yapf + isort * tests * docs * pp --- pytorch_lightning/accelerators/accelerator.py | 17 ++--- pytorch_lightning/accelerators/tpu.py | 4 +- .../plugins/precision/deepspeed_precision.py | 4 +- .../plugins/precision/sharded_native_amp.py | 4 +- pytorch_lightning/trainer/callback_hook.py | 5 +- pytorch_lightning/utilities/apply_func.py | 3 +- tests/__init__.py | 13 ---- tests/callbacks/test_quantization.py | 13 ++-- tests/core/test_results.py | 4 +- tests/deprecated_api/test_remove_1-5.py | 7 +- tests/helpers/skipif.py | 72 +++++++++++++++++++ tests/trainer/optimization/test_optimizers.py | 2 +- 12 files changed, 98 insertions(+), 50 deletions(-) create mode 100644 tests/helpers/skipif.py diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 9863fab79c..38fb423d22 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -29,7 +29,6 @@ if TYPE_CHECKING: from pytorch_lightning.trainer.trainer import Trainer - _STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None] @@ -224,9 +223,7 @@ class Accelerator(object): with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context(): return self.training_type_plugin.predict(*args) - def training_step_end( - self, output: _STEP_OUTPUT_TYPE - ) -> _STEP_OUTPUT_TYPE: + def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """A hook to do something at the end of the training step Args: @@ -234,9 +231,7 @@ class Accelerator(object): """ return self.training_type_plugin.training_step_end(output) - def test_step_end( - self, output: _STEP_OUTPUT_TYPE - ) -> _STEP_OUTPUT_TYPE: + def test_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """A hook to do something at the end of the test step Args: @@ -244,9 +239,7 @@ class Accelerator(object): """ return self.training_type_plugin.test_step_end(output) - def validation_step_end( - self, output: _STEP_OUTPUT_TYPE - ) -> _STEP_OUTPUT_TYPE: + def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE: """A hook to do something at the end of the validation step Args: @@ -400,9 +393,7 @@ class Accelerator(object): """ return self.training_type_plugin.broadcast(obj, src) - def all_gather( - self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False - ) -> torch.Tensor: + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """ Function to gather a tensor from several distributed processes. diff --git a/pytorch_lightning/accelerators/tpu.py b/pytorch_lightning/accelerators/tpu.py index 4a8467e10b..bbadd571d1 100644 --- a/pytorch_lightning/accelerators/tpu.py +++ b/pytorch_lightning/accelerators/tpu.py @@ -36,9 +36,7 @@ class TPUAccelerator(Accelerator): ) -> None: xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs}) - def all_gather( - self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False - ) -> torch.Tensor: + def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: """ Function to gather a tensor from several distributed processes Args: diff --git a/pytorch_lightning/plugins/precision/deepspeed_precision.py b/pytorch_lightning/plugins/precision/deepspeed_precision.py index 4d36097e1a..6bcbb5ad85 100644 --- a/pytorch_lightning/plugins/precision/deepspeed_precision.py +++ b/pytorch_lightning/plugins/precision/deepspeed_precision.py @@ -75,9 +75,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin): return closure_loss - def clip_gradients( - self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0 - ) -> None: + def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: """ DeepSpeed handles clipping gradients via the training type plugin. """ diff --git a/pytorch_lightning/plugins/precision/sharded_native_amp.py b/pytorch_lightning/plugins/precision/sharded_native_amp.py index dc9c4903ec..12ae5d0bc6 100644 --- a/pytorch_lightning/plugins/precision/sharded_native_amp.py +++ b/pytorch_lightning/plugins/precision/sharded_native_amp.py @@ -32,8 +32,6 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin): super().__init__() self.scaler = ShardedGradScaler() - def clip_gradients( - self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0 - ) -> None: + def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None: optimizer = cast(OSS, optimizer) optimizer.clip_grad_norm(clip_val, norm_type=norm_type) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 60e9183ac4..8f9fc3ad93 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -15,7 +15,7 @@ from abc import ABC from copy import deepcopy from inspect import signature -from typing import List, Dict, Any, Type, Callable +from typing import Any, Callable, Dict, List, Type from pytorch_lightning.callbacks import Callback from pytorch_lightning.core.lightning import LightningModule @@ -214,8 +214,7 @@ class TrainerCallbackHookMixin(ABC): rank_zero_warn( "`Callback.on_save_checkpoint` signature has changed in v1.3." " A `checkpoint` parameter has been added." - " Support for the old signature will be removed in v1.5", - DeprecationWarning + " Support for the old signature will be removed in v1.5", DeprecationWarning ) state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled else: diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 27ec0a5389..0599cccec8 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -22,8 +22,7 @@ import numpy as np import torch from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE -from pytorch_lightning.utilities.imports import _module_available +from pytorch_lightning.utilities.imports import _module_available, _TORCHTEXT_AVAILABLE if _TORCHTEXT_AVAILABLE: if _module_available("torchtext.legacy.data"): diff --git a/tests/__init__.py b/tests/__init__.py index a833da7cbd..57feda6280 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -14,9 +14,6 @@ import os import numpy as np -import torch - -from pytorch_lightning.utilities import _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE _TEST_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_TEST_ROOT) @@ -34,13 +31,3 @@ RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000)) if not os.path.isdir(_TEMP_PATH): os.mkdir(_TEMP_PATH) - -_MISS_QUANT_DEFAULT = 'fbgemm' not in torch.backends.quantized.supported_engines - -_SKIPIF_ARGS_PT_LE_1_4 = dict(condition=_TORCH_LOWER_EQUAL_1_4, reason="test pytorch > 1.4") -_SKIPIF_ARGS_NO_GPU = dict(condition=not torch.cuda.is_available(), reason="test requires single-GPU machine") -_SKIPIF_ARGS_NO_GPUS = dict(condition=torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -_SKIPIF_ARGS_NO_PT_QUANT = dict( - condition=not _TORCH_QUANTIZE_AVAILABLE or _MISS_QUANT_DEFAULT, - reason="PyTorch quantization is needed for this test" -) diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py index 7b51b81e1b..37fccdb00f 100644 --- a/tests/callbacks/test_quantization.py +++ b/tests/callbacks/test_quantization.py @@ -20,16 +20,17 @@ from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import QuantizationAwareTraining from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests import _SKIPIF_ARGS_NO_PT_QUANT, _SKIPIF_ARGS_PT_LE_1_4 from tests.helpers.datamodules import RegressDataModule from tests.helpers.simple_models import RegressionModel +from tests.helpers.skipif import skipif_args @pytest.mark.parametrize( - "observe", ['average', pytest.param('histogram', marks=pytest.mark.skipif(**_SKIPIF_ARGS_PT_LE_1_4))] + "observe", + ['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))] ) @pytest.mark.parametrize("fuse", [True, False]) -@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT) +@pytest.mark.skipif(**skipif_args(quant_available=True)) def test_quantization(tmpdir, observe, fuse): """Parity test for quant model""" seed_everything(42) @@ -64,7 +65,7 @@ def test_quantization(tmpdir, observe, fuse): assert torch.allclose(org_score, quant_score, atol=0.45) -@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT) +@pytest.mark.skipif(**skipif_args(quant_available=True)) def test_quantize_torchscript(tmpdir): """Test converting to torchscipt """ dm = RegressDataModule() @@ -80,7 +81,7 @@ def test_quantize_torchscript(tmpdir): tsmodel(tsmodel.quant(batch[0])) -@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT) +@pytest.mark.skipif(**skipif_args(quant_available=True)) def test_quantization_exceptions(tmpdir): """Test wrong fuse layers""" with pytest.raises(MisconfigurationException, match='Unsupported qconfig'): @@ -123,7 +124,7 @@ def custom_trigger_last(trainer): (custom_trigger_last, 2), ] ) -@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT) +@pytest.mark.skipif(**skipif_args(quant_available=True)) def test_quantization_triggers(tmpdir, trigger_fn, expected_count): """Test how many times the quant is called""" dm = RegressDataModule() diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 1793f3e7bb..5db282b6e9 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -25,8 +25,8 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.core.step_result import Result from pytorch_lightning.trainer.states import TrainerState -from tests import _SKIPIF_ARGS_NO_GPU from tests.helpers import BoringDataModule, BoringModel +from tests.helpers.skipif import skipif_args def _setup_ddp(rank, worldsize): @@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls): pytest.param(5, False, 0, id='nested_list_predictions'), pytest.param(6, False, 0, id='dict_list_predictions'), pytest.param(7, True, 0, id='write_dict_predictions'), - pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**_SKIPIF_ARGS_NO_GPU)) + pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1))) ] ) def test_result_obj_predictions(tmpdir, test_option, do_train, gpus): diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 415f1d040b..cb1d461414 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -17,7 +17,7 @@ from unittest import mock import pytest -from pytorch_lightning import Trainer, Callback +from pytorch_lightning import Callback, Trainer from pytorch_lightning.loggers import WandbLogger from tests.helpers import BoringModel from tests.helpers.utils import no_warning_call @@ -30,7 +30,9 @@ def test_v1_5_0_wandb_unused_sync_step(tmpdir): def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): + class OldSignature(Callback): + def on_save_checkpoint(self, trainer, pl_module): # noqa ... @@ -49,14 +51,17 @@ def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir): trainer.save_checkpoint(filepath) class NewSignature(Callback): + def on_save_checkpoint(self, trainer, pl_module, checkpoint): ... class ValidSignature1(Callback): + def on_save_checkpoint(self, trainer, *args): ... class ValidSignature2(Callback): + def on_save_checkpoint(self, *args): ... diff --git a/tests/helpers/skipif.py b/tests/helpers/skipif.py new file mode 100644 index 0000000000..d8f5835dd6 --- /dev/null +++ b/tests/helpers/skipif.py @@ -0,0 +1,72 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from distutils.version import LooseVersion +from typing import Optional + +import pytest +import torch +from pkg_resources import get_distribution + +from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE + + +def skipif_args( + min_gpus: int = 0, + min_torch: Optional[str] = None, + quant_available: bool = False, +) -> dict: + """ Creating aggregated arguments for standard pytest skipif, sot the usecase is:: + + @pytest.mark.skipif(**create_skipif(min_torch="99")) + def test_any_func(...): + ... + + >>> from pprint import pprint + >>> pprint(skipif_args(min_torch="99", min_gpus=0)) + {'condition': True, 'reason': 'Required: [torch>=99]'} + >>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE + {'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'} + """ + conditions = [] + reasons = [] + + if min_gpus: + conditions.append(torch.cuda.device_count() < min_gpus) + reasons.append(f"GPUs>={min_gpus}") + + if min_torch: + torch_version = LooseVersion(get_distribution("torch").version) + conditions.append(torch_version < LooseVersion(min_torch)) + reasons.append(f"torch>={min_torch}") + + if quant_available: + _miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines + conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default) + reasons.append("PyTorch quantization is available") + + if not any(conditions): + return dict(condition=False, reason="Conditions satisfied, going ahead with the test.") + + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] + return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",) + + +@pytest.mark.skipif(**skipif_args(min_torch="99")) +def test_always_skip(): + exit(1) + + +@pytest.mark.skipif(**skipif_args(min_torch="0.0")) +def test_always_pass(): + assert True diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index 36713de792..554fc98740 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -34,7 +34,7 @@ def test_optimizer_with_scheduling(tmpdir): max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2, - val_check_interval=0.5 + val_check_interval=0.5, ) trainer.fit(model) assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"