fix flake8 for new plugins (#5951)
* flake8 * fix cyclic import * isort
This commit is contained in:
parent
6cc1a06078
commit
fc9bb53e13
|
@ -11,21 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.plugins.precision import (
|
||||
ApexMixedPrecisionPlugin,
|
||||
MixedPrecisionPlugin,
|
||||
NativeMixedPrecisionPlugin,
|
||||
PrecisionPlugin,
|
||||
)
|
||||
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin, PrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type import TrainingTypePlugin
|
||||
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
|
||||
from pytorch_lightning.utilities.enums import AMPType, LightningEnum
|
||||
|
@ -64,7 +58,7 @@ class Accelerator(object):
|
|||
self.lr_schedulers = None
|
||||
self.optimizer_frequencies = None
|
||||
|
||||
def setup(self, trainer: "Trainer", model: LightningModule) -> None:
|
||||
def setup(self, trainer, model: LightningModule) -> None:
|
||||
"""
|
||||
Connects the plugins to the training process, creates optimizers
|
||||
|
||||
|
@ -76,13 +70,13 @@ class Accelerator(object):
|
|||
self.setup_optimizers(trainer)
|
||||
self.connect_precision_plugin(self.precision_plugin)
|
||||
|
||||
def start_training(self, trainer: 'Trainer'):
|
||||
def start_training(self, trainer):
|
||||
self.training_type_plugin.start_training(trainer)
|
||||
|
||||
def start_testing(self, trainer: 'Trainer'):
|
||||
def start_testing(self, trainer):
|
||||
self.training_type_plugin.start_testing(trainer)
|
||||
|
||||
def start_predicting(self, trainer: 'Trainer'):
|
||||
def start_predicting(self, trainer):
|
||||
self.training_type_plugin.start_predicting(trainer)
|
||||
|
||||
def pre_dispatch(self) -> None:
|
||||
|
@ -310,7 +304,7 @@ class Accelerator(object):
|
|||
"""Hook to do something at the end of the training"""
|
||||
pass
|
||||
|
||||
def setup_optimizers(self, trainer: "Trainer"):
|
||||
def setup_optimizers(self, trainer):
|
||||
"""creates optimizers and schedulers
|
||||
|
||||
Args:
|
||||
|
|
|
@ -13,9 +13,8 @@
|
|||
# limitations under the License.
|
||||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Generator, Optional, overload, Sequence, Tuple
|
||||
from typing import Generator, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,8 @@ from pytorch_lightning.plugins.base_plugin import Plugin
|
|||
|
||||
class PrecisionPlugin(Plugin):
|
||||
""" Plugin handling the precision-specific parts of the training.
|
||||
The static classattributes EPSILON and precision must be overwritten in child-classes and their default values reflect fp32 training
|
||||
The static classattributes EPSILON and precision must be overwritten in child-classes and their
|
||||
default values reflect fp32 training.
|
||||
"""
|
||||
EPSILON = 1e-6
|
||||
precision = 32
|
||||
|
|
|
@ -25,4 +25,4 @@ class TPUHalfPrecisionPlugin(PrecisionPlugin):
|
|||
|
||||
def connect(self, model: torch.nn.Module, optimizers, lr_schedulers):
|
||||
os.environ["XLA_USE_BF16"] = str(1)
|
||||
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
|
||||
return super().connect(model=model, optimizers=optimizers, lr_schedulers=lr_schedulers)
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin
|
||||
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin
|
||||
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin
|
||||
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin
|
||||
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin
|
||||
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin
|
||||
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin
|
||||
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin
|
||||
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin
|
||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.deepspeed import DeepSpeedPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.dp import DataParallelPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.horovod import HorovodPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.rpc import RPCPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.rpc_sequential import RPCSequentialPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.sharded import DDPShardedPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.sharded_spawn import DDPSpawnShardedPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.single_device import SingleDevicePlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.single_tpu import SingleTPUPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.tpu_spawn import TPUSpawnPlugin # noqa: F401
|
||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401
|
||||
|
|
|
@ -27,7 +27,6 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.distributed import LightningDistributed
|
||||
from pytorch_lightning.overrides import LightningDistributedModule
|
||||
from pytorch_lightning.overrides.distributed import prepare_for_backward
|
||||
from pytorch_lightning.plugins.environments import SLURMEnvironment, TorchElasticEnvironment
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
||||
from pytorch_lightning.utilities import _HYDRA_AVAILABLE, _TORCH_GREATER_EQUAL_1_7, rank_zero_warn
|
||||
|
@ -120,7 +119,7 @@ class DDPPlugin(ParallelPlugin):
|
|||
command = sys.argv
|
||||
try:
|
||||
full_path = path_lib(command[0])
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
full_path = os.path.abspath(command[0])
|
||||
|
||||
command[0] = full_path
|
||||
|
|
|
@ -263,7 +263,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
|
||||
return distributed_sampler_kwargs
|
||||
|
||||
def init_optimizers(self, trainer: "Trainer", model: LightningModule) -> Tuple[List, List, List]:
|
||||
def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]:
|
||||
# Skip initializing optimizers here as DeepSpeed handles optimizers via config.
|
||||
# User may have specified config options instead in configure_optimizers, but this is handled
|
||||
# via `_initialize_deepspeed_train`
|
||||
|
|
|
@ -101,14 +101,14 @@ class HorovodPlugin(ParallelPlugin):
|
|||
hvd.join()
|
||||
|
||||
def start_testing(self, trainer):
|
||||
with ExitStack() as stack:
|
||||
with ExitStack():
|
||||
self._results = trainer.run_test()
|
||||
|
||||
# Make sure all workers have finished training before returning to the user
|
||||
hvd.join()
|
||||
|
||||
def start_predicting(self, trainer):
|
||||
with ExitStack() as stack:
|
||||
with ExitStack():
|
||||
# set up training routine
|
||||
self._results = trainer.run_predict()
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.overrides.base import unwrap_lightning_module
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, ReduceOp
|
||||
|
||||
|
||||
|
|
|
@ -13,11 +13,10 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.utilities import _RPC_AVAILABLE
|
||||
|
|
|
@ -329,11 +329,11 @@ class RPCSequentialPlugin(RPCPlugin):
|
|||
if self.main_rpc_process:
|
||||
super().post_training_step()
|
||||
|
||||
def start_training(self, trainer: 'Trainer') -> None:
|
||||
def start_training(self, trainer) -> None:
|
||||
if self.main_rpc_process:
|
||||
super().start_training(trainer)
|
||||
|
||||
def start_testing(self, trainer: 'Trainer') -> None:
|
||||
def start_testing(self, trainer) -> None:
|
||||
if self.main_rpc_process:
|
||||
super().start_testing(trainer)
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.optimizer import is_lightning_optimizer
|
||||
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
|
||||
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
from torch._C import device
|
||||
|
||||
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
|
||||
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import io
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
|
@ -11,7 +10,6 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_warn
|
|||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
|
||||
if _TPU_AVAILABLE:
|
||||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
|
@ -68,4 +66,4 @@ class SingleTPUPlugin(SingleDevicePlugin):
|
|||
|
||||
@property
|
||||
def is_distributed(self):
|
||||
return False
|
||||
return False
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -206,7 +206,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
# restore main state with best weights
|
||||
best_path = self.mp_queue.get()
|
||||
last_path = self.mp_queue.get()
|
||||
results = self.mp_queue.get()
|
||||
self._results = self.mp_queue.get()
|
||||
|
||||
# transfer back the best path to the trainer
|
||||
if self.lightning_module.trainer.checkpoint_callback is not None:
|
||||
|
|
|
@ -67,11 +67,6 @@ exclude =
|
|||
*.egg
|
||||
build
|
||||
temp
|
||||
# TODO: temporary until accelerator refactor finished
|
||||
pytorch_lightning/accelerators/accelerator.py
|
||||
pytorch_lightning/plugins/training_type
|
||||
pytorch_lightning/plugins/precision
|
||||
pytorch_lightning/plugins/base_plugin.py
|
||||
|
||||
select = E,W,F
|
||||
doctests = True
|
||||
|
|
|
@ -79,7 +79,7 @@ def test_if_test_works_after_train(tmpdir):
|
|||
model = BoringModel()
|
||||
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
assert trainer.test(model) == 1
|
||||
assert len(trainer.test(model)) == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
|
||||
|
@ -119,4 +119,4 @@ def test_if_weights_tied(tmpdir, capsys=None):
|
|||
assert result
|
||||
|
||||
assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
|
||||
assert trainer.test(model) == 1
|
||||
assert len(trainer.test(model)) == 1
|
||||
|
|
Loading…
Reference in New Issue