parent
440f837f6d
commit
ac2b0f0f06
|
@ -17,10 +17,10 @@ import platform
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Union, List, Tuple, Callable, Optional
|
from typing import Union, List, Tuple, Callable, Optional
|
||||||
|
|
||||||
import torch.distributed as torch_distrib
|
|
||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
|
||||||
|
from pytorch_lightning.accelerators.base_backend import Accelerator
|
||||||
from pytorch_lightning.accelerators.base_backend import BackendType
|
from pytorch_lightning.accelerators.base_backend import BackendType
|
||||||
from pytorch_lightning.core import LightningModule
|
from pytorch_lightning.core import LightningModule
|
||||||
from pytorch_lightning.utilities import rank_zero_warn
|
from pytorch_lightning.utilities import rank_zero_warn
|
||||||
|
@ -75,6 +75,7 @@ class TrainerDataLoadingMixin(ABC):
|
||||||
limit_val_batches: Union[int, float]
|
limit_val_batches: Union[int, float]
|
||||||
limit_test_batches: Union[int, float]
|
limit_test_batches: Union[int, float]
|
||||||
replace_sampler_ddp: bool
|
replace_sampler_ddp: bool
|
||||||
|
accelerator_backend: Accelerator
|
||||||
num_nodes: int
|
num_nodes: int
|
||||||
num_processes: int
|
num_processes: int
|
||||||
distributed_backend: Optional[str]
|
distributed_backend: Optional[str]
|
||||||
|
@ -337,18 +338,6 @@ class TrainerDataLoadingMixin(ABC):
|
||||||
"""
|
"""
|
||||||
dataloader = dataloader_fx()
|
dataloader = dataloader_fx()
|
||||||
|
|
||||||
# get the function we'll use to get data
|
if self.accelerator_backend is not None:
|
||||||
if self.use_ddp or self.use_ddp2:
|
self.accelerator_backend.barrier('get_dataloaders')
|
||||||
# all processes wait until data download has happened
|
|
||||||
torch_distrib.barrier()
|
|
||||||
|
|
||||||
# data download/load on TPU
|
|
||||||
elif self.use_tpu and XLA_AVAILABLE:
|
|
||||||
# all processes wait until data download has happened
|
|
||||||
torch_xla.core.xla_model.rendezvous('pl.TrainerDataLoadingMixin.get_dataloaders')
|
|
||||||
|
|
||||||
elif self.use_horovod:
|
|
||||||
# all processes wait until data download has happened
|
|
||||||
hvd.join()
|
|
||||||
|
|
||||||
return dataloader
|
return dataloader
|
||||||
|
|
|
@ -696,9 +696,6 @@ class Trainer(
|
||||||
# --------------------
|
# --------------------
|
||||||
self.verbose_test = verbose
|
self.verbose_test = verbose
|
||||||
|
|
||||||
if self.global_rank != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
|
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
|
||||||
if test_dataloaders and datamodule:
|
if test_dataloaders and datamodule:
|
||||||
raise MisconfigurationException(
|
raise MisconfigurationException(
|
||||||
|
@ -738,6 +735,8 @@ class Trainer(
|
||||||
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
|
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
|
||||||
)
|
)
|
||||||
return {}
|
return {}
|
||||||
|
if self.accelerator_backend is not None:
|
||||||
|
self.accelerator_backend.barrier()
|
||||||
|
|
||||||
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
|
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
|
||||||
model.load_state_dict(ckpt['state_dict'])
|
model.load_state_dict(ckpt['state_dict'])
|
||||||
|
|
|
@ -181,13 +181,8 @@ class TrainLoop:
|
||||||
if self.trainer.global_rank == 0:
|
if self.trainer.global_rank == 0:
|
||||||
self.trainer.profiler.describe()
|
self.trainer.profiler.describe()
|
||||||
|
|
||||||
if self.trainer.global_rank == 0:
|
# give accelerators a chance to finish
|
||||||
for proc in self.trainer.interactive_ddp_procs:
|
self.trainer.accelerator_backend.on_train_end()
|
||||||
subprocess.Popen.kill(proc)
|
|
||||||
|
|
||||||
# clean up dist group
|
|
||||||
if self.trainer.use_ddp or self.trainer.use_ddp2:
|
|
||||||
torch_distrib.destroy_process_group()
|
|
||||||
|
|
||||||
# clear mem
|
# clear mem
|
||||||
if self.trainer.on_gpu:
|
if self.trainer.on_gpu:
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
"""
|
||||||
|
Runs either `.fit()` or `.test()` on a single node across multiple gpus.
|
||||||
|
"""
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from pytorch_lightning import Trainer, seed_everything
|
||||||
|
from tests.base import EvalModelTemplate
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
seed_everything(1234)
|
||||||
|
parser = ArgumentParser(add_help=False)
|
||||||
|
parser = Trainer.add_argparse_args(parser)
|
||||||
|
parser.add_argument('--trainer_method', default='fit')
|
||||||
|
parser.add_argument('--tmpdir')
|
||||||
|
parser.set_defaults(gpus=2)
|
||||||
|
parser.set_defaults(distributed_backend="ddp")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
model = EvalModelTemplate()
|
||||||
|
trainer = Trainer.from_argparse_args(args)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
if args.trainer_method == 'fit':
|
||||||
|
trainer.fit(model)
|
||||||
|
result = {'status': 'complete', 'method': args.trainer_method, 'result': None}
|
||||||
|
if args.trainer_method == 'test':
|
||||||
|
result = trainer.test(model)
|
||||||
|
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
|
||||||
|
if args.trainer_method == 'fit_test':
|
||||||
|
trainer.fit(model)
|
||||||
|
result = trainer.test(model)
|
||||||
|
result = {'status': 'complete', 'method': args.trainer_method, 'result': result}
|
||||||
|
|
||||||
|
if len(result) > 0:
|
||||||
|
file_path = os.path.join(args.tmpdir, 'ddp.result')
|
||||||
|
torch.save(result, file_path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
Reference in New Issue