Initialize trainer with None in DDPAccelerator (#4915)
* Initialize trainer with None * add typing to all accelerators * resolve imports * update * add typing * removed typo * update * Fix formatting and imports in accelerator Co-authored-by: maxjeblick <maxjeblick@users.noreply.github.com> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch>
This commit is contained in:
parent
d5fa02e798
commit
2c3d43dcb5
|
@ -12,14 +12,15 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities.apply_func import move_data_to_device
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||
|
@ -33,7 +34,10 @@ else:
|
|||
|
||||
class Accelerator(object):
|
||||
|
||||
def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
|
||||
def __init__(self,
|
||||
trainer: Optional = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
ddp_plugin: Optional[DDPPlugin] = None):
|
||||
self.trainer = trainer
|
||||
self.nickname = None
|
||||
self.cluster_environment = cluster_environment
|
||||
|
|
|
@ -16,13 +16,14 @@ from typing import Any, Optional, Union, Callable
|
|||
import torch
|
||||
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.utilities import AMPType, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class CPUAccelerator(Accelerator):
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None):
|
||||
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
|
||||
"""
|
||||
Runs training on CPU
|
||||
|
||||
|
|
|
@ -20,12 +20,14 @@ from torch.nn.parallel import DistributedDataParallel
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
|
||||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
|
@ -34,7 +36,10 @@ if HYDRA_AVAILABLE:
|
|||
|
||||
class DDP2Accelerator(Accelerator):
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
|
||||
def __init__(self,
|
||||
trainer,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
ddp_plugin: Optional[DDPPlugin] = None):
|
||||
"""
|
||||
Runs training using DDP2 strategy on a cluster
|
||||
|
||||
|
|
|
@ -25,12 +25,18 @@ from torch.nn.parallel import DistributedDataParallel
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available
|
||||
from pytorch_lightning.utilities.distributed import find_free_network_port, rank_zero_only, sync_ddp_if_available
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
all_gather_ddp_if_available,
|
||||
find_free_network_port,
|
||||
rank_zero_only,
|
||||
sync_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
|
@ -41,7 +47,10 @@ if HYDRA_AVAILABLE:
|
|||
|
||||
class DDPAccelerator(Accelerator):
|
||||
|
||||
def __init__(self, trainer=None, cluster_environment=None, ddp_plugin=None):
|
||||
def __init__(self,
|
||||
trainer: Optional = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
ddp_plugin: Optional[DDPPlugin] = None):
|
||||
"""
|
||||
Runs training using DDP strategy on a single machine (manually, not via cluster start)
|
||||
|
||||
|
|
|
@ -11,17 +11,25 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License
|
||||
from typing import Optional
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.ddp_hpc_accelerator import DDPHPCAccelerator
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
from hydra.utils import to_absolute_path, get_original_cwd
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
from hydra.utils import get_original_cwd, to_absolute_path
|
||||
|
||||
|
||||
class DDPCPUHPCAccelerator(DDPHPCAccelerator):
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
|
||||
def __init__(self,
|
||||
trainer,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
ddp_plugin: Optional[DDPPlugin] = None):
|
||||
"""
|
||||
Runs training using DDP (with CPUs) strategy on a cluster
|
||||
|
||||
|
|
|
@ -16,22 +16,23 @@ from typing import Any, List, Optional, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
all_gather_ddp_if_available,
|
||||
find_free_network_port,
|
||||
rank_zero_only,
|
||||
rank_zero_warn,
|
||||
sync_ddp_if_available,
|
||||
all_gather_ddp_if_available,
|
||||
)
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
|
@ -41,7 +42,11 @@ if HYDRA_AVAILABLE:
|
|||
|
||||
class DDPCPUSpawnAccelerator(Accelerator):
|
||||
|
||||
def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
|
||||
def __init__(self,
|
||||
trainer,
|
||||
nprocs: int,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
ddp_plugin: Optional[DDPPlugin] = None):
|
||||
"""
|
||||
Runs training using DDP (on a single machine or manually on multiple machines), using mp.spawn
|
||||
|
||||
|
@ -197,8 +202,8 @@ class DDPCPUSpawnAccelerator(Accelerator):
|
|||
|
||||
def early_stopping_should_stop(self, pl_module):
|
||||
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
|
||||
dist.all_reduce(stop, op=dist.reduce_op.SUM)
|
||||
dist.barrier()
|
||||
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
|
||||
torch_distrib.barrier()
|
||||
should_stop = stop == self.trainer.world_size
|
||||
return should_stop
|
||||
|
||||
|
|
|
@ -21,11 +21,13 @@ from torch.nn.parallel import DistributedDataParallel
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only, sync_ddp_if_available, all_gather_ddp_if_available
|
||||
from pytorch_lightning.utilities.distributed import all_gather_ddp_if_available, rank_zero_only, sync_ddp_if_available
|
||||
|
||||
if HYDRA_AVAILABLE:
|
||||
from hydra.core.hydra_config import HydraConfig
|
||||
|
@ -34,7 +36,10 @@ if HYDRA_AVAILABLE:
|
|||
|
||||
class DDPHPCAccelerator(Accelerator):
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None, ddp_plugin=None):
|
||||
def __init__(self,
|
||||
trainer,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
ddp_plugin: Optional[DDPPlugin] = None):
|
||||
"""
|
||||
Runs training using DDP on an HPC cluster
|
||||
|
||||
|
|
|
@ -17,24 +17,25 @@ from typing import Any, List, Optional, Union
|
|||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.distributed import LightningDistributed
|
||||
from pytorch_lightning.plugins.ddp_plugin import DDPPlugin
|
||||
from pytorch_lightning.plugins.rpc_plugin import RPCPlugin
|
||||
from pytorch_lightning.utilities import HYDRA_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.cloud_io import atomic_save
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.distributed import (
|
||||
all_gather_ddp_if_available,
|
||||
find_free_network_port,
|
||||
rank_zero_only,
|
||||
rank_zero_warn,
|
||||
sync_ddp_if_available,
|
||||
all_gather_ddp_if_available,
|
||||
)
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
|
||||
|
@ -45,7 +46,11 @@ if HYDRA_AVAILABLE:
|
|||
|
||||
class DDPSpawnAccelerator(Accelerator):
|
||||
|
||||
def __init__(self, trainer, nprocs, cluster_environment=None, ddp_plugin=None):
|
||||
def __init__(self,
|
||||
trainer,
|
||||
nprocs: int,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
ddp_plugin: Optional[DDPPlugin] = None):
|
||||
"""
|
||||
Runs training using DDP using mp.spawn via manual launch (not cluster launch)
|
||||
|
||||
|
@ -226,8 +231,8 @@ class DDPSpawnAccelerator(Accelerator):
|
|||
|
||||
def early_stopping_should_stop(self, pl_module):
|
||||
stop = torch.tensor(int(self.trainer.should_stop), device=pl_module.device)
|
||||
dist.all_reduce(stop, op=dist.reduce_op.SUM)
|
||||
dist.barrier()
|
||||
torch_distrib.all_reduce(stop, op=torch_distrib.reduce_op.SUM)
|
||||
torch_distrib.barrier()
|
||||
should_stop = stop == self.trainer.world_size
|
||||
return should_stop
|
||||
|
||||
|
|
|
@ -11,12 +11,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.distributed import LightningDistributed
|
||||
|
@ -27,7 +29,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
|
||||
class DataParallelAccelerator(Accelerator):
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None):
|
||||
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
|
||||
"""
|
||||
Runs training using DP via manual start (not HPC cluster)
|
||||
|
||||
|
|
|
@ -15,7 +15,9 @@ from typing import Any, Callable, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.distributed.dist import LightningDistributed
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
|
||||
|
@ -23,7 +25,7 @@ from pytorch_lightning.utilities import AMPType
|
|||
class GPUAccelerator(Accelerator):
|
||||
amp_backend: AMPType
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None):
|
||||
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
|
||||
"""
|
||||
Runs training using a single GPU
|
||||
|
||||
|
|
|
@ -17,7 +17,9 @@ from typing import Any, Optional, Union, Callable
|
|||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.utilities import HOROVOD_AVAILABLE, AMPType
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_only
|
||||
|
||||
|
@ -28,7 +30,7 @@ if HOROVOD_AVAILABLE:
|
|||
class HorovodAccelerator(Accelerator):
|
||||
amp_backend: AMPType
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None):
|
||||
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
|
||||
"""
|
||||
Runs training using horovod
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ from torch.optim import Optimizer
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp
|
||||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.utilities import (
|
||||
|
@ -43,7 +44,7 @@ if TPU_AVAILABLE:
|
|||
|
||||
class TPUAccelerator(Accelerator):
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None):
|
||||
def __init__(self, trainer, cluster_environment: Optional[ClusterEnvironment] = None):
|
||||
"""
|
||||
Runs training using TPUs (colab, single machine or pod)
|
||||
|
||||
|
|
Loading…
Reference in New Issue