simplify skip-if tests >> 0/n (#5920)
* skipif + yapf + isort * tests * docs * pp
This commit is contained in:
parent
15c477e9fc
commit
58a6d59784
|
@ -29,7 +29,6 @@ if TYPE_CHECKING:
|
|||
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
|
||||
_STEP_OUTPUT_TYPE = Union[torch.Tensor, Dict[str, torch.Tensor], None]
|
||||
|
||||
|
||||
|
@ -224,9 +223,7 @@ class Accelerator(object):
|
|||
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
|
||||
return self.training_type_plugin.predict(*args)
|
||||
|
||||
def training_step_end(
|
||||
self, output: _STEP_OUTPUT_TYPE
|
||||
) -> _STEP_OUTPUT_TYPE:
|
||||
def training_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
|
||||
"""A hook to do something at the end of the training step
|
||||
|
||||
Args:
|
||||
|
@ -234,9 +231,7 @@ class Accelerator(object):
|
|||
"""
|
||||
return self.training_type_plugin.training_step_end(output)
|
||||
|
||||
def test_step_end(
|
||||
self, output: _STEP_OUTPUT_TYPE
|
||||
) -> _STEP_OUTPUT_TYPE:
|
||||
def test_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
|
||||
"""A hook to do something at the end of the test step
|
||||
|
||||
Args:
|
||||
|
@ -244,9 +239,7 @@ class Accelerator(object):
|
|||
"""
|
||||
return self.training_type_plugin.test_step_end(output)
|
||||
|
||||
def validation_step_end(
|
||||
self, output: _STEP_OUTPUT_TYPE
|
||||
) -> _STEP_OUTPUT_TYPE:
|
||||
def validation_step_end(self, output: _STEP_OUTPUT_TYPE) -> _STEP_OUTPUT_TYPE:
|
||||
"""A hook to do something at the end of the validation step
|
||||
|
||||
Args:
|
||||
|
@ -400,9 +393,7 @@ class Accelerator(object):
|
|||
"""
|
||||
return self.training_type_plugin.broadcast(obj, src)
|
||||
|
||||
def all_gather(
|
||||
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
|
||||
) -> torch.Tensor:
|
||||
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Function to gather a tensor from several distributed processes.
|
||||
|
||||
|
|
|
@ -36,9 +36,7 @@ class TPUAccelerator(Accelerator):
|
|||
) -> None:
|
||||
xm.optimizer_step(optimizer, optimizer_args={'closure': lambda_closure, **kwargs})
|
||||
|
||||
def all_gather(
|
||||
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
|
||||
) -> torch.Tensor:
|
||||
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
|
||||
"""
|
||||
Function to gather a tensor from several distributed processes
|
||||
Args:
|
||||
|
|
|
@ -75,9 +75,7 @@ class DeepSpeedPrecisionPlugin(PrecisionPlugin):
|
|||
|
||||
return closure_loss
|
||||
|
||||
def clip_gradients(
|
||||
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
|
||||
) -> None:
|
||||
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
|
||||
"""
|
||||
DeepSpeed handles clipping gradients via the training type plugin.
|
||||
"""
|
||||
|
|
|
@ -32,8 +32,6 @@ class ShardedNativeMixedPrecisionPlugin(NativeMixedPrecisionPlugin):
|
|||
super().__init__()
|
||||
self.scaler = ShardedGradScaler()
|
||||
|
||||
def clip_gradients(
|
||||
self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0
|
||||
) -> None:
|
||||
def clip_gradients(self, optimizer: 'Optimizer', clip_val: Union[int, float], norm_type: float = 2.0) -> None:
|
||||
optimizer = cast(OSS, optimizer)
|
||||
optimizer.clip_grad_norm(clip_val, norm_type=norm_type)
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from inspect import signature
|
||||
from typing import List, Dict, Any, Type, Callable
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
|
@ -214,8 +214,7 @@ class TrainerCallbackHookMixin(ABC):
|
|||
rank_zero_warn(
|
||||
"`Callback.on_save_checkpoint` signature has changed in v1.3."
|
||||
" A `checkpoint` parameter has been added."
|
||||
" Support for the old signature will be removed in v1.5",
|
||||
DeprecationWarning
|
||||
" Support for the old signature will be removed in v1.5", DeprecationWarning
|
||||
)
|
||||
state = callback.on_save_checkpoint(self, self.lightning_module) # noqa: parameter-unfilled
|
||||
else:
|
||||
|
|
|
@ -22,8 +22,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE
|
||||
from pytorch_lightning.utilities.imports import _module_available
|
||||
from pytorch_lightning.utilities.imports import _module_available, _TORCHTEXT_AVAILABLE
|
||||
|
||||
if _TORCHTEXT_AVAILABLE:
|
||||
if _module_available("torchtext.legacy.data"):
|
||||
|
|
|
@ -14,9 +14,6 @@
|
|||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.utilities import _TORCH_LOWER_EQUAL_1_4, _TORCH_QUANTIZE_AVAILABLE
|
||||
|
||||
_TEST_ROOT = os.path.dirname(__file__)
|
||||
_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
|
||||
|
@ -34,13 +31,3 @@ RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
|
|||
|
||||
if not os.path.isdir(_TEMP_PATH):
|
||||
os.mkdir(_TEMP_PATH)
|
||||
|
||||
_MISS_QUANT_DEFAULT = 'fbgemm' not in torch.backends.quantized.supported_engines
|
||||
|
||||
_SKIPIF_ARGS_PT_LE_1_4 = dict(condition=_TORCH_LOWER_EQUAL_1_4, reason="test pytorch > 1.4")
|
||||
_SKIPIF_ARGS_NO_GPU = dict(condition=not torch.cuda.is_available(), reason="test requires single-GPU machine")
|
||||
_SKIPIF_ARGS_NO_GPUS = dict(condition=torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
_SKIPIF_ARGS_NO_PT_QUANT = dict(
|
||||
condition=not _TORCH_QUANTIZE_AVAILABLE or _MISS_QUANT_DEFAULT,
|
||||
reason="PyTorch quantization is needed for this test"
|
||||
)
|
||||
|
|
|
@ -20,16 +20,17 @@ from pytorch_lightning import seed_everything, Trainer
|
|||
from pytorch_lightning.callbacks import QuantizationAwareTraining
|
||||
from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests import _SKIPIF_ARGS_NO_PT_QUANT, _SKIPIF_ARGS_PT_LE_1_4
|
||||
from tests.helpers.datamodules import RegressDataModule
|
||||
from tests.helpers.simple_models import RegressionModel
|
||||
from tests.helpers.skipif import skipif_args
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"observe", ['average', pytest.param('histogram', marks=pytest.mark.skipif(**_SKIPIF_ARGS_PT_LE_1_4))]
|
||||
"observe",
|
||||
['average', pytest.param('histogram', marks=pytest.mark.skipif(**skipif_args(min_torch="1.5")))]
|
||||
)
|
||||
@pytest.mark.parametrize("fuse", [True, False])
|
||||
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
|
||||
@pytest.mark.skipif(**skipif_args(quant_available=True))
|
||||
def test_quantization(tmpdir, observe, fuse):
|
||||
"""Parity test for quant model"""
|
||||
seed_everything(42)
|
||||
|
@ -64,7 +65,7 @@ def test_quantization(tmpdir, observe, fuse):
|
|||
assert torch.allclose(org_score, quant_score, atol=0.45)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
|
||||
@pytest.mark.skipif(**skipif_args(quant_available=True))
|
||||
def test_quantize_torchscript(tmpdir):
|
||||
"""Test converting to torchscipt """
|
||||
dm = RegressDataModule()
|
||||
|
@ -80,7 +81,7 @@ def test_quantize_torchscript(tmpdir):
|
|||
tsmodel(tsmodel.quant(batch[0]))
|
||||
|
||||
|
||||
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
|
||||
@pytest.mark.skipif(**skipif_args(quant_available=True))
|
||||
def test_quantization_exceptions(tmpdir):
|
||||
"""Test wrong fuse layers"""
|
||||
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):
|
||||
|
@ -123,7 +124,7 @@ def custom_trigger_last(trainer):
|
|||
(custom_trigger_last, 2),
|
||||
]
|
||||
)
|
||||
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
|
||||
@pytest.mark.skipif(**skipif_args(quant_available=True))
|
||||
def test_quantization_triggers(tmpdir, trigger_fn, expected_count):
|
||||
"""Test how many times the quant is called"""
|
||||
dm = RegressDataModule()
|
||||
|
|
|
@ -25,8 +25,8 @@ import tests.helpers.utils as tutils
|
|||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.trainer.states import TrainerState
|
||||
from tests import _SKIPIF_ARGS_NO_GPU
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
from tests.helpers.skipif import skipif_args
|
||||
|
||||
|
||||
def _setup_ddp(rank, worldsize):
|
||||
|
@ -72,7 +72,7 @@ def test_result_reduce_ddp(result_cls):
|
|||
pytest.param(5, False, 0, id='nested_list_predictions'),
|
||||
pytest.param(6, False, 0, id='dict_list_predictions'),
|
||||
pytest.param(7, True, 0, id='write_dict_predictions'),
|
||||
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**_SKIPIF_ARGS_NO_GPU))
|
||||
pytest.param(0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(**skipif_args(min_gpus=1)))
|
||||
]
|
||||
)
|
||||
def test_result_obj_predictions(tmpdir, test_option, do_train, gpus):
|
||||
|
|
|
@ -17,7 +17,7 @@ from unittest import mock
|
|||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.utils import no_warning_call
|
||||
|
@ -30,7 +30,9 @@ def test_v1_5_0_wandb_unused_sync_step(tmpdir):
|
|||
|
||||
|
||||
def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
|
||||
|
||||
class OldSignature(Callback):
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module): # noqa
|
||||
...
|
||||
|
||||
|
@ -49,14 +51,17 @@ def test_v1_5_0_old_callback_on_save_checkpoint(tmpdir):
|
|||
trainer.save_checkpoint(filepath)
|
||||
|
||||
class NewSignature(Callback):
|
||||
|
||||
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
||||
...
|
||||
|
||||
class ValidSignature1(Callback):
|
||||
|
||||
def on_save_checkpoint(self, trainer, *args):
|
||||
...
|
||||
|
||||
class ValidSignature2(Callback):
|
||||
|
||||
def on_save_checkpoint(self, *args):
|
||||
...
|
||||
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from pkg_resources import get_distribution
|
||||
|
||||
from pytorch_lightning.utilities import _TORCH_QUANTIZE_AVAILABLE
|
||||
|
||||
|
||||
def skipif_args(
|
||||
min_gpus: int = 0,
|
||||
min_torch: Optional[str] = None,
|
||||
quant_available: bool = False,
|
||||
) -> dict:
|
||||
""" Creating aggregated arguments for standard pytest skipif, sot the usecase is::
|
||||
|
||||
@pytest.mark.skipif(**create_skipif(min_torch="99"))
|
||||
def test_any_func(...):
|
||||
...
|
||||
|
||||
>>> from pprint import pprint
|
||||
>>> pprint(skipif_args(min_torch="99", min_gpus=0))
|
||||
{'condition': True, 'reason': 'Required: [torch>=99]'}
|
||||
>>> pprint(skipif_args(min_torch="0.0", min_gpus=0)) # doctest: +NORMALIZE_WHITESPACE
|
||||
{'condition': False, 'reason': 'Conditions satisfied, going ahead with the test.'}
|
||||
"""
|
||||
conditions = []
|
||||
reasons = []
|
||||
|
||||
if min_gpus:
|
||||
conditions.append(torch.cuda.device_count() < min_gpus)
|
||||
reasons.append(f"GPUs>={min_gpus}")
|
||||
|
||||
if min_torch:
|
||||
torch_version = LooseVersion(get_distribution("torch").version)
|
||||
conditions.append(torch_version < LooseVersion(min_torch))
|
||||
reasons.append(f"torch>={min_torch}")
|
||||
|
||||
if quant_available:
|
||||
_miss_default = 'fbgemm' not in torch.backends.quantized.supported_engines
|
||||
conditions.append(not _TORCH_QUANTIZE_AVAILABLE or _miss_default)
|
||||
reasons.append("PyTorch quantization is available")
|
||||
|
||||
if not any(conditions):
|
||||
return dict(condition=False, reason="Conditions satisfied, going ahead with the test.")
|
||||
|
||||
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
|
||||
return dict(condition=any(conditions), reason=f"Required: [{' + '.join(reasons)}]",)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**skipif_args(min_torch="99"))
|
||||
def test_always_skip():
|
||||
exit(1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(**skipif_args(min_torch="0.0"))
|
||||
def test_always_pass():
|
||||
assert True
|
|
@ -34,7 +34,7 @@ def test_optimizer_with_scheduling(tmpdir):
|
|||
max_epochs=1,
|
||||
limit_val_batches=0.1,
|
||||
limit_train_batches=0.2,
|
||||
val_check_interval=0.5
|
||||
val_check_interval=0.5,
|
||||
)
|
||||
trainer.fit(model)
|
||||
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
|
||||
|
|
Loading…
Reference in New Issue