Initialize trainer with None in DDPAccelerator (#4915)

* Initialize trainer with None

* add typing to all accelerators

* resolve imports

* update

* add typing

* removed typo

* update

* Fix formatting and imports in accelerator

Co-authored-by: maxjeblick <maxjeblick@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
Co-authored-by: SeanNaren <sean@grid.ai>
Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
This commit is contained in:
chaton 2020-12-10 15:24:44 +01:00 committed by GitHub
parent d5fa02e798
commit 2c3d43dcb5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 76 additions and 27 deletions

View File

@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union
import torch
import torch.distributed as torch_distrib
from torch.optim import Optimizer
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities.apply_func import move_data_to_device
from pytorch_lightning.utilities.parsing import AttributeDict
@ -33,7 +34,10 @@ else:
class Accelerator(object):
def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: Optional = None,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
self.trainer = trainer
self.nickname = None
self.cluster_environment = cluster_environment

View File

@ -16,13 +16,14 @@ from typing import Any, Optional, Union, Callable
import torch
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import AMPType, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class CPUAccelerator(Accelerator):
def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training on CPU

View File

@ -20,12 +20,14 @@ from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
@ -34,7 +36,10 @@ if HYDRA_AVAILABLE:
class DDP2Accelerator(Accelerator):
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP2 strategy on a cluster

View File

@ -25,12 +25,18 @@ from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
sync_ddp_if_available,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
@ -41,7 +47,10 @@ if HYDRA_AVAILABLE:
class DDPAccelerator(Accelerator):
def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer: Optional = None,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP strategy on a single machine (manually, not via cluster start)

View File

@ -11,17 +11,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from typing import Optional
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE
if HYDRA_AVAILABLE:
from hydra.utils import to_absolute_path, get_original_cwd
from hydra.core.hydra_config import HydraConfig
from hydra.utils import get_original_cwd, to_absolute_path
class DDPCPUHPCAccelerator(DDPHPCAccelerator):
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP (with CPUs) strategy on a cluster

View File

@ -16,22 +16,23 @@ from typing import Any, List, Optional, Union
import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)
if HYDRA_AVAILABLE:
@ -41,7 +42,11 @@ if HYDRA_AVAILABLE:
class DDPCPUSpawnAccelerator(Accelerator):
def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
nprocs: int,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn
@ -197,8 +202,8 @@ class DDPCPUSpawnAccelerator(Accelerator):
def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

View File

@ -21,11 +21,13 @@ from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
if HYDRA_AVAILABLE:
from hydra.core.hydra_config import HydraConfig
@ -34,7 +36,10 @@ if HYDRA_AVAILABLE:
class DDPHPCAccelerator(Accelerator):
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP on an HPC cluster

View File

@ -17,24 +17,25 @@ from typing import Any, List, Optional, Union
import torch
import torch.distributed as torch_distrib
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
from pytorch_lightning.utilities.cloud_io import atomic_save
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.distributed import (
all_gather_ddp_if_available,
find_free_network_port,
rank_zero_only,
rank_zero_warn,
sync_ddp_if_available,
all_gather_ddp_if_available,
)
from pytorch_lightning.utilities.seed import seed_everything
@ -45,7 +46,11 @@ if HYDRA_AVAILABLE:
class DDPSpawnAccelerator(Accelerator):
def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
def __init__(self,
trainer,
nprocs: int,
cluster_environment: Optional[ClusterEnvironment] = None,
ddp_plugin: Optional[DDPPlugin] = None):
"""
Runs training using DDP using mp.spawn via manual launch (not cluster launch)
@ -226,8 +231,8 @@ class DDPSpawnAccelerator(Accelerator):
def early_stopping_should_stop(self, pl_module):
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
torch_distrib.barrier()
should_stop = stop == self.trainer.world_size
return should_stop

View File

@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Optional, Union
import torch
from torch import optim
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed import LightningDistributed
@ -27,7 +29,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
class DataParallelAccelerator(Accelerator):
def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using DP via manual start (not HPC cluster)

View File

@ -15,7 +15,9 @@ from typing import Any, Callable, Optional, Union
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.distributed.dist import LightningDistributed
from pytorch_lightning.utilities import AMPType
@ -23,7 +25,7 @@ from pytorch_lightning.utilities import AMPType
class GPUAccelerator(Accelerator):
amp_backend: AMPType
def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using a single GPU

View File

@ -17,7 +17,9 @@ from typing import Any, Optional, Union, Callable
import torch
from torch.optim.lr_scheduler import _LRScheduler
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
from pytorch_lightning.utilities.distributed import rank_zero_only
@ -28,7 +30,7 @@ if HOROVOD_AVAILABLE:
class HorovodAccelerator(Accelerator):
amp_backend: AMPType
def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using horovod

View File

@ -22,6 +22,7 @@ from torch.optim import Optimizer
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.core import LightningModule
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.utilities import (
@ -43,7 +44,7 @@ if TPU_AVAILABLE:
class TPUAccelerator(Accelerator):
def __init__(self, trainer, cluster_environment=None):
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
"""
Runs training using TPUs (colab, single machine or pod)