diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 2cc6098778..84d53b5add 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,21 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Iterable, Optional, Union import torch from torch.optim import Optimizer from torch.utils.data import DataLoader from pytorch_lightning.core import LightningModule -from pytorch_lightning.plugins.precision import ( - ApexMixedPrecisionPlugin, - MixedPrecisionPlugin, - NativeMixedPrecisionPlugin, - PrecisionPlugin, -) +from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin from pytorch_lightning.plugins.training_type import TrainingTypePlugin -from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available from pytorch_lightning.utilities.enums import AMPType, LightningEnum @@ -64,7 +58,7 @@ class Accelerator(object): self.lr_schedulers = None self.optimizer_frequencies = None - def setup(self, trainer: "Trainer", model: LightningModule) -> None: + def setup(self, trainer, model: LightningModule) -> None: """ Connects the plugins to the training process, creates optimizers @@ -76,13 +70,13 @@ class Accelerator(object): self.setup_optimizers(trainer) self.connect_precision_plugin(self.precision_plugin) - def start_training(self, trainer: 'Trainer'): + def start_training(self, trainer): self.training_type_plugin.start_training(trainer) - def start_testing(self, trainer: 'Trainer'): + def start_testing(self, trainer): self.training_type_plugin.start_testing(trainer) - def start_predicting(self, trainer: 'Trainer'): + def start_predicting(self, trainer): self.training_type_plugin.start_predicting(trainer) def pre_dispatch(self) -> None: @@ -310,7 +304,7 @@ class Accelerator(object): """Hook to do something at the end of the training""" pass - def setup_optimizers(self, trainer: "Trainer"): + def setup_optimizers(self, trainer): """creates optimizers and schedulers Args: diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py index e495d9ffad..19033a2099 100644 --- a/pytorch_lightning/plugins/base_plugin.py +++ b/pytorch_lightning/plugins/base_plugin.py @@ -13,9 +13,8 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Callable, Generator, Optional, overload, Sequence, Tuple +from typing import Generator, Optional, Sequence, Tuple -import torch from torch.nn import Module diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index 2216d3ae46..34879e514a 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -24,7 +24,8 @@ from pytorch_lightning.plugins.base_plugin import Plugin class PrecisionPlugin(Plugin): """ Plugin handling the precision-specific parts of the training. - The static classattributes EPSILON and precision must be overwritten in child-classes and their default values reflect fp32 training + The static classattributes EPSILON and precision must be overwritten in child-classes and their + default values reflect fp32 training. """ EPSILON = 1e-6 precision = 32 diff --git a/pytorch_lightning/plugins/precision/tpu_bfloat.py b/pytorch_lightning/plugins/precision/tpu_bfloat.py index c911bf6918..7f4916dd26 100644 --- a/pytorch_lightning/plugins/precision/tpu_bfloat.py +++ b/pytorch_lightning/plugins/precision/tpu_bfloat.py @@ -25,4 +25,4 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin): def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): os.environ["XLA_USE_BF16"] = str(1) - return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) \ No newline at end of file + return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers) diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py index b73c6351de..30723d67da 100644 --- a/pytorch_lightning/plugins/training_type/__init__.py +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -1,15 +1,15 @@ -from pytorch_lightning.plugins.training_type.ddp import DDPPlugin -from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin -from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin -from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin -from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin -from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin -from pytorch_lightning.plugins.training_type.rpc import RPCPlugin -from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin -from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin -from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin -from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin -from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin -from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin -from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin +from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401 +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index ec0ff1a308..6a4e948e89 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -27,7 +27,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.distributed import LightningDistributed from pytorch_lightning.overrides import LightningDistributedModule from pytorch_lightning.overrides.distributed import prepare_for_backward -from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn @@ -120,7 +119,7 @@ class DDPPlugin(ParallelPlugin): command = sys.argv try: full_path = path_lib(command[0]) - except Exception as e: + except Exception: full_path = os.path.abspath(command[0]) command[0] = full_path diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index d9949b97f6..b6545c9b40 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -263,7 +263,7 @@ class DeepSpeedPlugin(DDPPlugin): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs - def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[List, List, List]: + def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]: # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled # via `_initialize_deepspeed_train` diff --git a/pytorch_lightning/plugins/training_type/horovod.py b/pytorch_lightning/plugins/training_type/horovod.py index f9cfa43fe9..351d945675 100644 --- a/pytorch_lightning/plugins/training_type/horovod.py +++ b/pytorch_lightning/plugins/training_type/horovod.py @@ -101,14 +101,14 @@ class HorovodPlugin(ParallelPlugin): hvd.join() def start_testing(self, trainer): - with ExitStack() as stack: + with ExitStack(): self._results = trainer.run_test() # Make sure all workers have finished training before returning to the user hvd.join() def start_predicting(self, trainer): - with ExitStack() as stack: + with ExitStack(): # set up training routine self._results = trainer.run_predict() diff --git a/pytorch_lightning/plugins/training_type/parallel.py b/pytorch_lightning/plugins/training_type/parallel.py index e1bc52a513..f3c825fe9c 100644 --- a/pytorch_lightning/plugins/training_type/parallel.py +++ b/pytorch_lightning/plugins/training_type/parallel.py @@ -24,7 +24,6 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import unwrap_lightning_module from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin -from pytorch_lightning.utilities import rank_zero_info from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp diff --git a/pytorch_lightning/plugins/training_type/rpc.py b/pytorch_lightning/plugins/training_type/rpc.py index 3e86eec778..3c016f3cb8 100644 --- a/pytorch_lightning/plugins/training_type/rpc.py +++ b/pytorch_lightning/plugins/training_type/rpc.py @@ -13,11 +13,10 @@ # limitations under the License. import os from contextlib import suppress -from typing import List, Optional, Sequence +from typing import List, Optional import torch -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.training_type.ddp import DDPPlugin from pytorch_lightning.utilities import _RPC_AVAILABLE diff --git a/pytorch_lightning/plugins/training_type/rpc_sequential.py b/pytorch_lightning/plugins/training_type/rpc_sequential.py index fc707afb3e..3878aa9db3 100644 --- a/pytorch_lightning/plugins/training_type/rpc_sequential.py +++ b/pytorch_lightning/plugins/training_type/rpc_sequential.py @@ -329,11 +329,11 @@ class RPCSequentialPlugin(RPCPlugin): if self.main_rpc_process: super().post_training_step() - def start_training(self, trainer: 'Trainer') -> None: + def start_training(self, trainer) -> None: if self.main_rpc_process: super().start_training(trainer) - def start_testing(self, trainer: 'Trainer') -> None: + def start_testing(self, trainer) -> None: if self.main_rpc_process: super().start_testing(trainer) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index c38690473b..a8d497cd11 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -1,7 +1,6 @@ from typing import Optional from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.core.optimizer import is_lightning_optimizer from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only diff --git a/pytorch_lightning/plugins/training_type/single_device.py b/pytorch_lightning/plugins/training_type/single_device.py index 1e3fe6b851..4b1d24301b 100644 --- a/pytorch_lightning/plugins/training_type/single_device.py +++ b/pytorch_lightning/plugins/training_type/single_device.py @@ -1,7 +1,6 @@ from typing import Any, Union import torch -from torch._C import device from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin diff --git a/pytorch_lightning/plugins/training_type/single_tpu.py b/pytorch_lightning/plugins/training_type/single_tpu.py index 40fc9fba3a..3ddfd98128 100644 --- a/pytorch_lightning/plugins/training_type/single_tpu.py +++ b/pytorch_lightning/plugins/training_type/single_tpu.py @@ -1,4 +1,3 @@ -import io import os from typing import Optional, Union @@ -11,7 +10,6 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn from pytorch_lightning.utilities.apply_func import move_data_to_device if _TPU_AVAILABLE: - import torch_xla import torch_xla.core.xla_model as xm @@ -68,4 +66,4 @@ class SingleTPUPlugin(SingleDevicePlugin): @property def is_distributed(self): - return False \ No newline at end of file + return False diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index cd8a132c07..692a4426a6 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -1,7 +1,7 @@ import io import os import re -from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Union import torch import torch.multiprocessing as mp @@ -206,7 +206,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin): # restore main state with best weights best_path = self.mp_queue.get() last_path = self.mp_queue.get() - results = self.mp_queue.get() + self._results = self.mp_queue.get() # transfer back the best path to the trainer if self.lightning_module.trainer.checkpoint_callback is not None: diff --git a/setup.cfg b/setup.cfg index f622581b5a..be8f7cd50f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,11 +67,6 @@ exclude = *.egg build temp - # TODO: temporary until accelerator refactor finished - pytorch_lightning/accelerators/accelerator.py - pytorch_lightning/plugins/training_type - pytorch_lightning/plugins/precision - pytorch_lightning/plugins/base_plugin.py select = E,W,F doctests = True diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index daea22968b..03da7c81b2 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -79,7 +79,7 @@ def test_if_test_works_after_train(tmpdir): model = BoringModel() trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) - assert trainer.test(model) == 1 + assert len(trainer.test(model)) == 1 @pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @@ -119,4 +119,4 @@ def test_if_weights_tied(tmpdir, capsys=None): assert result assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list)) - assert trainer.test(model) == 1 + assert len(trainer.test(model)) == 1