refactored dataloader process hook (#3139)
This commit is contained in:
parent
229b87655a
commit
f064d74be8
|
@ -22,3 +22,6 @@ class Accelerator(object):
|
|||
|
||||
def validation_step_end(self, output):
|
||||
return output
|
||||
|
||||
def process_dataloader(self, dataloader):
|
||||
return dataloader
|
||||
|
|
|
@ -27,6 +27,7 @@ try:
|
|||
import torch_xla
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
import torch_xla.distributed.parallel_loader as xla_pl
|
||||
except ImportError:
|
||||
XLA_AVAILABLE = False
|
||||
else:
|
||||
|
@ -139,6 +140,12 @@ class TPUBackend(Accelerator):
|
|||
output = self.trainer.model.test_step(*args)
|
||||
return output
|
||||
|
||||
def process_dataloader(self, dataloader):
|
||||
device = xm.xla_device(self.trainer.tpu_id)
|
||||
dataloader = xla_pl.ParallelLoader(dataloader, [device])
|
||||
dataloader = dataloader.per_device_loader(device)
|
||||
return dataloader
|
||||
|
||||
def to_device(self, batch):
|
||||
"""
|
||||
Transfers the data to the TPU.
|
||||
|
|
|
@ -260,10 +260,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
dl_outputs = []
|
||||
|
||||
# on TPU we have to wrap it under the ParallelLoader
|
||||
if self.use_tpu:
|
||||
device = xm.xla_device(self.tpu_id)
|
||||
dataloader = xla_pl.ParallelLoader(dataloader, [device])
|
||||
dataloader = dataloader.per_device_loader(device)
|
||||
dataloader = self.accelerator_backend.process_dataloader(dataloader)
|
||||
|
||||
# each dataloader has a max num batches
|
||||
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
||||
|
|
|
@ -427,15 +427,6 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
self.run_training_teardown()
|
||||
|
||||
def prepare_train_loop_dataloader(self, train_dataloader):
|
||||
# on TPU we have to wrap it under the ParallelLoader
|
||||
if self.use_tpu:
|
||||
device = xm.xla_device(self.tpu_id)
|
||||
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
|
||||
train_dataloader = train_dataloader.per_device_loader(device)
|
||||
|
||||
return train_dataloader
|
||||
|
||||
def run_on_epoch_start_hook(self, model):
|
||||
# Epoch start events
|
||||
with self.profiler.profile('on_epoch_start'):
|
||||
|
@ -464,7 +455,7 @@ class TrainerTrainLoopMixin(ABC):
|
|||
self.run_on_epoch_start_hook(model)
|
||||
|
||||
# modify dataloader if needed (ddp, etc...)
|
||||
train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)
|
||||
train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)
|
||||
|
||||
# bookkeeping
|
||||
num_optimizers = len(self._get_optimizers_iterable())
|
||||
|
|
Loading…
Reference in New Issue