ref: continue #3733 (#3767)

* ref: #3733 part 2

* ref: #3733 part 2
This commit is contained in:
William Falcon 2020-10-01 09:25:33 -04:00 committed by GitHub
parent 440f837f6d
commit ac2b0f0f06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 51 additions and 25 deletions

View File

@ -17,10 +17,10 @@ import platform
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Union, List, Tuple, Callable, Optional from typing import Union, List, Tuple, Callable, Optional
import torch.distributed as torch_distrib
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.accelerators.base_backend import Accelerator
from pytorch_lightning.accelerators.base_backend import BackendType from pytorch_lightning.accelerators.base_backend import BackendType
from pytorch_lightning.core import LightningModule from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities import rank_zero_warn
@ -75,6 +75,7 @@ class TrainerDataLoadingMixin(ABC):
limit_val_batches: Union[int, float] limit_val_batches: Union[int, float]
limit_test_batches: Union[int, float] limit_test_batches: Union[int, float]
replace_sampler_ddp: bool replace_sampler_ddp: bool
accelerator_backend: Accelerator
num_nodes: int num_nodes: int
num_processes: int num_processes: int
distributed_backend: Optional[str] distributed_backend: Optional[str]
@ -337,18 +338,6 @@ class TrainerDataLoadingMixin(ABC):
""" """
dataloader = dataloader_fx() dataloader = dataloader_fx()
# get the function we'll use to get data if self.accelerator_backend is not None:
if self.use_ddp or self.use_ddp2: self.accelerator_backend.barrier('get_dataloaders')
# all processes wait until data download has happened
torch_distrib.barrier()
# data download/load on TPU
elif self.use_tpu and XLA_AVAILABLE:
# all processes wait until data download has happened
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')
elif self.use_horovod:
# all processes wait until data download has happened
hvd.join()
return dataloader return dataloader

View File

@ -696,9 +696,6 @@ class Trainer(
# -------------------- # --------------------
self.verbose_test = verbose self.verbose_test = verbose
if self.global_rank != 0:
return
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders # If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if test_dataloaders and datamodule: if test_dataloaders and datamodule:
raise MisconfigurationException( raise MisconfigurationException(
@ -738,6 +735,8 @@ class Trainer(
f'specify a path for a checkpoint .test(ckpt_path=PATH)' f'specify a path for a checkpoint .test(ckpt_path=PATH)'
) )
return {} return {}
if self.accelerator_backend is not None:
self.accelerator_backend.barrier()
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage) ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict']) model.load_state_dict(ckpt['state_dict'])

View File

@ -181,13 +181,8 @@ class TrainLoop:
if self.trainer.global_rank == 0: if self.trainer.global_rank == 0:
self.trainer.profiler.describe() self.trainer.profiler.describe()
if self.trainer.global_rank == 0: # give accelerators a chance to finish
for proc in self.trainer.interactive_ddp_procs: self.trainer.accelerator_backend.on_train_end()
subprocess.Popen.kill(proc)
# clean up dist group
if self.trainer.use_ddp or self.trainer.use_ddp2:
torch_distrib.destroy_process_group()
# clear mem # clear mem
if self.trainer.on_gpu: if self.trainer.on_gpu:

View File

View File

@ -0,0 +1,43 @@
"""
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
"""
from argparse import ArgumentParser
from pytorch_lightning import Trainer, seed_everything
from tests.base import EvalModelTemplate
import os
import torch
def main():
seed_everything(1234)
parser = ArgumentParser(add_help=False)
parser = Trainer.add_argparse_args(parser)
parser.add_argument('--trainer_method', default='fit')
parser.add_argument('--tmpdir')
parser.set_defaults(gpus=2)
parser.set_defaults(distributed_backend="ddp")
args = parser.parse_args()
model = EvalModelTemplate()
trainer = Trainer.from_argparse_args(args)
result = {}
if args.trainer_method == 'fit':
trainer.fit(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': None}
if args.trainer_method == 'test':
result = trainer.test(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
if args.trainer_method == 'fit_test':
trainer.fit(model)
result = trainer.test(model)
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
if len(result) > 0:
file_path = os.path.join(args.tmpdir, 'ddp.result')
torch.save(result, file_path)
if __name__ == '__main__':
main()