enable passing in custom accelerators (#4050)
* enable custom accelerators * ref: finish decoupling apex, LM and backward * ref: finish decoupling apex, LM and backward * ref: finish decoupling apex, LM and backward
This commit is contained in:
parent
2b255a3df4
commit
5b261a230e
|
@ -11,3 +11,4 @@ from pytorch_lightning.accelerators.ddp_slurm_backend import DDPSLURMBackend
|
|||
from pytorch_lightning.accelerators.ddp_torchelastic_backend import DDPTorchElasticBackend
|
||||
from pytorch_lightning.accelerators.ddp_cpu_torchelastic_backend import DDPCPUTorchElasticBackend
|
||||
from pytorch_lightning.accelerators.ddp_cpu_slurm_backend import DDPCPUSLURMBackend
|
||||
from pytorch_lightning.accelerators.base_accelerator import Accelerator
|
||||
|
|
|
@ -9,6 +9,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.cluster_environments.slurm_environment import SLURMEnvironment
|
||||
from pytorch_lightning.cluster_environments.torchelastic_environment import TorchElasticEnvironment
|
||||
from pytorch_lightning.accelerators.base_accelerator import Accelerator
|
||||
|
||||
try:
|
||||
import torch_xla
|
||||
|
@ -29,11 +30,13 @@ class AcceleratorConnector:
|
|||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self.accelerator = None
|
||||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
num_processes,
|
||||
tpu_cores,
|
||||
accelerator,
|
||||
distributed_backend,
|
||||
auto_select_gpus,
|
||||
gpus,
|
||||
|
@ -44,6 +47,15 @@ class AcceleratorConnector:
|
|||
replace_sampler_ddp,
|
||||
deterministic,
|
||||
):
|
||||
# temporary mapping until we remove all the distributed_backend references
|
||||
if accelerator is not None:
|
||||
self.accelerator = accelerator
|
||||
if isinstance(accelerator, Accelerator):
|
||||
self.accelerator.trainer = self
|
||||
distributed_backend = self.accelerator.nickname
|
||||
else:
|
||||
distributed_backend = accelerator
|
||||
|
||||
self.trainer.deterministic = deterministic
|
||||
|
||||
torch.backends.cudnn.deterministic = self.trainer.deterministic
|
||||
|
@ -145,7 +157,18 @@ class AcceleratorConnector:
|
|||
if self.trainer.accelerator_backend is not None:
|
||||
return self.trainer.accelerator_backend
|
||||
|
||||
# SLURM ddp
|
||||
# ----------------------------------
|
||||
# Use the user provided accelerator
|
||||
# ----------------------------------
|
||||
# use the one the user passed in
|
||||
if self.accelerator is not None and isinstance(self.accelerator, Accelerator):
|
||||
self.accelerator.trainer = self.trainer
|
||||
acc = self.accelerator
|
||||
return acc
|
||||
|
||||
# ----------------------------------
|
||||
# choose an accelerator for the user
|
||||
# ----------------------------------
|
||||
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
|
||||
|
||||
# torchelastic or general non_slurm ddp
|
||||
|
|
|
@ -23,13 +23,16 @@ EPSILON_FP16 = 1e-5
|
|||
|
||||
class Accelerator(object):
|
||||
|
||||
def __init__(self, trainer, cluster_environment=None):
|
||||
def __init__(self, trainer=None, cluster_environment=None):
|
||||
self.trainer = trainer
|
||||
self.nickname = None
|
||||
self.cluster_environment = cluster_environment
|
||||
self.dist = AttributeDict(rank=0, device=None)
|
||||
self.train_loop = self.trainer.train
|
||||
self.validation_loop = self.trainer.run_evaluation
|
||||
self.test_loop = self.trainer.run_evaluation
|
||||
|
||||
if trainer is not None:
|
||||
self.train_loop = self.trainer.train
|
||||
self.validation_loop = self.trainer.run_evaluation
|
||||
self.test_loop = self.trainer.run_evaluation
|
||||
|
||||
def setup(self, model):
|
||||
pass
|
||||
|
|
|
@ -22,6 +22,7 @@ class CPUBackend(Accelerator):
|
|||
|
||||
def __init__(self, trainer, cluster_environment=None):
|
||||
super().__init__(trainer, cluster_environment)
|
||||
self.nickname = None
|
||||
|
||||
def setup(self, model):
|
||||
# run through amp wrapper
|
||||
|
|
|
@ -42,6 +42,7 @@ class DDP2Backend(Accelerator):
|
|||
super().__init__(trainer, cluster_environment)
|
||||
self.task_idx = None
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'ddp2'
|
||||
|
||||
def setup(self, model):
|
||||
self._resolve_task_idx()
|
||||
|
|
|
@ -52,6 +52,7 @@ class DDPBackend(Accelerator):
|
|||
self._has_spawned_children = False
|
||||
self.interactive_ddp_procs = []
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'ddp'
|
||||
|
||||
def setup(self, model):
|
||||
# first track model
|
||||
|
|
|
@ -48,6 +48,7 @@ class DDPCPUSLURMBackend(Accelerator):
|
|||
self.task_idx = None
|
||||
self._has_spawned_children = False
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'ddp_cpu'
|
||||
|
||||
def setup(self, model):
|
||||
self.trainer.model = model
|
||||
|
|
|
@ -46,6 +46,7 @@ class DDPCPUSpawnBackend(Accelerator):
|
|||
self.mp_queue = None
|
||||
self.nprocs = nprocs
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'ddp_cpu'
|
||||
|
||||
def setup(self, model):
|
||||
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))
|
||||
|
|
|
@ -47,6 +47,7 @@ class DDPCPUTorchElasticBackend(Accelerator):
|
|||
self.task_idx = None
|
||||
self._has_spawned_children = False
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'ddp_cpu'
|
||||
|
||||
def setup(self, model):
|
||||
self.trainer.model = model
|
||||
|
|
|
@ -47,6 +47,7 @@ class DDPSLURMBackend(Accelerator):
|
|||
self.task_idx = None
|
||||
self._has_spawned_children = False
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'ddp'
|
||||
|
||||
def setup(self, model):
|
||||
self.trainer.model = model
|
||||
|
|
|
@ -47,6 +47,7 @@ class DDPSpawnBackend(Accelerator):
|
|||
self.mp_queue = None
|
||||
self.nprocs = nprocs
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'ddp'
|
||||
|
||||
def setup(self, model):
|
||||
os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', str(find_free_network_port()))
|
||||
|
|
|
@ -48,6 +48,7 @@ class DDPTorchElasticBackend(Accelerator):
|
|||
self.task_idx = None
|
||||
self._has_spawned_children = False
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'ddp'
|
||||
|
||||
def setup(self, model):
|
||||
self.trainer.model = model
|
||||
|
|
|
@ -29,6 +29,7 @@ class DataParallelBackend(Accelerator):
|
|||
super().__init__(trainer, cluster_environment)
|
||||
self.model_autocast_original_forward = None
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = 'dp'
|
||||
|
||||
def setup(self, model):
|
||||
# call setup after the ddp process has connected
|
||||
|
|
|
@ -25,6 +25,7 @@ class GPUBackend(Accelerator):
|
|||
def __init__(self, trainer, cluster_environment=None):
|
||||
super().__init__(trainer, cluster_environment)
|
||||
self.dist = LightningDistributed()
|
||||
self.nickname = None
|
||||
|
||||
def setup(self, model):
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ class HorovodBackend(Accelerator):
|
|||
|
||||
def __init__(self, trainer, cluster_environment=None):
|
||||
super().__init__(trainer, cluster_environment)
|
||||
self.nickname = 'horovod'
|
||||
|
||||
def setup(self, model):
|
||||
# call setup after the ddp process has connected
|
||||
|
|
|
@ -42,6 +42,7 @@ class TPUBackend(Accelerator):
|
|||
super().__init__(trainer, cluster_environment)
|
||||
self.start_method = None
|
||||
self.mp_queue = None
|
||||
self.nickname = None
|
||||
|
||||
def setup(self, model):
|
||||
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
|
||||
|
|
|
@ -165,6 +165,57 @@ Example::
|
|||
Trainer flags
|
||||
-------------
|
||||
|
||||
accelerator
|
||||
^^^^^^^^^^^
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<video width="50%" max-width="400px" controls
|
||||
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/distributed_backend.jpg"
|
||||
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/distributed_backend.mp4"></video>
|
||||
|
||||
|
|
||||
|
||||
The accelerator backend to use (previously known as distributed_backend).
|
||||
|
||||
- (```dp```) is DataParallel (split batch among GPUs of same machine)
|
||||
- (```ddp```) is DistributedDataParallel (each gpu on each node trains, and syncs grads)
|
||||
- (```ddp_cpu```) is DistributedDataParallel on CPU (same as `ddp`, but does not use GPUs.
|
||||
Useful for multi-node CPU training or single-node debugging. Note that this will **not** give
|
||||
a speedup on a single node, since Torch already makes effient use of multiple CPUs on a single
|
||||
machine.)
|
||||
- (```ddp2```) dp on node, ddp across nodes. Useful for things like increasing
|
||||
the number of negative samples
|
||||
|
||||
.. testcode::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(distributed_backend=None)
|
||||
|
||||
Example::
|
||||
|
||||
# dp = DataParallel
|
||||
trainer = Trainer(gpus=2, distributed_backend='dp')
|
||||
|
||||
# ddp = DistributedDataParallel
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
|
||||
|
||||
# ddp2 = DistributedDataParallel + dp
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
|
||||
|
||||
.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
|
||||
|
||||
You can also modify hardware behavior by subclassing an existing accelerator to adjust for your needs.
|
||||
|
||||
Example::
|
||||
|
||||
class MyOwnDDP(DDPBackend):
|
||||
...
|
||||
|
||||
Trainer(accelerator=MyOwnDDP())
|
||||
|
||||
.. warning:: Passing in custom accelerators is experimental but work is in progress to enable full compatibility.
|
||||
|
||||
accumulate_grad_batches
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -486,47 +537,7 @@ Example::
|
|||
|
||||
distributed_backend
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<video width="50%" max-width="400px" controls
|
||||
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/distributed_backend.jpg"
|
||||
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/distributed_backend.mp4"></video>
|
||||
|
||||
|
|
||||
|
||||
The distributed backend to use.
|
||||
|
||||
- (```dp```) is DataParallel (split batch among GPUs of same machine)
|
||||
- (```ddp```) is DistributedDataParallel (each gpu on each node trains, and syncs grads)
|
||||
- (```ddp_cpu```) is DistributedDataParallel on CPU (same as `ddp`, but does not use GPUs.
|
||||
Useful for multi-node CPU training or single-node debugging. Note that this will **not** give
|
||||
a speedup on a single node, since Torch already makes effient use of multiple CPUs on a single
|
||||
machine.)
|
||||
- (```ddp2```) dp on node, ddp across nodes. Useful for things like increasing
|
||||
the number of negative samples
|
||||
|
||||
.. testcode::
|
||||
|
||||
# default used by the Trainer
|
||||
trainer = Trainer(distributed_backend=None)
|
||||
|
||||
Example::
|
||||
|
||||
# dp = DataParallel
|
||||
trainer = Trainer(gpus=2, distributed_backend='dp')
|
||||
|
||||
# ddp = DistributedDataParallel
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp')
|
||||
|
||||
# ddp2 = DistributedDataParallel + dp
|
||||
trainer = Trainer(gpus=2, num_nodes=2, distributed_backend='ddp2')
|
||||
|
||||
.. note:: this option does not apply to TPU. TPUs use ```ddp``` by default (over each core)
|
||||
|
||||
See Also:
|
||||
- :ref:`Multi-GPU training guide <multi_gpu>`.
|
||||
- :ref:`Multi-node (SLURM) guide <slurm>`.
|
||||
This has been renamed "accelerator".
|
||||
|
||||
fast_dev_run
|
||||
^^^^^^^^^^^^
|
||||
|
|
|
@ -58,6 +58,8 @@ from pytorch_lightning.utilities.cloud_io import load as pl_load
|
|||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.trainer.properties import TrainerProperties
|
||||
from pytorch_lightning.plugins.plugin_connector import PluginConnector
|
||||
from pytorch_lightning.accelerators.base_accelerator import Accelerator
|
||||
from pytorch_lightning.accelerators.cpu_backend import CPUBackend
|
||||
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings(
|
||||
|
@ -111,7 +113,7 @@ class Trainer(
|
|||
val_check_interval: Union[int, float] = 1.0,
|
||||
flush_logs_every_n_steps: int = 100,
|
||||
log_every_n_steps: int = 50,
|
||||
distributed_backend: Optional[str] = None,
|
||||
accelerator: Optional[Union[str, Accelerator]] = None,
|
||||
sync_batchnorm: bool = False,
|
||||
precision: int = 32,
|
||||
weights_summary: Optional[str] = 'top',
|
||||
|
@ -131,12 +133,16 @@ class Trainer(
|
|||
plugins: list = None,
|
||||
amp_backend: str = 'native',
|
||||
amp_level: str = 'O2',
|
||||
distributed_backend: Optional[str] = None,
|
||||
):
|
||||
r"""
|
||||
Customize every aspect of training via flags
|
||||
|
||||
Args:
|
||||
|
||||
accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...).
|
||||
Can also take in an accelerator object for custom hardware.
|
||||
|
||||
accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict.
|
||||
|
||||
amp_backend: The mixed precision backend to use ("native" or "apex")
|
||||
|
@ -173,7 +179,7 @@ class Trainer(
|
|||
|
||||
deterministic: If true enables cudnn.deterministic.
|
||||
|
||||
distributed_backend: The distributed backend to use (dp, ddp, ddp2, ddp_spawn, ddp_cpu)
|
||||
distributed_backend: deprecated. Please use 'accelerator'
|
||||
|
||||
fast_dev_run: runs 1 batch of train, test and val to find any bugs (ie: a sort of unit test).
|
||||
|
||||
|
@ -318,6 +324,7 @@ class Trainer(
|
|||
self.accelerator_connector.on_trainer_init(
|
||||
num_processes,
|
||||
tpu_cores,
|
||||
accelerator,
|
||||
distributed_backend,
|
||||
auto_select_gpus,
|
||||
gpus,
|
||||
|
|
|
@ -105,9 +105,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
|
|||
>>> args = get_init_arguments_and_types(Trainer)
|
||||
>>> import pprint
|
||||
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
||||
[('accumulate_grad_batches',
|
||||
(<class 'int'>, typing.Dict[int, int], typing.List[list]),
|
||||
1),
|
||||
[('accelerator',
|
||||
(<class 'str'>,
|
||||
<class 'pytorch_lightning.accelerators.base_accelerator.Accelerator'>,
|
||||
<class 'NoneType'>),
|
||||
None),
|
||||
...
|
||||
('callbacks',
|
||||
(typing.List[pytorch_lightning.callbacks.base.Callback],
|
||||
|
|
|
@ -18,6 +18,7 @@ from tests.base.boring_model import BoringModel
|
|||
from pytorch_lightning.callbacks import Callback
|
||||
from pytorch_lightning import accelerators, Trainer
|
||||
from pytorch_lightning.cluster_environments import SLURMEnvironment, TorchElasticEnvironment, ClusterEnvironment
|
||||
from pytorch_lightning.accelerators import Accelerator
|
||||
from unittest import mock
|
||||
|
||||
|
||||
|
@ -297,3 +298,61 @@ def test_accelerator_choice_ddp_cpu_custom_cluster(tmpdir):
|
|||
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {
|
||||
"SLURM_NTASKS": "1",
|
||||
"SLURM_JOB_NAME": "SOME_NAME",
|
||||
"SLURM_NODEID": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"SLURM_LOCALID": "0"
|
||||
})
|
||||
@mock.patch('torch.cuda.device_count', return_value=0)
|
||||
def test_custom_accelerator(tmpdir):
|
||||
class Accel(Accelerator):
|
||||
def init_ddp_connection(
|
||||
self, global_rank: int, world_size: int, is_slurm_managing_tasks: bool = True
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
class CB(Callback):
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
assert isinstance(trainer.accelerator_backend, Accel)
|
||||
raise SystemExit()
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
accelerator=Accel(),
|
||||
num_processes=1,
|
||||
callbacks=[CB()]
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {
|
||||
"SLURM_NTASKS": "1",
|
||||
"SLURM_JOB_NAME": "SOME_NAME",
|
||||
"SLURM_NODEID": "0",
|
||||
"LOCAL_RANK": "0",
|
||||
"SLURM_LOCALID": "0"
|
||||
})
|
||||
@mock.patch('torch.cuda.device_count', return_value=0)
|
||||
def test_dist_backend_accelerator_mapping(tmpdir):
|
||||
class CB(Callback):
|
||||
def on_fit_start(self, trainer, pl_module):
|
||||
assert isinstance(trainer.accelerator_backend, accelerators.DDPCPUSLURMBackend)
|
||||
raise SystemExit()
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
fast_dev_run=True,
|
||||
accelerator='ddp_cpu',
|
||||
num_processes=1,
|
||||
callbacks=[CB()]
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
|
Loading…
Reference in New Issue