diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index 9329fd45c6..fd46d4e4ca 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -278,6 +278,9 @@ Doing it in the `prepare_data` method ensures that when you have multiple GPUs you won't overwrite the data. This is a contrived example but it gets more complicated with things like NLP or Imagenet. +`prepare_data` gets called on the `LOCAL_RANK=0` GPU per node. If your nodes share a file system, +set `Trainer(prepare_data_per_node=False)` and it will be code from node=0, gpu=0 only. + In general fill these methods with the following: .. testcode:: @@ -535,16 +538,21 @@ will cause all sorts of issues. To solve this problem, move the download code to the `prepare_data` method in the LightningModule. In this method we do all the preparation we need to do once (instead of on every gpu). +`prepare_data` can be called in two ways, once per node or only on the root node (`Trainer(prepare_data_per_node=False)`). + .. testcode:: class LitMNIST(LightningModule): def prepare_data(self): + # download only + MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) + MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) + + def setup(self, stage): # transform transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) - - # download - mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform) - mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform) + MNIST(os.getcwd(), train=True, download=False, transform=transform) + MNIST(os.getcwd(), train=False, download=False, transform=transform) # train/val split mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) diff --git a/pl_examples/domain_templates/computer_vision_fine_tuning.py b/pl_examples/domain_templates/computer_vision_fine_tuning.py index e72d7a9f03..973c9f09c5 100644 --- a/pl_examples/domain_templates/computer_vision_fine_tuning.py +++ b/pl_examples/domain_templates/computer_vision_fine_tuning.py @@ -307,12 +307,11 @@ class TransferLearningModel(pl.LightningModule): def prepare_data(self): """Download images and prepare images datasets.""" - - # 1. Download the images download_and_extract_archive(url=DATA_URL, download_root=self.dl_path, remove_finished=True) + def setup(self, stage: str): data_path = Path(self.dl_path).joinpath('cats_and_dogs_filtered') # 2. Load the data + preprocessing & data augmentation diff --git a/pl_examples/models/lightning_template.py b/pl_examples/models/lightning_template.py index dba605f5ca..391d8d99f0 100644 --- a/pl_examples/models/lightning_template.py +++ b/pl_examples/models/lightning_template.py @@ -141,10 +141,13 @@ class LightningTemplateModel(LightningModule): return [optimizer], [scheduler] def prepare_data(self): - transform = transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.5,), (1.0,))]) - self.mnist_train = MNIST(self.data_root, train=True, download=True, transform=transform) - self.mnist_test = MNIST(self.data_root, train=False, download=True, transform=transform) + MNIST(self.data_root, train=True, download=True, transform=transforms.ToTensor()) + MNIST(self.data_root, train=False, download=True, transform=transforms.ToTensor()) + + def setup(self, stage): + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))]) + self.mnist_train = MNIST(self.data_root, train=True, download=False, transform=transform) + self.mnist_test = MNIST(self.data_root, train=False, download=False, transform=transform) def train_dataloader(self): log.info('Training data loader called.') diff --git a/pytorch_lightning/core/__init__.py b/pytorch_lightning/core/__init__.py index 83fff7d862..e0aa8cf30d 100644 --- a/pytorch_lightning/core/__init__.py +++ b/pytorch_lightning/core/__init__.py @@ -267,11 +267,12 @@ allow for this: >>> class LitModel(pl.LightningModule): ... def prepare_data(self): ... # download - ... mnist_train = MNIST(os.getcwd(), train=True, download=True, - ... transform=transforms.ToTensor()) - ... mnist_test = MNIST(os.getcwd(), train=False, download=True, - ... transform=transforms.ToTensor()) + ... MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()) + ... MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()) ... + ... def setup(self, stage): + ... mnist_train = MNIST(os.getcwd(), train=True, download=False, transform=transforms.ToTensor()) + ... mnist_test = MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()) ... # train/val split ... mnist_train, mnist_val = random_split(mnist_train, [55000, 5000]) ... diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index e4c5e67d85..a4f52711f9 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -22,7 +22,7 @@ class ModelHooks(Module): Called at the beginning of fit and test. Args: - step: either 'fit' or 'test' + stage: either 'fit' or 'test' """ def teardown(self, stage: str): diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index e5c470e3cf..df163aaa10 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1290,7 +1290,8 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod def prepare_data(self) -> None: """ Use this to download and prepare data. - In distributed (GPU, TPU), this will only be called once. + In distributed (GPU, TPU), this will only be called once on the local_rank=0 of each node. + To call this on only the root=0 of the main node, use `Trainer(prepare_data_per_node=False)` This is called before requesting the dataloaders: .. code-block:: python diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 1347b38ee9..44dba72270 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -104,7 +104,7 @@ class EvalModelTemplate( return nll def prepare_data(self): - _ = TrialMNIST(root=self.data_root, train=True, download=True) + TrialMNIST(root=self.data_root, train=True, download=True) @staticmethod def get_default_hparams(continue_training: bool = False, hpc_exp_number: int = 0) -> dict: diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py index 16ea2038c1..d2bafe6d7c 100644 --- a/tests/callbacks/test_callbacks.py +++ b/tests/callbacks/test_callbacks.py @@ -69,7 +69,7 @@ def test_trainer_callback_system(tmpdir): self.on_test_start_called = False self.on_test_end_called = False - def setup(self, trainer, step: str): + def setup(self, trainer, stage: str): assert isinstance(trainer, Trainer) self.setup_called = True