From f064d74be8678d11f47228b9d23b091f23f6ec38 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 24 Aug 2020 21:53:56 -0400 Subject: [PATCH] refactored dataloader process hook (#3139) --- pytorch_lightning/accelerators/base_backend.py | 3 +++ pytorch_lightning/accelerators/tpu_backend.py | 7 +++++++ pytorch_lightning/trainer/evaluation_loop.py | 5 +---- pytorch_lightning/trainer/training_loop.py | 11 +---------- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/accelerators/base_backend.py b/pytorch_lightning/accelerators/base_backend.py index 268f7ff5d2..f0f7918103 100644 --- a/pytorch_lightning/accelerators/base_backend.py +++ b/pytorch_lightning/accelerators/base_backend.py @@ -22,3 +22,6 @@ class Accelerator(object): def validation_step_end(self, output): return output + + def process_dataloader(self, dataloader): + return dataloader diff --git a/pytorch_lightning/accelerators/tpu_backend.py b/pytorch_lightning/accelerators/tpu_backend.py index 1522e51afe..1e807cd6fa 100644 --- a/pytorch_lightning/accelerators/tpu_backend.py +++ b/pytorch_lightning/accelerators/tpu_backend.py @@ -27,6 +27,7 @@ try: import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp + import torch_xla.distributed.parallel_loader as xla_pl except ImportError: XLA_AVAILABLE = False else: @@ -139,6 +140,12 @@ class TPUBackend(Accelerator): output = self.trainer.model.test_step(*args) return output + def process_dataloader(self, dataloader): + device = xm.xla_device(self.trainer.tpu_id) + dataloader = xla_pl.ParallelLoader(dataloader, [device]) + dataloader = dataloader.per_device_loader(device) + return dataloader + def to_device(self, batch): """ Transfers the data to the TPU. diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index e7dca236b7..2911bff94f 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -260,10 +260,7 @@ class TrainerEvaluationLoopMixin(ABC): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader - if self.use_tpu: - device = xm.xla_device(self.tpu_id) - dataloader = xla_pl.ParallelLoader(dataloader, [device]) - dataloader = dataloader.per_device_loader(device) + dataloader = self.accelerator_backend.process_dataloader(dataloader) # each dataloader has a max num batches dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 617db687cd..ae5e7a2f30 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -427,15 +427,6 @@ class TrainerTrainLoopMixin(ABC): self.run_training_teardown() - def prepare_train_loop_dataloader(self, train_dataloader): - # on TPU we have to wrap it under the ParallelLoader - if self.use_tpu: - device = xm.xla_device(self.tpu_id) - train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) - train_dataloader = train_dataloader.per_device_loader(device) - - return train_dataloader - def run_on_epoch_start_hook(self, model): # Epoch start events with self.profiler.profile('on_epoch_start'): @@ -464,7 +455,7 @@ class TrainerTrainLoopMixin(ABC): self.run_on_epoch_start_hook(model) # modify dataloader if needed (ddp, etc...) - train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader) + train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader) # bookkeeping num_optimizers = len(self._get_optimizers_iterable())