parent
440f837f6d
commit
ac2b0f0f06
|
@ -17,10 +17,10 @@ import platform
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Union, List, Tuple, Callable, Optional
|
||||
|
||||
import torch.distributed as torch_distrib
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||
from pytorch_lightning.accelerators.base_backend import BackendType
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
@ -75,6 +75,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
limit_val_batches: Union[int, float]
|
||||
limit_test_batches: Union[int, float]
|
||||
replace_sampler_ddp: bool
|
||||
accelerator_backend: Accelerator
|
||||
num_nodes: int
|
||||
num_processes: int
|
||||
distributed_backend: Optional[str]
|
||||
|
@ -337,18 +338,6 @@ class TrainerDataLoadingMixin(ABC):
|
|||
"""
|
||||
dataloader = dataloader_fx()
|
||||
|
||||
# get the function we'll use to get data
|
||||
if self.use_ddp or self.use_ddp2:
|
||||
# all processes wait until data download has happened
|
||||
torch_distrib.barrier()
|
||||
|
||||
# data download/load on TPU
|
||||
elif self.use_tpu and XLA_AVAILABLE:
|
||||
# all processes wait until data download has happened
|
||||
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')
|
||||
|
||||
elif self.use_horovod:
|
||||
# all processes wait until data download has happened
|
||||
hvd.join()
|
||||
|
||||
if self.accelerator_backend is not None:
|
||||
self.accelerator_backend.barrier('get_dataloaders')
|
||||
return dataloader
|
||||
|
|
|
@ -696,9 +696,6 @@ class Trainer(
|
|||
# --------------------
|
||||
self.verbose_test = verbose
|
||||
|
||||
if self.global_rank != 0:
|
||||
return
|
||||
|
||||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
|
||||
if test_dataloaders and datamodule:
|
||||
raise MisconfigurationException(
|
||||
|
@ -738,6 +735,8 @@ class Trainer(
|
|||
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
|
||||
)
|
||||
return {}
|
||||
if self.accelerator_backend is not None:
|
||||
self.accelerator_backend.barrier()
|
||||
|
||||
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(ckpt['state_dict'])
|
||||
|
|
|
@ -181,13 +181,8 @@ class TrainLoop:
|
|||
if self.trainer.global_rank == 0:
|
||||
self.trainer.profiler.describe()
|
||||
|
||||
if self.trainer.global_rank == 0:
|
||||
for proc in self.trainer.interactive_ddp_procs:
|
||||
subprocess.Popen.kill(proc)
|
||||
|
||||
# clean up dist group
|
||||
if self.trainer.use_ddp or self.trainer.use_ddp2:
|
||||
torch_distrib.destroy_process_group()
|
||||
# give accelerators a chance to finish
|
||||
self.trainer.accelerator_backend.on_train_end()
|
||||
|
||||
# clear mem
|
||||
if self.trainer.on_gpu:
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
"""
|
||||
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from tests.base import EvalModelTemplate
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
def main():
|
||||
seed_everything(1234)
|
||||
parser = ArgumentParser(add_help=False)
|
||||
parser = Trainer.add_argparse_args(parser)
|
||||
parser.add_argument('--trainer_method', default='fit')
|
||||
parser.add_argument('--tmpdir')
|
||||
parser.set_defaults(gpus=2)
|
||||
parser.set_defaults(distributed_backend="ddp")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = EvalModelTemplate()
|
||||
trainer = Trainer.from_argparse_args(args)
|
||||
|
||||
result = {}
|
||||
if args.trainer_method == 'fit':
|
||||
trainer.fit(model)
|
||||
result = {'status': 'complete', 'method': args.trainer_method, 'result': None}
|
||||
if args.trainer_method == 'test':
|
||||
result = trainer.test(model)
|
||||
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
|
||||
if args.trainer_method == 'fit_test':
|
||||
trainer.fit(model)
|
||||
result = trainer.test(model)
|
||||
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
|
||||
|
||||
if len(result) > 0:
|
||||
file_path = os.path.join(args.tmpdir, 'ddp.result')
|
||||
torch.save(result, file_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue