diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 24236a6eb4..88130e1b59 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -981,16 +981,8 @@ class Trainer( if hasattr(model, 'hparams'): parsing.clean_namespace(model.hparams) - # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None - - self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) - - # set up the passed in dataloaders (if needed) - self.__attach_dataloaders(model, train_dataloader, val_dataloaders) - self.__attach_datamodule(model, datamodule, 'fit') + # links data to the trainer + self.attach_data(model, train_dataloader, val_dataloaders) # check that model is configured correctly self.config_validator.verify_loop_configurations(model) @@ -1045,6 +1037,18 @@ class Trainer( # used for testing or when we need to know that training succeeded return results or 1 + def attach_data(self, model, train_dataloader, val_dataloaders): + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + + self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) + + # set up the passed in dataloaders (if needed) + self.__attach_dataloaders(model, train_dataloader, val_dataloaders) + self.__attach_datamodule(model, datamodule, 'fit') + def select_accelerator(self): # SLURM ddp use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks