diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 88130e1b59..c7edc1c935 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -982,7 +982,7 @@ class Trainer( parsing.clean_namespace(model.hparams) # links data to the trainer - self.attach_data(model, train_dataloader, val_dataloaders) + self.attach_data(model, train_dataloader, val_dataloaders, datamodule) # check that model is configured correctly self.config_validator.verify_loop_configurations(model) @@ -990,13 +990,8 @@ class Trainer( # hook self.call_hook('on_fit_start', model) - # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 - # or in the case where each node needs to do its own manipulation in which case just local_rank=0 - if self.can_prepare_data(): - if self.datamodule is not None: - self.datamodule.prepare_data() - model.prepare_data() - self._is_data_prepared = True + # hook + self.prepare_data(model) # Run auto batch size scaling if self.auto_scale_batch_size: @@ -1013,18 +1008,17 @@ class Trainer( # set testing if set in environ self.testing = os.environ.get('PL_TESTING_MODE', self.testing) - # choose accelerator + # ------------------------- + # TRAIN + # ------------------------- self.accelerator_backend = self.select_accelerator() - - # setup accelerator self.accelerator_backend.setup(model) - - # train! results = self.accelerator_backend.train() - - # teardown accelerator self.accelerator_backend.teardown() + # ------------------------- + # POST-Training + # ------------------------- # hook self.call_hook('on_fit_end') @@ -1037,7 +1031,16 @@ class Trainer( # used for testing or when we need to know that training succeeded return results or 1 - def attach_data(self, model, train_dataloader, val_dataloaders): + def prepare_data(self, model): + # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 + # or in the case where each node needs to do its own manipulation in which case just local_rank=0 + if self.can_prepare_data(): + if self.datamodule is not None: + self.datamodule.prepare_data() + model.prepare_data() + self._is_data_prepared = True + + def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): # if a datamodule comes in as the second arg, then fix it for the user if isinstance(train_dataloader, LightningDataModule): datamodule = train_dataloader @@ -1059,10 +1062,7 @@ class Trainer( use_ddp_spawn = self.use_ddp and self.distributed_backend in ['ddp_cpu', 'ddp_spawn'] - # ------------------- - # route training mode - # ------------------- - # DDP2 (cluster only) + # choose the appropriate accelerator backend if self.use_ddp2: accelerator_backend = DDP2Backend(self) @@ -1072,15 +1072,12 @@ class Trainer( elif use_torchelastic_ddp: accelerator_backend = DDPBackend(self, mode='torchelastic_ddp') - # regular ddp using .spawn elif use_ddp_spawn: accelerator_backend = DDPSpawnBackend(self, nprocs=self.num_processes) - # ddp elif self.distributed_backend == 'ddp': accelerator_backend = DDPBackend(self, mode='ddp') - # dp elif self.use_dp: accelerator_backend = DataParallelBackend(self) @@ -1098,7 +1095,6 @@ class Trainer( return accelerator_backend - def can_prepare_data(self): should_call_dm_prepare_data = True if self.datamodule is not None and self.is_overridden('prepare_data', self.datamodule):