group fit data links

This commit is contained in:
William Falcon 2020-08-26 21:34:55 -04:00
parent eb12f58edf
commit be0438bb47
1 changed files with 14 additions and 10 deletions

View File

@ -981,16 +981,8 @@ class Trainer(
if hasattr(model, 'hparams'):
parsing.clean_namespace(model.hparams)
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None
self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
self.__attach_datamodule(model, datamodule, 'fit')
# links data to the trainer
self.attach_data(model, train_dataloader, val_dataloaders)
# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)
@ -1045,6 +1037,18 @@ class Trainer(
# used for testing or when we need to know that training succeeded
return results or 1
def attach_data(self, model, train_dataloader, val_dataloaders):
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None
self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
self.__attach_datamodule(model, datamodule, 'fit')
def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks