refactored dataloader process hook (#3139)

This commit is contained in:
William Falcon 2020-08-24 21:53:56 -04:00 committed by GitHub
parent 229b87655a
commit f064d74be8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 14 deletions

View File

@ -22,3 +22,6 @@ class Accelerator(object):
def validation_step_end(self, output):
return output
def process_dataloader(self, dataloader):
return dataloader

View File

@ -27,6 +27,7 @@ try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as xla_pl
except ImportError:
XLA_AVAILABLE = False
else:
@ -139,6 +140,12 @@ class TPUBackend(Accelerator):
output = self.trainer.model.test_step(*args)
return output
def process_dataloader(self, dataloader):
device = xm.xla_device(self.trainer.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
return dataloader
def to_device(self, batch):
"""
Transfers the data to the TPU.

View File

@ -260,10 +260,7 @@ class TrainerEvaluationLoopMixin(ABC):
dl_outputs = []
# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
dataloader = self.accelerator_backend.process_dataloader(dataloader)
# each dataloader has a max num batches
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]

View File

@ -427,15 +427,6 @@ class TrainerTrainLoopMixin(ABC):
self.run_training_teardown()
def prepare_train_loop_dataloader(self, train_dataloader):
# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)
return train_dataloader
def run_on_epoch_start_hook(self, model):
# Epoch start events
with self.profiler.profile('on_epoch_start'):
@ -464,7 +455,7 @@ class TrainerTrainLoopMixin(ABC):
self.run_on_epoch_start_hook(model)
# modify dataloader if needed (ddp, etc...)
train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)
train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)
# bookkeeping
num_optimizers = len(self._get_optimizers_iterable())