simplify skip-if tests >> 0/n (#5920)

* skipif + yapf + isort

* tests

* docs

* pp
This commit is contained in:
Jirka Borovec 2021-03-01 13:17:09 +01:00 committed by GitHub
parent 15c477e9fc
commit 58a6d59784
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 98 additions and 50 deletions

View File

@ -29,7 +29,6 @@ if TYPE_CHECKING:
from pytorch_lightning.trainer.trainer import Trainer
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
@ -224,9 +223,7 @@ class Accelerator(object):
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
return self.training_type_plugin.predict(*args)
def training_step_end(
self, output: _STEP_OUTPUT_TYPE
) -> _STEP_OUTPUT_TYPE:
def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the training step
Args:
@ -234,9 +231,7 @@ class Accelerator(object):
"""
return self.training_type_plugin.training_step_end(output)
def test_step_end(
self, output: _STEP_OUTPUT_TYPE
) -> _STEP_OUTPUT_TYPE:
def test_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the test step
Args:
@ -244,9 +239,7 @@ class Accelerator(object):
"""
return self.training_type_plugin.test_step_end(output)
def validation_step_end(
self, output: _STEP_OUTPUT_TYPE
) -> _STEP_OUTPUT_TYPE:
def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
"""A hook to do something at the end of the validation step
Args:
@ -400,9 +393,7 @@ class Accelerator(object):
"""
return self.training_type_plugin.broadcast(obj, src)
def all_gather(
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
) -> torch.Tensor:
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes.

View File

@ -36,9 +36,7 @@ class TPUAccelerator(Accelerator):
) -> None:
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
def all_gather(
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
) -> torch.Tensor:
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Args:

View File

@ -75,9 +75,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
return closure_loss
def clip_gradients(
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
) -> None:
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
"""
DeepSpeed handles clipping gradients via the training type plugin.
"""

View File

@ -32,8 +32,6 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
super().__init__()
self.scaler = ShardedGradScaler()
def clip_gradients(
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
) -> None:
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
optimizer = cast(OSS, optimizer)
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)

View File

@ -15,7 +15,7 @@
from abc import ABC
from copy import deepcopy
from inspect import signature
from typing import List, Dict, Any, Type, Callable
from typing import Any, Callable, Dict, List, Type
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
@ -214,8 +214,7 @@ class TrainerCallbackHookMixin(ABC):
rank_zero_warn(
"`Callback.on_save_checkpoint` signature has changed in v1.3."
" A `checkpoint` parameter has been added."
" Support for the old signature will be removed in v1.5",
DeprecationWarning
" Support for the old signature will be removed in v1.5", DeprecationWarning
)
state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled
else:

View File

@ -22,8 +22,7 @@ import numpy as np
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE
from pytorch_lightning.utilities.imports import _module_available
from pytorch_lightning.utilities.imports import _module_available, _TORCHTEXT_AVAILABLE
if _TORCHTEXT_AVAILABLE:
if _module_available("torchtext.legacy.data"):

View File

