From a52be5bb0790113a69d9965f2e9ebd84f40acb06 Mon Sep 17 00:00:00 2001 From: chaton Date: Tue, 16 Feb 2021 22:02:25 +0000 Subject: [PATCH] [Hot Fix] Ensure process_dataloader is called when tpu_cores > 1 to use Parallel DataLoader (#6015) * hotfix for tpu * update changelog * Update CHANGELOG.md Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Sean Naren --- CHANGELOG.md | 3 +++ pytorch_lightning/accelerators/accelerator.py | 11 +---------- .../plugins/training_type/training_type_plugin.py | 12 ++++++++++-- pytorch_lightning/trainer/trainer.py | 2 +- pytorch_lightning/trainer/training_loop.py | 2 +- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b542474ef..13c11163bb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -276,6 +276,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed passing wrong strings for scheduler interval doesn't throw an error ([#5923](https://github.com/PyTorchLightning/pytorch-lightning/pull/5923)) +- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015)) + + ## [1.1.8] - 2021-02-08 ### Fixed diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index 4f4b10e273..893456f403 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, Optional, TYPE_CHECKING, Union import torch from torch.optim import Optimizer -from torch.utils.data import DataLoader from pytorch_lightning.core import LightningModule from pytorch_lightning.plugins.precision import ( @@ -226,14 +225,6 @@ class Accelerator(object): args[0] = batch return self.training_type_plugin.predict(*args) - def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: - """Wraps the dataloader if necessary - - Args: - dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` - """ - return dataloader - def backward( self, closure_loss: torch.Tensor, diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index db0e390c4b..74f5837afc 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from abc import ABC, abstractmethod -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, Iterable, Optional, TYPE_CHECKING, Union import torch from torch.nn import Module from torch.optim import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.base import unwrap_lightning_module @@ -144,3 +144,11 @@ class TrainingTypePlugin(Plugin, ABC): def on_save(self, checkpoint: dict) -> dict: return checkpoint + + def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + """Wraps the dataloader if necessary + + Args: + dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ + return dataloader diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 72236b1589..db04734d2f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -668,7 +668,7 @@ class Trainer( for dataloader_idx, dataloader in enumerate(dataloaders): # bookkeeping dl_outputs = [] - dataloader = self.accelerator_backend.process_dataloader(dataloader) + dataloader = self.training_type_plugin.process_dataloader(dataloader) dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f7f44625a3..1640afe97f 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -502,7 +502,7 @@ class TrainLoop: def run_training_epoch(self): # modify dataloader if needed (ddp, etc...) - train_dataloader = self.trainer.accelerator_backend.process_dataloader(self.trainer.train_dataloader) + train_dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader) # track epoch output epoch_output = [[] for _ in range(self.num_optimizers)]