refactor - check E501 (#5200)

This commit is contained in:
Jirka Borovec 2020-12-21 09:53:09 +01:00 committed by GitHub
parent 6d2c564bc6
commit 35fd6e93c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 85 additions and 69 deletions

View File

@ -84,7 +84,8 @@ class ExternalSourcePipeline(Pipeline):
class DALIClassificationLoader(DALIClassificationIterator):
"""
This class extends DALI's original DALIClassificationIterator with the __len__() function so that we can call len() on it
This class extends DALI's original `DALIClassificationIterator` with the `__len__()` function
so that we can call `len()` on it
"""
def __init__(

View File

@ -1188,7 +1188,8 @@ class LightningModule(
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
once per optimizer.
.. tip:: With `Trainer(enable_pl_optimizer=True)`, you can user `optimizer.step()` directly and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you.
.. tip:: With `Trainer(enable_pl_optimizer=True)`, you can user `optimizer.step()` directly
and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you.
Warning:
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter

View File

@ -51,7 +51,8 @@ class LightningOptimizer:
# For Horovod
if hasattr(optimizer, "skip_synchronize"):
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__.__bases__[0]), {})
self.__class__ = type("Lightning" + optimizer.__class__.__name__,
(self.__class__, optimizer.__class__.__bases__[0]), {})
self.skip_synchronize = optimizer.skip_synchronize
self.synchronize = optimizer.synchronize
else:

View File

@ -174,7 +174,8 @@ class ModelIO(object):
cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key))
# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded,
checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
# 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)

View File

@ -127,7 +127,8 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool):
if num_classes == 1 and is_multiclass:
raise ValueError(
"You have binary data and have set `is_multiclass=True`, but `num_classes` is 1."
" Either set `is_multiclass=None`(default) or set `num_classes=2` to transform binary data to multi-class format."
" Either set `is_multiclass=None`(default) or set `num_classes=2`"
" to transform binary data to multi-class format."
)

View File

@ -73,11 +73,12 @@ class PrecisionRecallCurve(Metric):
>>> target = torch.tensor([0, 1, 3, 2])
>>> pr_curve = PrecisionRecallCurve(num_classes=5)
>>> precision, recall, thresholds = pr_curve(pred, target)
>>> precision
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> precision # doctest: +NORMALIZE_WHITESPACE
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
>>> thresholds
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
"""

View File

@ -18,7 +18,10 @@ import torch
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve, precision_recall_curve as __prc
from pytorch_lightning.metrics.functional.precision_recall_curve import (
_binary_clf_curve,
precision_recall_curve as __prc
)
from pytorch_lightning.metrics.functional.roc import roc as __roc
from pytorch_lightning.metrics.utils import (
to_categorical as __tc,
@ -821,7 +824,8 @@ def precision_recall_curve(
"""
Computes precision-recall pairs for different thresholds.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
@ -841,7 +845,8 @@ def multiclass_precision_recall_curve(
"""
Computes precision-recall pairs for different thresholds given a multiclass scores.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
"""
rank_zero_warn(
"This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
@ -863,7 +868,8 @@ def average_precision(
"""
Compute average precision from prediction scores.
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
.. warning :: Deprecated in favor of
:func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
"""
rank_zero_warn(
"This `average_precision` was deprecated in v1.1.0 in favor of"

View File

@ -208,8 +208,9 @@ def precision_recall_curve(
... [0.05, 0.05, 0.05, 0.75, 0.05]])
>>> target = torch.tensor([0, 1, 3, 2])
>>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5)
>>> precision
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> precision # doctest: +NORMALIZE_WHITESPACE
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
>>> recall
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
>>> thresholds

View File

@ -172,7 +172,8 @@ def _load_long_description(path_dir: str) -> str:
# readthedocs badge
text = text.replace('badge/?version=stable', f'badge/?version={__version__}')
text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{__version__}')
text = text.replace('pytorch-lightning.readthedocs.io/en/stable/',
f'pytorch-lightning.readthedocs.io/en/{__version__}')
# codecov badge
text = text.replace('/branch/master/graph/badge.svg', f'/release/{__version__}/graph/badge.svg')
# replace github badges for release ones

View File

@ -282,7 +282,9 @@ class CheckpointConnector:
checkpoint['lr_schedulers'] = lr_schedulers
# dump amp scaling
if self.trainer.amp_backend == AMPType.NATIVE and not self.trainer.use_tpu and self.trainer.scaler is not None:
if (self.trainer.amp_backend == AMPType.NATIVE
and not self.trainer.use_tpu
and self.trainer.scaler is not None):
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()

View File

@ -168,8 +168,9 @@ class LoggerConnector:
metrics (dict): Metric values
grad_norm_dic (dict): Gradient norms
step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
log_train_step_metrics (bool): Used to track if log_metrics function is being called in during training steps.
In training steps, we will log metrics on step: total_nb_idx (for accumulated gradients) and global_step for the rest.
log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training
steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients)
and global_step for the rest.
"""
# add gpu memory
if self.trainer.on_gpu and self.trainer.log_gpu_memory:

View File

@ -125,7 +125,8 @@ class TrainerOptimizersMixin(ABC):
if monitor is None:
raise MisconfigurationException(
'`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
' For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
' For example:'
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
)
lr_schedulers.append(
{**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor}

View File

@ -65,13 +65,12 @@ verbose = 2
# https://pep8.readthedocs.io/en/latest/intro.html#error-codes
format = pylint
ignore =
E731
W504
F401
E203 # E203 - whitespace before ':'. Opposite convention enforced by black
E231 # E231: missing whitespace after ',', ';', or ':'; for black
E501 # E501 - line too long. Handled by black, we have longer lines
W503 # W503 - line break before binary operator, need for black
E731 # do not assign a lambda expression, use a def
W504 # line break occurred after a binary operator
F401 # module imported but unused
E203 # whitespace before ':'. Opposite convention enforced by black
E231 # missing whitespace after ',', ';', or ':'; for black
W503 # line break before binary operator, need for black
# setup.cfg or tox.ini
[check-manifest]

View File

@ -73,7 +73,8 @@ class Discriminator(nn.Module):
class BasicGAN(LightningModule):
"""Implements a basic GAN for the purpose of illustrating multiple optimizers."""
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs):
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001,
b1: float = 0.5, b2: float = 0.999, **kwargs):
super().__init__()
self.hidden_dim = hidden_dim
self.learning_rate = learning_rate

View File

@ -213,11 +213,14 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered,
THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop.
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is being triggered,
THEN the trainer should continue until reaching `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step`
when `early_stopping` is being triggered,
THEN the trainer should continue until reaching
`trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
This test validate this expected behaviour
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when `early_stopping` is being triggered,
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step`
when `early_stopping` is being triggered,
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)

View File

@ -269,7 +269,12 @@ def test_model_checkpoint_file_extension(tmpdir):
"""
model = LogInTwoMethods()
model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True)
model_checkpoint = ModelCheckpointExtensionTest(
monitor='early_stop_on',
dirpath=tmpdir,
save_top_k=1,
save_last=True,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],