@ -14,9 +14,6 @@
import os
import numpy as np
import torch
from pytorch_lightning.utilities import _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE
_TEST_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
@ -34,13 +31,3 @@ RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
if not os.path.isdir(_TEMP_PATH):
os.mkdir(_TEMP_PATH)
_MISS_QUANT_DEFAULT = 'fbgemm' not in torch.backends.quantized.supported_engines
_SKIPIF_ARGS_PT_LE_1_4 = dict(condition=_TORCH_LOWER_EQUAL_1_4, reason="test pytorch > 1.4")
_SKIPIF_ARGS_NO_GPU = dict(condition=not torch.cuda.is_available(), reason="test requires single-GPU machine")
_SKIPIF_ARGS_NO_GPUS = dict(condition=torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
_SKIPIF_ARGS_NO_PT_QUANT = dict(
condition=not _TORCH_QUANTIZE_AVAILABLE or _MISS_QUANT_DEFAULT,
reason="PyTorch quantization is needed for this test"
)

View File

@ -20,16 +20,17 @@ from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import QuantizationAwareTraining
from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests import _SKIPIF_ARGS_NO_PT_QUANT, _SKIPIF_ARGS_PT_LE_1_4
from tests.helpers.datamodules import RegressDataModule
from tests.helpers.simple_models import RegressionModel
from tests.helpers.skipif import skipif_args
@pytest.mark.parametrize(
"observe", ['average', pytest.param('histogram', marks=pytest.mark.skipif(**_SKIPIF_ARGS_PT_LE_1_4))]
"observe",
['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))]
)
@pytest.mark.parametrize("fuse", [True, False])
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
@pytest.mark.skipif(**skipif_args(quant_available=True))
def test_quantization(tmpdir, observe, fuse):
"""Parity test for quant model"""
seed_everything(42)
@ -64,7 +65,7 @@ def test_quantization(tmpdir, observe, fuse):
assert torch.allclose(org_score, quant_score, atol=0.45)
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
@pytest.mark.skipif(**skipif_args(quant_available=True))
def test_quantize_torchscript(tmpdir):
"""Test converting to torchscipt """
dm = RegressDataModule()
@ -80,7 +81,7 @@ def test_quantize_torchscript(tmpdir):
tsmodel(tsmodel.quant(batch[0]))
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
@pytest.mark.skipif(**skipif_args(quant_available=True))
def test_quantization_exceptions(tmpdir):
"""Test wrong fuse layers"""
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):
@ -123,7 +124,7 @@ def custom_trigger_last(trainer):
(custom_trigger_last, 2),
]
)
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
@pytest.mark.skipif(**skipif_args(quant_available=True))
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
"""Test how many times the quant is called"""
dm = RegressDataModule()

View File

@ -25,8 +25,8 @@ import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.trainer.states import TrainerState
from tests import _SKIPIF_ARGS_NO_GPU
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.skipif import skipif_args
def _setup_ddp(rank, worldsize):
@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls):
pytest.param(5, False, 0, id='nested_list_predictions'),
pytest.param(6, False, 0, id='dict_list_predictions'),
pytest.param(7, True, 0, id='write_dict_predictions'),
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**_SKIPIF_ARGS_NO_GPU))
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1)))
]
)
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):

View File

@ -17,7 +17,7 @@ from unittest import mock
import pytest
from pytorch_lightning import Trainer, Callback
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import WandbLogger
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call
@ -30,7 +30,9 @@ def test_v1_5_0_wandb_unused_sync_step(tmpdir):
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
class OldSignature(Callback):
def on_save_checkpoint(self, trainer, pl_module): # noqa
...
@ -49,14 +51,17 @@ def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
trainer.save_checkpoint(filepath)
class NewSignature(Callback):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
...
class ValidSignature1(Callback):
def on_save_checkpoint(self, trainer, *args):
...
class ValidSignature2(Callback):
def on_save_checkpoint(self, *args):
...

72
tests/helpers/skipif.py Normal file
View File

@ -0,0 +1,72 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
from typing import Optional
import pytest
import torch
from pkg_resources import get_distribution
from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
def skipif_args(
min_gpus: int = 0,
min_torch: Optional[str] = None,
quant_available: bool = False,
) -> dict:
""" Creating aggregated arguments for standard pytest skipif, sot the usecase is::
@pytest.mark.skipif(**create_skipif(min_torch="99"))
def test_any_func(...):
...
>>> from pprint import pprint
>>> pprint(skipif_args(min_torch="99", min_gpus=0))
{'condition': True, 'reason': 'Required: [torch>=99]'}
>>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE
{'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'}
"""
conditions = []
reasons = []
if min_gpus:
conditions.append(torch.cuda.device_count() < min_gpus)
reasons.append(f"GPUs>={min_gpus}")
if min_torch:
torch_version = LooseVersion(get_distribution("torch").version)
conditions.append(torch_version < LooseVersion(min_torch))
reasons.append(f"torch>={min_torch}")
if quant_available:
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
reasons.append("PyTorch quantization is available")
if not any(conditions):
return dict(condition=False, reason="Conditions satisfied, going ahead with the test.")
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",)
@pytest.mark.skipif(**skipif_args(min_torch="99"))
def test_always_skip():
exit(1)
@pytest.mark.skipif(**skipif_args(min_torch="0.0"))
def test_always_pass():
assert True

View File

@ -34,7 +34,7 @@ def test_optimizer_with_scheduling(tmpdir):
max_epochs=1,
limit_val_batches=0.1,
limit_train_batches=0.2,
val_check_interval=0.5
val_check_interval=0.5,
)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"