ref: move prepare_data to data connector (#3307)
* ref: moved argparse code to central class * ref: moved argparse code to central class * ref: moved argparse code to central class
This commit is contained in:
parent
3910ad0330
commit
7d57f8d407
|
@ -24,6 +24,25 @@ class DataConnector(object):
|
|||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def prepare_data(self, model):
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
||||
if self.can_prepare_data():
|
||||
if self.trainer.datamodule is not None:
|
||||
self.trainer.datamodule.prepare_data()
|
||||
model.prepare_data()
|
||||
self.trainer._is_data_prepared = True
|
||||
|
||||
def can_prepare_data(self):
|
||||
should_call_dm_prepare_data = True
|
||||
if self.trainer.datamodule is not None and is_overridden('prepare_data', self.trainer.datamodule):
|
||||
should_call_dm_prepare_data = not self.trainer.datamodule.has_prepared_data
|
||||
|
||||
if self.trainer.prepare_data_per_node:
|
||||
return self.trainer.local_rank == 0 and should_call_dm_prepare_data
|
||||
else:
|
||||
return self.trainer.node_rank == 0 and self.trainer.local_rank == 0 and should_call_dm_prepare_data
|
||||
|
||||
def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
# if a datamodule comes in as the second arg, then fix it for the user
|
||||
if isinstance(train_dataloader, LightningDataModule):
|
||||
|
|
|
@ -944,7 +944,7 @@ class Trainer(
|
|||
self.call_hook('on_fit_start', model)
|
||||
|
||||
# hook
|
||||
self.prepare_data(model)
|
||||
self.data_connector.prepare_data(model)
|
||||
|
||||
# Run auto batch size scaling
|
||||
if self.auto_scale_batch_size:
|
||||
|
@ -1014,7 +1014,7 @@ class Trainer(
|
|||
self.call_hook('on_fit_start', model)
|
||||
|
||||
# hook
|
||||
self.prepare_data(model)
|
||||
self.data_connector.prepare_data(model)
|
||||
|
||||
# set testing if set in environ
|
||||
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
|
||||
|
@ -1056,15 +1056,6 @@ class Trainer(
|
|||
# check that model is configured correctly
|
||||
self.config_validator.verify_loop_configurations(model)
|
||||
|
||||
def prepare_data(self, model):
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
||||
if self.can_prepare_data():
|
||||
if self.datamodule is not None:
|
||||
self.datamodule.prepare_data()
|
||||
model.prepare_data()
|
||||
self._is_data_prepared = True
|
||||
|
||||
def select_accelerator(self):
|
||||
# SLURM ddp
|
||||
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
|
||||
|
@ -1108,16 +1099,6 @@ class Trainer(
|
|||
|
||||
return accelerator_backend
|
||||
|
||||
def can_prepare_data(self):
|
||||
should_call_dm_prepare_data = True
|
||||
if self.datamodule is not None and is_overridden('prepare_data', self.datamodule):
|
||||
should_call_dm_prepare_data = not self.datamodule.has_prepared_data
|
||||
|
||||
if self.prepare_data_per_node:
|
||||
return self.local_rank == 0 and should_call_dm_prepare_data
|
||||
else:
|
||||
return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data
|
||||
|
||||
def setup_training(self, model: LightningModule):
|
||||
"""Sanity check a few things before starting actual training.
|
||||
|
||||
|
|
|
@ -23,26 +23,26 @@ def test_can_prepare_data(tmpdir):
|
|||
# local rank = 0 (True)
|
||||
trainer.prepare_data_per_node = True
|
||||
trainer.local_rank = 0
|
||||
assert trainer.can_prepare_data()
|
||||
assert trainer.data_connector.can_prepare_data()
|
||||
|
||||
# local rank = 1 (False)
|
||||
trainer.local_rank = 1
|
||||
assert not trainer.can_prepare_data()
|
||||
assert not trainer.data_connector.can_prepare_data()
|
||||
|
||||
# prepare_data_per_node = False (prepare across all nodes)
|
||||
# global rank = 0 (True)
|
||||
trainer.prepare_data_per_node = False
|
||||
trainer.node_rank = 0
|
||||
trainer.local_rank = 0
|
||||
assert trainer.can_prepare_data()
|
||||
assert trainer.data_connector.can_prepare_data()
|
||||
|
||||
# global rank = 1 (False)
|
||||
trainer.node_rank = 1
|
||||
trainer.local_rank = 0
|
||||
assert not trainer.can_prepare_data()
|
||||
assert not trainer.data_connector.can_prepare_data()
|
||||
trainer.node_rank = 0
|
||||
trainer.local_rank = 1
|
||||
assert not trainer.can_prepare_data()
|
||||
assert not trainer.data_connector.can_prepare_data()
|
||||
|
||||
# 2 dm
|
||||
# prepar per node = True
|
||||
|
@ -54,17 +54,17 @@ def test_can_prepare_data(tmpdir):
|
|||
# has been called
|
||||
# False
|
||||
dm._has_prepared_data = True
|
||||
assert not trainer.can_prepare_data()
|
||||
assert not trainer.data_connector.can_prepare_data()
|
||||
|
||||
# has not been called
|
||||
# True
|
||||
dm._has_prepared_data = False
|
||||
assert trainer.can_prepare_data()
|
||||
assert trainer.data_connector.can_prepare_data()
|
||||
|
||||
# is_overridden prepare data = False
|
||||
# True
|
||||
dm.prepare_data = None
|
||||
assert trainer.can_prepare_data()
|
||||
assert trainer.data_connector.can_prepare_data()
|
||||
|
||||
|
||||
def test_base_datamodule(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue