fix flake8 for new plugins (#5951)

* flake8

* fix cyclic import

* isort
This commit is contained in:
Adrian Wälchli 2021-02-18 19:28:23 +01:00 committed by GitHub
parent 6cc1a06078
commit fc9bb53e13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 38 additions and 56 deletions

View File

@ -11,21 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Iterable, Optional, Union
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import (
ApexMixedPrecisionPlugin,
MixedPrecisionPlugin,
NativeMixedPrecisionPlugin,
PrecisionPlugin,
)
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
@ -64,7 +58,7 @@ class Accelerator(object):
self.lr_schedulers = None
self.optimizer_frequencies = None
def setup(self, trainer: "Trainer", model: LightningModule) -> None:
def setup(self, trainer, model: LightningModule) -> None:
"""
Connects the plugins to the training process, creates optimizers
@ -76,13 +70,13 @@ class Accelerator(object):
self.setup_optimizers(trainer)
self.connect_precision_plugin(self.precision_plugin)
def start_training(self, trainer: 'Trainer'):
def start_training(self, trainer):
self.training_type_plugin.start_training(trainer)
def start_testing(self, trainer: 'Trainer'):
def start_testing(self, trainer):
self.training_type_plugin.start_testing(trainer)
def start_predicting(self, trainer: 'Trainer'):
def start_predicting(self, trainer):
self.training_type_plugin.start_predicting(trainer)
def pre_dispatch(self) -> None:
@ -310,7 +304,7 @@ class Accelerator(object):
"""Hook to do something at the end of the training"""
pass
def setup_optimizers(self, trainer: "Trainer"):
def setup_optimizers(self, trainer):
"""creates optimizers and schedulers
Args:

View File

@ -13,9 +13,8 @@
# limitations under the License.
import contextlib
from abc import ABC, abstractmethod
from typing import Any, Callable, Generator, Optional, overload, Sequence, Tuple
from typing import Generator, Optional, Sequence, Tuple
import torch
from torch.nn import Module

View File

@ -24,7 +24,8 @@ from pytorch_lightning.plugins.base_plugin import Plugin
class PrecisionPlugin(Plugin):
""" Plugin handling the precision-specific parts of the training.
The static classattributes EPSILON and precision must be overwritten in child-classes and their default values reflect fp32 training
The static classattributes EPSILON and precision must be overwritten in child-classes and their
default values reflect fp32 training.
"""
EPSILON = 1e-6
precision = 32

View File

@ -25,4 +25,4 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin):
def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
os.environ["XLA_USE_BF16"] = str(1)
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)

View File

@ -1,15 +1,15 @@
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401

View File

@ -27,7 +27,6 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
@ -120,7 +119,7 @@ class DDPPlugin(ParallelPlugin):
command = sys.argv
try:
full_path = path_lib(command[0])
except Exception as e:
except Exception:
full_path = os.path.abspath(command[0])
command[0] = full_path

View File

@ -263,7 +263,7 @@ class DeepSpeedPlugin(DDPPlugin):
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs
def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[List, List, List]:
def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]:
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
# User may have specified config options instead in configure_optimizers, but this is handled
# via `_initialize_deepspeed_train`

View File

@ -101,14 +101,14 @@ class HorovodPlugin(ParallelPlugin):
hvd.join()
def start_testing(self, trainer):
with ExitStack() as stack:
with ExitStack():
self._results = trainer.run_test()
# Make sure all workers have finished training before returning to the user
hvd.join()
def start_predicting(self, trainer):
with ExitStack() as stack:
with ExitStack():
# set up training routine
self._results = trainer.run_predict()

View File

@ -24,7 +24,6 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp

View File

@ -13,11 +13,10 @@
# limitations under the License.
import os
from contextlib import suppress
from typing import List, Optional, Sequence
from typing import List, Optional
import torch
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.utilities import _RPC_AVAILABLE

View File

@ -329,11 +329,11 @@ class RPCSequentialPlugin(RPCPlugin):
if self.main_rpc_process:
super().post_training_step()
def start_training(self, trainer: 'Trainer') -> None:
def start_training(self, trainer) -> None:
if self.main_rpc_process:
super().start_training(trainer)
def start_testing(self, trainer: 'Trainer') -> None:
def start_testing(self, trainer) -> None:
if self.main_rpc_process:
super().start_testing(trainer)

View File

@ -1,7 +1,6 @@
from typing import Optional
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.optimizer import is_lightning_optimizer
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only

View File

@ -1,7 +1,6 @@
from typing import Any, Union
import torch
from torch._C import device
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin

View File

@ -1,4 +1,3 @@
import io
import os
from typing import Optional, Union
@ -11,7 +10,6 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
from pytorch_lightning.utilities.apply_func import move_data_to_device
if _TPU_AVAILABLE:
import torch_xla
import torch_xla.core.xla_model as xm
@ -68,4 +66,4 @@ class SingleTPUPlugin(SingleDevicePlugin):
@property
def is_distributed(self):
return False
return False

View File

@ -1,7 +1,7 @@
import io
import os
import re
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Union
import torch
import torch.multiprocessing as mp
@ -206,7 +206,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
# restore main state with best weights
best_path = self.mp_queue.get()
last_path = self.mp_queue.get()
results = self.mp_queue.get()
self._results = self.mp_queue.get()
# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback is not None:

View File

@ -67,11 +67,6 @@ exclude =
*.egg
build
temp
# TODO: temporary until accelerator refactor finished
pytorch_lightning/accelerators/accelerator.py
pytorch_lightning/plugins/training_type
pytorch_lightning/plugins/precision
pytorch_lightning/plugins/base_plugin.py
select = E,W,F
doctests = True

View File

@ -79,7 +79,7 @@ def test_if_test_works_after_train(tmpdir):
model = BoringModel()
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
assert trainer.test(model) == 1
assert len(trainer.test(model)) == 1
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@ -119,4 +119,4 @@ def test_if_weights_tied(tmpdir, capsys=None):
assert result
assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
assert trainer.test(model) == 1
assert len(trainer.test(model)) == 1