ref: group prepare data hook (6) (#3212)
* group prepare data hook * group prepare data hook * group prepare data hook * group prepare data hook * group prepare data hook * group prepare data hook * group prepare data hook
This commit is contained in:
parent
be0438bb47
commit
464a0e7bb1
|
@ -982,7 +982,7 @@ class Trainer(
|
|||
parsing.clean_namespace(model.hparams)
|
||||
|
||||
# links data to the trainer
|
||||
self.attach_data(model, train_dataloader, val_dataloaders)
|
||||
self.attach_data(model, train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# check that model is configured correctly
|
||||
self.config_validator.verify_loop_configurations(model)
|
||||
|
@ -990,13 +990,8 @@ class Trainer(
|
|||
# hook
|
||||
self.call_hook('on_fit_start', model)
|
||||
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
||||
if self.can_prepare_data():
|
||||
if self.datamodule is not None:
|
||||
self.datamodule.prepare_data()
|
||||
model.prepare_data()
|
||||
self._is_data_prepared = True
|
||||
# hook
|
||||
self.prepare_data(model)
|
||||
|
||||
# Run auto batch size scaling
|
||||
if self.auto_scale_batch_size:
|
||||
|
@ -1013,18 +1008,17 @@ class Trainer(
|
|||
# set testing if set in environ
|
||||
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
|
||||
|
||||
# choose accelerator
|
||||
# -------------------------
|
||||
# TRAIN
|
||||
# -------------------------
|
||||
self.accelerator_backend = self.select_accelerator()
|
||||
|
||||
# setup accelerator
|
||||
self.accelerator_backend.setup(model)
|
||||
|
||||
# train!
|
||||
results = self.accelerator_backend.train()
|
||||
|
||||
# teardown accelerator
|
||||
self.accelerator_backend.teardown()
|
||||
|
||||
# -------------------------
|
||||
# POST-Training
|
||||
# -------------------------
|
||||
# hook
|
||||
self.call_hook('on_fit_end')
|
||||
|
||||
|
@ -1037,7 +1031,16 @@ class Trainer(
|
|||
# used for testing or when we need to know that training succeeded
|
||||
return results or 1
|
||||
|
||||
def attach_data(self, model, train_dataloader, val_dataloaders):
|
||||
def prepare_data(self, model):
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
||||
if self.can_prepare_data():
|
||||
if self.datamodule is not None:
|
||||
self.datamodule.prepare_data()
|
||||
model.prepare_data()
|
||||
self._is_data_prepared = True
|
||||
|
||||
def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
# if a datamodule comes in as the second arg, then fix it for the user
|
||||
if isinstance(train_dataloader, LightningDataModule):
|
||||
datamodule = train_dataloader
|
||||
|
@ -1059,10 +1062,7 @@ class Trainer(
|
|||
|
||||
use_ddp_spawn = self.use_ddp and self.distributed_backend in ['ddp_cpu', 'ddp_spawn']
|
||||
|
||||
# -------------------
|
||||
# route training mode
|
||||
# -------------------
|
||||
# DDP2 (cluster only)
|
||||
# choose the appropriate accelerator backend
|
||||
if self.use_ddp2:
|
||||
accelerator_backend = DDP2Backend(self)
|
||||
|
||||
|
@ -1072,15 +1072,12 @@ class Trainer(
|
|||
elif use_torchelastic_ddp:
|
||||
accelerator_backend = DDPBackend(self, mode='torchelastic_ddp')
|
||||
|
||||
# regular ddp using .spawn
|
||||
elif use_ddp_spawn:
|
||||
accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes)
|
||||
|
||||
# ddp
|
||||
elif self.distributed_backend == 'ddp':
|
||||
accelerator_backend = DDPBackend(self, mode='ddp')
|
||||
|
||||
# dp
|
||||
elif self.use_dp:
|
||||
accelerator_backend = DataParallelBackend(self)
|
||||
|
||||
|
@ -1098,7 +1095,6 @@ class Trainer(
|
|||
|
||||
return accelerator_backend
|
||||
|
||||
|
||||
def can_prepare_data(self):
|
||||
should_call_dm_prepare_data = True
|
||||
if self.datamodule is not None and self.is_overridden('prepare_data', self.datamodule):
|
||||
|
|
Loading…
Reference in New Issue