diff --git a/pl_examples/basic_examples/dali_image_classifier.py b/pl_examples/basic_examples/dali_image_classifier.py index 0a39f1cb9a..291490d6f9 100644 --- a/pl_examples/basic_examples/dali_image_classifier.py +++ b/pl_examples/basic_examples/dali_image_classifier.py @@ -84,7 +84,8 @@ class ExternalSourcePipeline(Pipeline): class DALIClassificationLoader(DALIClassificationIterator): """ - This class extends DALI's original DALIClassificationIterator with the __len__() function so that we can call len() on it + This class extends DALI's original `DALIClassificationIterator` with the `__len__()` function + so that we can call `len()` on it """ def __init__( diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 7e76321cfd..afb6206a97 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1188,7 +1188,8 @@ class LightningModule( By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer. - .. tip:: With `Trainer(enable_pl_optimizer=True)`, you can user `optimizer.step()` directly and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you. + .. tip:: With `Trainer(enable_pl_optimizer=True)`, you can user `optimizer.step()` directly + and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you. Warning: If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index ef7aa4c585..71b1292a48 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -51,7 +51,8 @@ class LightningOptimizer: # For Horovod if hasattr(optimizer, "skip_synchronize"): - self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__.__bases__[0]), {}) + self.__class__ = type("Lightning" + optimizer.__class__.__name__, + (self.__class__, optimizer.__class__.__bases__[0]), {}) self.skip_synchronize = optimizer.skip_synchronize self.synchronize = optimizer.synchronize else: diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index e7c0c707ad..7cd5a23154 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -174,7 +174,8 @@ class ModelIO(object): cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key)) # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace - cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE)) + cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, + checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE)) # 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME) diff --git a/pytorch_lightning/metrics/classification/helpers.py b/pytorch_lightning/metrics/classification/helpers.py index afb97e6e0a..f52fb743b6 100644 --- a/pytorch_lightning/metrics/classification/helpers.py +++ b/pytorch_lightning/metrics/classification/helpers.py @@ -127,7 +127,8 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool): if num_classes == 1 and is_multiclass: raise ValueError( "You have binary data and have set `is_multiclass=True`, but `num_classes` is 1." - " Either set `is_multiclass=None`(default) or set `num_classes=2` to transform binary data to multi-class format." + " Either set `is_multiclass=None`(default) or set `num_classes=2`" + " to transform binary data to multi-class format." ) diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 6209048985..4af8041ce9 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -73,11 +73,12 @@ class PrecisionRecallCurve(Metric): >>> target = torch.tensor([0, 1, 3, 2]) >>> pr_curve = PrecisionRecallCurve(num_classes=5) >>> precision, recall, thresholds = pr_curve(pred, target) - >>> precision - [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] >>> recall [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] - >>> thresholds # doctest: +NORMALIZE_WHITESPACE + >>> thresholds [tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])] """ diff --git a/pytorch_lightning/metrics/functional/classification.py b/pytorch_lightning/metrics/functional/classification.py index ffada88402..140dff7159 100644 --- a/pytorch_lightning/metrics/functional/classification.py +++ b/pytorch_lightning/metrics/functional/classification.py @@ -18,7 +18,10 @@ import torch from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1 -from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve, precision_recall_curve as __prc +from pytorch_lightning.metrics.functional.precision_recall_curve import ( + _binary_clf_curve, + precision_recall_curve as __prc +) from pytorch_lightning.metrics.functional.roc import roc as __roc from pytorch_lightning.metrics.utils import ( to_categorical as __tc, @@ -821,7 +824,8 @@ def precision_recall_curve( """ Computes precision-recall pairs for different thresholds. - .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` """ rank_zero_warn( "This `precision_recall_curve` was deprecated in v1.1.0 in favor of" @@ -841,7 +845,8 @@ def multiclass_precision_recall_curve( """ Computes precision-recall pairs for different thresholds given a multiclass scores. - .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve` """ rank_zero_warn( "This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of" @@ -863,7 +868,8 @@ def average_precision( """ Compute average precision from prediction scores. - .. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision` + .. warning :: Deprecated in favor of + :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision` """ rank_zero_warn( "This `average_precision` was deprecated in v1.1.0 in favor of" diff --git a/pytorch_lightning/metrics/functional/precision_recall_curve.py b/pytorch_lightning/metrics/functional/precision_recall_curve.py index 6c112fe010..e497c5f7b3 100644 --- a/pytorch_lightning/metrics/functional/precision_recall_curve.py +++ b/pytorch_lightning/metrics/functional/precision_recall_curve.py @@ -208,8 +208,9 @@ def precision_recall_curve( ... [0.05, 0.05, 0.05, 0.75, 0.05]]) >>> target = torch.tensor([0, 1, 3, 2]) >>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5) - >>> precision - [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] + >>> precision # doctest: +NORMALIZE_WHITESPACE + [tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), + tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])] >>> recall [tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])] >>> thresholds diff --git a/pytorch_lightning/setup_tools.py b/pytorch_lightning/setup_tools.py index 3842bbe50c..26a607a295 100644 --- a/pytorch_lightning/setup_tools.py +++ b/pytorch_lightning/setup_tools.py @@ -172,7 +172,8 @@ def _load_long_description(path_dir: str) -> str: # readthedocs badge text = text.replace('badge/?version=stable', f'badge/?version={__version__}') - text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{__version__}') + text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', + f'pytorch-lightning.readthedocs.io/en/{__version__}') # codecov badge text = text.replace('/branch/master/graph/badge.svg', f'/release/{__version__}/graph/badge.svg') # replace github badges for release ones diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 7facbf810d..fb77905f18 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -282,7 +282,9 @@ class CheckpointConnector: checkpoint['lr_schedulers'] = lr_schedulers # dump amp scaling - if self.trainer.amp_backend == AMPType.NATIVE and not self.trainer.use_tpu and self.trainer.scaler is not None: + if (self.trainer.amp_backend == AMPType.NATIVE + and not self.trainer.use_tpu + and self.trainer.scaler is not None): checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict() elif self.trainer.amp_backend == AMPType.APEX: checkpoint['amp_scaling_state'] = amp.state_dict() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index d57919d4dd..a3f86f6287 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -168,8 +168,9 @@ class LoggerConnector: metrics (dict): Metric values grad_norm_dic (dict): Gradient norms step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` - log_train_step_metrics (bool): Used to track if log_metrics function is being called in during training steps. - In training steps, we will log metrics on step: total_nb_idx (for accumulated gradients) and global_step for the rest. + log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training + steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients) + and global_step for the rest. """ # add gpu memory if self.trainer.on_gpu and self.trainer.log_gpu_memory: diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 479d401720..ee5f80207e 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -125,7 +125,8 @@ class TrainerOptimizersMixin(ABC): if monitor is None: raise MisconfigurationException( '`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.' - ' For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' + ' For example:' + ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' ) lr_schedulers.append( {**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor} diff --git a/setup.cfg b/setup.cfg index a1992016ba..2086758bf2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,13 +65,12 @@ verbose = 2 # https://pep8.readthedocs.io/en/latest/intro.html#error-codes format = pylint ignore = - E731 - W504 - F401 - E203 # E203 - whitespace before ':'. Opposite convention enforced by black - E231 # E231: missing whitespace after ',', ';', or ':'; for black - E501 # E501 - line too long. Handled by black, we have longer lines - W503 # W503 - line break before binary operator, need for black + E731 # do not assign a lambda expression, use a def + W504 # line break occurred after a binary operator + F401 # module imported but unused + E203 # whitespace before ':'. Opposite convention enforced by black + E231 # missing whitespace after ',', ';', or ':'; for black + W503 # line break before binary operator, need for black # setup.cfg or tox.ini [check-manifest] diff --git a/tests/base/models.py b/tests/base/models.py index 79f3f843ec..b346e370c7 100644 --- a/tests/base/models.py +++ b/tests/base/models.py @@ -73,7 +73,8 @@ class Discriminator(nn.Module): class BasicGAN(LightningModule): """Implements a basic GAN for the purpose of illustrating multiple optimizers.""" - def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs): + def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, + b1: float = 0.5, b2: float = 0.999, **kwargs): super().__init__() self.hidden_dim = hidden_dim self.learning_rate = learning_rate diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 7cecefad03..93830577d6 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -213,11 +213,14 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered, THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop. - IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is being triggered, - THEN the trainer should continue until reaching `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop. + IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` + when `early_stopping` is being triggered, + THEN the trainer should continue until reaching + `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop. This test validate this expected behaviour - IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when `early_stopping` is being triggered, + IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` + when `early_stopping` is being triggered, THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached. Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 9817dfa452..a7260be0ea 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -269,7 +269,12 @@ def test_model_checkpoint_file_extension(tmpdir): """ model = LogInTwoMethods() - model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True) + model_checkpoint = ModelCheckpointExtensionTest( + monitor='early_stop_on', + dirpath=tmpdir, + save_top_k=1, + save_last=True, + ) trainer = Trainer( default_root_dir=tmpdir, callbacks=[model_checkpoint], diff --git a/tests/core/test_results.py b/tests/core/test_results.py index 797004b7f2..d9fd3a473a 100644 --- a/tests/core/test_results.py +++ b/tests/core/test_results.py @@ -87,7 +87,8 @@ def test_result_reduce_ddp(result_cls): 7, True, 0, id='write_dict_predictions' ), pytest.param( - 0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine") + 0, True, 1, id='full_loop_single_gpu', + marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine") ) ] ) diff --git a/tests/metrics/classification/test_f_beta.py b/tests/metrics/classification/test_f_beta.py index 9cadd48bce..5939052cf1 100644 --- a/tests/metrics/classification/test_f_beta.py +++ b/tests/metrics/classification/test_f_beta.py @@ -86,23 +86,14 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0): (_binary_inputs.preds, _binary_inputs.target, _sk_fbeta_binary, 1, False), (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True), (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_fbeta_multilabel, NUM_CLASSES, True), - (_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target, _sk_fbeta_multilabel, NUM_CLASSES, True), + (_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target, + _sk_fbeta_multilabel, NUM_CLASSES, True), (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False), (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_fbeta_multiclass, NUM_CLASSES, False), - ( - _multidim_multiclass_prob_inputs.preds, - _multidim_multiclass_prob_inputs.target, - _sk_fbeta_multidim_multiclass_prob, - NUM_CLASSES, - False, - ), - ( - _multidim_multiclass_inputs.preds, - _multidim_multiclass_inputs.target, - _sk_fbeta_multidim_multiclass, - NUM_CLASSES, - False, - ), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, + _sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, + _sk_fbeta_multidim_multiclass, NUM_CLASSES, False), ], ) @pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None]) diff --git a/tests/metrics/classification/test_precision_recall.py b/tests/metrics/classification/test_precision_recall.py index 967bc60e28..6a399ba7f6 100644 --- a/tests/metrics/classification/test_precision_recall.py +++ b/tests/metrics/classification/test_precision_recall.py @@ -85,24 +85,16 @@ def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, av [ (_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False), (_binary_inputs.preds, _binary_inputs.target, _sk_prec_recall_binary, 1, False), - (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_prec_recall_multilabel_prob, NUM_CLASSES, True), + (_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, + _sk_prec_recall_multilabel_prob, NUM_CLASSES, True), (_multilabel_inputs.preds, _multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True), - (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_prec_recall_multiclass_prob, NUM_CLASSES, False), + (_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, + _sk_prec_recall_multiclass_prob, NUM_CLASSES, False), (_multiclass_inputs.preds, _multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False), - ( - _multidim_multiclass_prob_inputs.preds, - _multidim_multiclass_prob_inputs.target, - _sk_prec_recall_multidim_multiclass_prob, - NUM_CLASSES, - False, - ), - ( - _multidim_multiclass_inputs.preds, - _multidim_multiclass_inputs.target, - _sk_prec_recall_multidim_multiclass, - NUM_CLASSES, - False, - ), + (_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target, + _sk_prec_recall_multidim_multiclass_prob, NUM_CLASSES, False), + (_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target, + _sk_prec_recall_multidim_multiclass, NUM_CLASSES, False), ], ) @pytest.mark.parametrize( diff --git a/tests/metrics/functional/test_classification.py b/tests/metrics/functional/test_classification.py index a6fbe9e849..4a2d690ec0 100644 --- a/tests/metrics/functional/test_classification.py +++ b/tests/metrics/functional/test_classification.py @@ -153,7 +153,8 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean', torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8)) ]) -def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support): +def test_stat_scores_multiclass(pred, target, reduction, + expected_tp, expected_fp, expected_tn, expected_fn, expected_support): tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction) assert torch.allclose(torch.tensor(expected_tp).to(tp), tp) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 78e2ba2944..a77b4eb451 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -520,7 +520,8 @@ def test_log_works_in_train_callback(tmpdir): def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx, on_steps=[], on_epochs=[], prob_bars=[]): self.funcs_called_count[func_name] += 1 - for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))): + iterate = list(itertools.product(*[on_steps, on_epochs, prob_bars])) + for idx, (on_step, on_epoch, prog_bar) in enumerate(iterate): # run logging custom_func_name = f"{func_idx}_{idx}_{func_name}" pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step, diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 319c5b2bb1..a7cf2593d8 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -942,7 +942,8 @@ def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, m @patch("torch.optim.Adam.step") @patch("torch.optim.SGD.step") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest") +@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', + reason="test should be run outside of pytest") def test_step_with_optimizer_closure_with_different_frequencies_ddp(mock_sgd_step, mock_adam_step, tmpdir): """ Tests that `step` works with optimizer_closure and different accumulated_gradient frequency diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 93690b606d..e5a6bbb1a5 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -316,9 +316,11 @@ def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_bat self.on_train_batch_start_end_dict = self.state_dict() for key in self.on_train_batch_start_end_dict.keys(): if (batch_idx + 1) == self.trainer.num_training_batches: - assert torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key]) + assert torch.equal(self.on_train_batch_start_state_dict[key], + self.on_train_batch_start_end_dict[key]) else: - assert not torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key]) + assert not torch.equal(self.on_train_batch_start_state_dict[key], + self.on_train_batch_start_end_dict[key]) model = CurrentModel() diff --git a/tests/utilities/test_xla_device_utils.py b/tests/utilities/test_xla_device_utils.py index 0e9800c29d..a16cda958c 100644 --- a/tests/utilities/test_xla_device_utils.py +++ b/tests/utilities/test_xla_device_utils.py @@ -23,7 +23,8 @@ if _XLA_AVAILABLE: import torch_xla.core.xla_model as xm -# lets hope that in or env we have installed XLA only for TPU devices, otherwise, it is testing in the cycle "if I am true test that I am true :D" +# lets hope that in or env we have installed XLA only for TPU devices, otherwise, +# it is testing in the cycle "if I am true test that I am true :D" @pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent") def test_tpu_device_absence(): """Check tpu_device_exists returns None when torch_xla is not available"""