From 622c5c3982719e5335afa832cc719bd8edb863fb Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 1 Oct 2020 11:26:58 -0400 Subject: [PATCH] ref: part 4 of #3733 (#3773) * ref: part 4 of #3733 * ref: part 4 of #3733 * ref: part 4 of #3733 * ref: part 4 of #3733 --- .../accelerators/ddp_cpu_spawn_backend.py | 52 +++++++++----- tests/backends/test_ddp_spawn.py | 71 +++++++++++++++++++ tests/models/test_gpu.py | 41 ----------- 3 files changed, 104 insertions(+), 60 deletions(-) create mode 100644 tests/backends/test_ddp_spawn.py diff --git a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py index 2f0c6d29c7..e91428568d 100644 --- a/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py +++ b/pytorch_lightning/accelerators/ddp_cpu_spawn_backend.py @@ -68,19 +68,6 @@ class DDPCPUSpawnBackend(Accelerator): self.__recover_child_process_weights(model, best_path, last_path) return results - def __recover_child_process_weights(self, model, best_path, last_path): - # transfer back the best path to the trainer - if self.trainer.checkpoint_callback: - self.trainer.checkpoint_callback.best_model_path = best_path - # todo, pass also best score - - # load last weights - if last_path is not None and not self.trainer.testing: - ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) - model.load_state_dict(ckpt) - - self.trainer.model = model - def ddp_train(self, process_idx, mp_queue, model): """ Entry point for ddp @@ -95,9 +82,7 @@ class DDPCPUSpawnBackend(Accelerator): self.trainer.progress_bar_callback.disable() # determine which process we are and world size - self.trainer.local_rank = process_idx - self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx - self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + self.set_world_ranks(process_idx) # set warning rank rank_zero_only.rank = self.trainer.global_rank @@ -116,7 +101,7 @@ class DDPCPUSpawnBackend(Accelerator): self.trainer.call_setup_hook(model) # on world_size=0 let everyone know training is starting - if self.trainer.is_global_zero: + if self.trainer.is_global_zero and not torch.distributed.is_initialized(): log.info('-' * 100) log.info(f'distributed_backend={self.trainer.distributed_backend}') log.info(f'All DDP processes registered. Starting ddp with {self.trainer.world_size} processes') @@ -126,6 +111,9 @@ class DDPCPUSpawnBackend(Accelerator): if self.trainer.sync_batchnorm: model = model.configure_sync_batchnorm(model) + # move the model to the correct device + self.model_to_device(model, process_idx) + # CHOOSE OPTIMIZER # allow for lr schedulers as well self.setup_optimizers(model) @@ -137,7 +125,7 @@ class DDPCPUSpawnBackend(Accelerator): model = self.trainer.precision_connector.connect(model) # DDP spawn already spawned off each process... no need to do anything - device_ids = None + device_ids = self.get_device_ids() # allow user to configure ddp model = model.configure_ddp(model, device_ids) @@ -174,7 +162,8 @@ class DDPCPUSpawnBackend(Accelerator): return output def barrier(self, name: str = None): - torch_distrib.barrier() + if torch_distrib.is_initialized(): + torch_distrib.barrier() def broadcast(self, obj, src=0): return self.dist.broadcast(obj) @@ -186,6 +175,31 @@ class DDPCPUSpawnBackend(Accelerator): should_stop = stop == self.trainer.world_size return should_stop + def set_world_ranks(self, process_idx): + self.trainer.local_rank = process_idx + self.trainer.global_rank = self.trainer.node_rank * self.trainer.num_processes + process_idx + self.trainer.world_size = self.trainer.num_nodes * self.trainer.num_processes + + def model_to_device(self, model, process_idx): + model.cpu() + + def get_device_ids(self): + device_ids = None + return device_ids + + def __recover_child_process_weights(self, model, best_path, last_path): + # transfer back the best path to the trainer + if self.trainer.checkpoint_callback: + self.trainer.checkpoint_callback.best_model_path = best_path + # todo, pass also best score + + # load last weights + if last_path is not None and not self.trainer.testing: + ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) + model.load_state_dict(ckpt) + + self.trainer.model = model + def transfer_distrib_spawn_state_on_fit_end(self, model, mp_queue, results): # track the best model path best_model_path = None diff --git a/tests/backends/test_ddp_spawn.py b/tests/backends/test_ddp_spawn.py new file mode 100644 index 0000000000..0c5db6b1a0 --- /dev/null +++ b/tests/backends/test_ddp_spawn.py @@ -0,0 +1,71 @@ +import pytest +import torch + +import tests.base.develop_pipelines as tpipes +import tests.base.develop_utils as tutils +from tests.base import EvalModelTemplate +from pytorch_lightning.core import memory +from pytorch_lightning.trainer import Trainer + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_early_stop_ddp_spawn(tmpdir): + """Make sure DDP works. with early stopping""" + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + early_stop_callback=True, + max_epochs=50, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + ) + + model = EvalModelTemplate() + tpipes.run_model_test(trainer_options, model) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_multi_gpu_model_ddp_spawn(tmpdir): + tutils.set_random_master_port() + + trainer_options = dict( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=10, + limit_val_batches=10, + gpus=[0, 1], + distributed_backend='ddp_spawn', + progress_bar_refresh_rate=0 + ) + + model = EvalModelTemplate() + + tpipes.run_model_test(trainer_options, model) + + # test memory helper functions + memory.get_memory_profile('min_max') + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_ddp_all_dataloaders_passed_to_fit(tmpdir): + """Make sure DDP works with dataloaders passed to fit()""" + tutils.set_random_master_port() + + model = EvalModelTemplate() + fit_options = dict(train_dataloader=model.train_dataloader(), + val_dataloaders=model.val_dataloader()) + + trainer = Trainer( + default_root_dir=tmpdir, + progress_bar_refresh_rate=0, + max_epochs=1, + limit_train_batches=0.2, + limit_val_batches=0.2, + gpus=[0, 1], + distributed_backend='ddp_spawn' + ) + result = trainer.fit(model, **fit_options) + assert result == 1, "DDP doesn't work with dataloaders passed to fit()." diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 83416444bb..56a58760ee 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -62,25 +62,6 @@ def test_multi_gpu_none_backend(tmpdir): tpipes.run_model_test(trainer_options, model) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_multi_gpu_early_stop_ddp_spawn(tmpdir): - """Make sure DDP works. with early stopping""" - tutils.set_random_master_port() - - trainer_options = dict( - default_root_dir=tmpdir, - early_stop_callback=True, - max_epochs=50, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - distributed_backend='ddp_spawn', - ) - - model = EvalModelTemplate() - tpipes.run_model_test(trainer_options, model) - - @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_multi_gpu_model_dp(tmpdir): tutils.set_random_master_port() @@ -131,28 +112,6 @@ def test_multi_gpu_model_ddp(tmpdir, cli_args, variation): pytest.fail(err) -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -def test_multi_gpu_model_ddp_spawn(tmpdir): - tutils.set_random_master_port() - - trainer_options = dict( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=10, - limit_val_batches=10, - gpus=[0, 1], - distributed_backend='ddp_spawn', - progress_bar_refresh_rate=0 - ) - - model = EvalModelTemplate() - - tpipes.run_model_test(trainer_options, model) - - # test memory helper functions - memory.get_memory_profile('min_max') - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.parametrize('gpus', [1, [0], [1]]) def test_single_gpu_model(tmpdir, gpus):