ref: move prepare_data to data connector (#3307)

* ref: moved argparse code to central class

* ref: moved argparse code to central class

* ref: moved argparse code to central class
This commit is contained in:
William Falcon 2020-09-01 14:59:09 -04:00 committed by GitHub
parent 3910ad0330
commit 7d57f8d407
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 29 additions and 29 deletions

View File

@ -24,6 +24,25 @@ class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer
def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
if self.trainer.datamodule is not None:
self.trainer.datamodule.prepare_data()
model.prepare_data()
self.trainer._is_data_prepared = True
def can_prepare_data(self):
should_call_dm_prepare_data = True
if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule):
should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data
if self.trainer.prepare_data_per_node:
return self.trainer.local_rank == 0 and should_call_dm_prepare_data
else:
return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):

View File

@ -944,7 +944,7 @@ class Trainer(
self.call_hook('on_fit_start', model)
# hook
self.prepare_data(model)
self.data_connector.prepare_data(model)
# Run auto batch size scaling
if self.auto_scale_batch_size:
@ -1014,7 +1014,7 @@ class Trainer(
self.call_hook('on_fit_start', model)
# hook
self.prepare_data(model)
self.data_connector.prepare_data(model)
# set testing if set in environ
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
@ -1056,15 +1056,6 @@ class Trainer(
# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)
def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
if self.can_prepare_data():
if self.datamodule is not None:
self.datamodule.prepare_data()
model.prepare_data()
self._is_data_prepared = True
def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
@ -1108,16 +1099,6 @@ class Trainer(
return accelerator_backend
def can_prepare_data(self):
should_call_dm_prepare_data = True
if self.datamodule is not None and is_overridden('prepare_data', self.datamodule):
should_call_dm_prepare_data = not self.datamodule.has_prepared_data
if self.prepare_data_per_node:
return self.local_rank == 0 and should_call_dm_prepare_data
else:
return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data
def setup_training(self, model: LightningModule):
"""Sanity check a few things before starting actual training.

View File

@ -23,26 +23,26 @@ def test_can_prepare_data(tmpdir):
# local rank = 0 (True)
trainer.prepare_data_per_node = True
trainer.local_rank = 0
assert trainer.can_prepare_data()
assert trainer.data_connector.can_prepare_data()
# local rank = 1 (False)
trainer.local_rank = 1
assert not trainer.can_prepare_data()
assert not trainer.data_connector.can_prepare_data()
# prepare_data_per_node = False (prepare across all nodes)
# global rank = 0 (True)
trainer.prepare_data_per_node = False
trainer.node_rank = 0
trainer.local_rank = 0
assert trainer.can_prepare_data()
assert trainer.data_connector.can_prepare_data()
# global rank = 1 (False)
trainer.node_rank = 1
trainer.local_rank = 0
assert not trainer.can_prepare_data()
assert not trainer.data_connector.can_prepare_data()
trainer.node_rank = 0
trainer.local_rank = 1
assert not trainer.can_prepare_data()
assert not trainer.data_connector.can_prepare_data()
# 2 dm
# prepar per node = True
@ -54,17 +54,17 @@ def test_can_prepare_data(tmpdir):
# has been called
# False
dm._has_prepared_data = True
assert not trainer.can_prepare_data()
assert not trainer.data_connector.can_prepare_data()
# has not been called
# True
dm._has_prepared_data = False
assert trainer.can_prepare_data()
assert trainer.data_connector.can_prepare_data()
# is_overridden prepare data = False
# True
dm.prepare_data = None
assert trainer.can_prepare_data()
assert trainer.data_connector.can_prepare_data()
def test_base_datamodule(tmpdir):