[Hot Fix] Ensure process_dataloader is called when tpu_cores > 1 to use Parallel DataLoader (#6015)

* hotfix for tpu

* update changelog

* Update CHANGELOG.md

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
chaton 2021-02-16 22:02:25 +00:00 committed by GitHub
parent c9fde04947
commit a52be5bb07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 16 additions and 14 deletions

View File

@ -276,6 +276,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed passing wrong strings for scheduler interval doesn't throw an error ([#5923](https://github.com/PyTorchLightning/pytorch-lightning/pull/5923))
- Fixed missing `process_dataloader` call for `TPUSpawn` when in distributed mode ([#6015](https://github.com/PyTorchLightning/pytorch-lightning/pull/6015))
## [1.1.8] - 2021-02-08
### Fixed

View File

@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Iterable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.precision import (
@ -226,14 +225,6 @@ class Accelerator(object):
args[0] = batch
return self.training_type_plugin.predict(*args)
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary
Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return dataloader
def backward(
self,
closure_loss: torch.Tensor,

View File

@ -11,13 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from abc import ABC, abstractmethod
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, Iterable, Optional, TYPE_CHECKING, Union
import torch
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.base import unwrap_lightning_module
@ -144,3 +144,11 @@ class TrainingTypePlugin(Plugin, ABC):
def on_save(self, checkpoint: dict) -> dict:
return checkpoint
def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]:
"""Wraps the dataloader if necessary
Args:
dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader`
"""
return dataloader

View File

@ -668,7 +668,7 @@ class Trainer(
for dataloader_idx, dataloader in enumerate(dataloaders):
# bookkeeping
dl_outputs = []
dataloader = self.accelerator_backend.process_dataloader(dataloader)
dataloader = self.training_type_plugin.process_dataloader(dataloader)
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
for batch_idx, batch in enumerate(dataloader):

View File

@ -502,7 +502,7 @@ class TrainLoop:
def run_training_epoch(self):
# modify dataloader if needed (ddp, etc...)
train_dataloader = self.trainer.accelerator_backend.process_dataloader(self.trainer.train_dataloader)
train_dataloader = self.trainer.training_type_plugin.process_dataloader(self.trainer.train_dataloader)
# track epoch output
epoch_output = [[] for _ in range(self.num_optimizers)]