View File

@ -87,7 +87,8 @@ def test_result_reduce_ddp(result_cls):
7, True, 0, id='write_dict_predictions'
),
pytest.param(
0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine")
0, True, 1, id='full_loop_single_gpu',
marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine")
)
]
)

View File

@ -86,23 +86,14 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
(_binary_inputs.preds, _binary_inputs.target, _sk_fbeta_binary, 1, False),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
(_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
(_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target,
_sk_fbeta_multilabel, NUM_CLASSES, True),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_sk_fbeta_multidim_multiclass_prob,
NUM_CLASSES,
False,
),
(
_multidim_multiclass_inputs.preds,
_multidim_multiclass_inputs.target,
_sk_fbeta_multidim_multiclass,
NUM_CLASSES,
False,
),
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target,
_sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False),
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target,
_sk_fbeta_multidim_multiclass, NUM_CLASSES, False),
],
)
@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None])

View File

@ -85,24 +85,16 @@ def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, av
[
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False),
(_binary_inputs.preds, _binary_inputs.target, _sk_prec_recall_binary, 1, False),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_prec_recall_multilabel_prob, NUM_CLASSES, True),
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target,
_sk_prec_recall_multilabel_prob, NUM_CLASSES, True),
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_prec_recall_multiclass_prob, NUM_CLASSES, False),
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target,
_sk_prec_recall_multiclass_prob, NUM_CLASSES, False),
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False),
(
_multidim_multiclass_prob_inputs.preds,
_multidim_multiclass_prob_inputs.target,
_sk_prec_recall_multidim_multiclass_prob,
NUM_CLASSES,
False,
),
(
_multidim_multiclass_inputs.preds,
_multidim_multiclass_inputs.target,
_sk_prec_recall_multidim_multiclass,
NUM_CLASSES,
False,
),
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target,
_sk_prec_recall_multidim_multiclass_prob, NUM_CLASSES, False),
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target,
_sk_prec_recall_multidim_multiclass, NUM_CLASSES, False),
],
)
@pytest.mark.parametrize(

View File

@ -153,7 +153,8 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean',
torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8))
])
def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
def test_stat_scores_multiclass(pred, target, reduction,
expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction)
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)

View File

@ -520,7 +520,8 @@ def test_log_works_in_train_callback(tmpdir):
def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx,
on_steps=[], on_epochs=[], prob_bars=[]):
self.funcs_called_count[func_name] += 1
for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))):
iterate = list(itertools.product(*[on_steps, on_epochs, prob_bars]))
for idx, (on_step, on_epoch, prog_bar) in enumerate(iterate):
# run logging
custom_func_name = f"{func_idx}_{idx}_{func_name}"
pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step,

View File

@ -942,7 +942,8 @@ def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, m
@patch("torch.optim.Adam.step")
@patch("torch.optim.SGD.step")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
reason="test should be run outside of pytest")
def test_step_with_optimizer_closure_with_different_frequencies_ddp(mock_sgd_step, mock_adam_step, tmpdir):
"""
Tests that `step` works with optimizer_closure and different accumulated_gradient frequency

View File

@ -316,9 +316,11 @@ def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_bat
self.on_train_batch_start_end_dict = self.state_dict()
for key in self.on_train_batch_start_end_dict.keys():
if (batch_idx + 1) == self.trainer.num_training_batches:
assert torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key])
assert torch.equal(self.on_train_batch_start_state_dict[key],
self.on_train_batch_start_end_dict[key])
else:
assert not torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key])
assert not torch.equal(self.on_train_batch_start_state_dict[key],
self.on_train_batch_start_end_dict[key])
model = CurrentModel()

View File

@ -23,7 +23,8 @@ if _XLA_AVAILABLE:
import torch_xla.core.xla_model as xm
# lets hope that in or env we have installed XLA only for TPU devices, otherwise, it is testing in the cycle "if I am true test that I am true :D"
# lets hope that in or env we have installed XLA only for TPU devices, otherwise,
# it is testing in the cycle "if I am true test that I am true :D"
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
def test_tpu_device_absence():
"""Check tpu_device_exists returns None when torch_xla is not available"""