group fit data links
This commit is contained in:
parent
eb12f58edf
commit
be0438bb47
|
@ -981,16 +981,8 @@ class Trainer(
|
|||
if hasattr(model, 'hparams'):
|
||||
parsing.clean_namespace(model.hparams)
|
||||
|
||||
# if a datamodule comes in as the second arg, then fix it for the user
|
||||
if isinstance(train_dataloader, LightningDataModule):
|
||||
datamodule = train_dataloader
|
||||
train_dataloader = None
|
||||
|
||||
self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# set up the passed in dataloaders (if needed)
|
||||
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
|
||||
self.__attach_datamodule(model, datamodule, 'fit')
|
||||
# links data to the trainer
|
||||
self.attach_data(model, train_dataloader, val_dataloaders)
|
||||
|
||||
# check that model is configured correctly
|
||||
self.config_validator.verify_loop_configurations(model)
|
||||
|
@ -1045,6 +1037,18 @@ class Trainer(
|
|||
# used for testing or when we need to know that training succeeded
|
||||
return results or 1
|
||||
|
||||
def attach_data(self, model, train_dataloader, val_dataloaders):
|
||||
# if a datamodule comes in as the second arg, then fix it for the user
|
||||
if isinstance(train_dataloader, LightningDataModule):
|
||||
datamodule = train_dataloader
|
||||
train_dataloader = None
|
||||
|
||||
self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# set up the passed in dataloaders (if needed)
|
||||
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
|
||||
self.__attach_datamodule(model, datamodule, 'fit')
|
||||
|
||||
def select_accelerator(self):
|
||||
# SLURM ddp
|
||||
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
|
||||
|
|
Loading…
Reference in New Issue