diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 35c1a12d4d..6411c2b61f 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -17,10 +17,10 @@ import platform from abc import ABC, abstractmethod from typing import Union, List, Tuple, Callable, Optional -import torch.distributed as torch_distrib from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +from pytorch_lightning.accelerators.base_backend import Accelerator from pytorch_lightning.accelerators.base_backend import BackendType from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_warn @@ -75,6 +75,7 @@ class TrainerDataLoadingMixin(ABC): limit_val_batches: Union[int, float] limit_test_batches: Union[int, float] replace_sampler_ddp: bool + accelerator_backend: Accelerator num_nodes: int num_processes: int distributed_backend: Optional[str] @@ -337,18 +338,6 @@ class TrainerDataLoadingMixin(ABC): """ dataloader = dataloader_fx() - # get the function we'll use to get data - if self.use_ddp or self.use_ddp2: - # all processes wait until data download has happened - torch_distrib.barrier() - - # data download/load on TPU - elif self.use_tpu and XLA_AVAILABLE: - # all processes wait until data download has happened - torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders') - - elif self.use_horovod: - # all processes wait until data download has happened - hvd.join() - + if self.accelerator_backend is not None: + self.accelerator_backend.barrier('get_dataloaders') return dataloader diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 138bdf92c3..446591a436 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -696,9 +696,6 @@ class Trainer( # -------------------- self.verbose_test = verbose - if self.global_rank != 0: - return - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders if test_dataloaders and datamodule: raise MisconfigurationException( @@ -738,6 +735,8 @@ class Trainer( f'specify a path for a checkpoint .test(ckpt_path=PATH)' ) return {} + if self.accelerator_backend is not None: + self.accelerator_backend.barrier() ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt['state_dict']) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e670c01f04..99318b9f34 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -181,13 +181,8 @@ class TrainLoop: if self.trainer.global_rank == 0: self.trainer.profiler.describe() - if self.trainer.global_rank == 0: - for proc in self.trainer.interactive_ddp_procs: - subprocess.Popen.kill(proc) - - # clean up dist group - if self.trainer.use_ddp or self.trainer.use_ddp2: - torch_distrib.destroy_process_group() + # give accelerators a chance to finish + self.trainer.accelerator_backend.on_train_end() # clear mem if self.trainer.on_gpu: diff --git a/tests/backends/__init__.py b/tests/backends/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/backends/ddp_model.py b/tests/backends/ddp_model.py new file mode 100644 index 0000000000..9f75415fe6 --- /dev/null +++ b/tests/backends/ddp_model.py @@ -0,0 +1,43 @@ +""" +Runs either `.fit()` or `.test()` on a single node across multiple gpus. +""" +from argparse import ArgumentParser + +from pytorch_lightning import Trainer, seed_everything +from tests.base import EvalModelTemplate +import os +import torch + + +def main(): + seed_everything(1234) + parser = ArgumentParser(add_help=False) + parser = Trainer.add_argparse_args(parser) + parser.add_argument('--trainer_method', default='fit') + parser.add_argument('--tmpdir') + parser.set_defaults(gpus=2) + parser.set_defaults(distributed_backend="ddp") + args = parser.parse_args() + + model = EvalModelTemplate() + trainer = Trainer.from_argparse_args(args) + + result = {} + if args.trainer_method == 'fit': + trainer.fit(model) + result = {'status': 'complete', 'method': args.trainer_method, 'result': None} + if args.trainer_method == 'test': + result = trainer.test(model) + result = {'status': 'complete', 'method': args.trainer_method, 'result': result} + if args.trainer_method == 'fit_test': + trainer.fit(model) + result = trainer.test(model) + result = {'status': 'complete', 'method': args.trainer_method, 'result': result} + + if len(result) > 0: + file_path = os.path.join(args.tmpdir, 'ddp.result') + torch.save(result, file_path) + + +if __name__ == '__main__': + main()