refactored dataloader process hook (#3139)
This commit is contained in:
parent
229b87655a
commit
f064d74be8
|
@ -22,3 +22,6 @@ class Accelerator(object):
|
||||||
|
|
||||||
def validation_step_end(self, output):
|
def validation_step_end(self, output):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def process_dataloader(self, dataloader):
|
||||||
|
return dataloader
|
||||||
|
|
|
@ -27,6 +27,7 @@ try:
|
||||||
import torch_xla
|
import torch_xla
|
||||||
import torch_xla.core.xla_model as xm
|
import torch_xla.core.xla_model as xm
|
||||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||||
|
import torch_xla.distributed.parallel_loader as xla_pl
|
||||||
except ImportError:
|
except ImportError:
|
||||||
XLA_AVAILABLE = False
|
XLA_AVAILABLE = False
|
||||||
else:
|
else:
|
||||||
|
@ -139,6 +140,12 @@ class TPUBackend(Accelerator):
|
||||||
output = self.trainer.model.test_step(*args)
|
output = self.trainer.model.test_step(*args)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def process_dataloader(self, dataloader):
|
||||||
|
device = xm.xla_device(self.trainer.tpu_id)
|
||||||
|
dataloader = xla_pl.ParallelLoader(dataloader, [device])
|
||||||
|
dataloader = dataloader.per_device_loader(device)
|
||||||
|
return dataloader
|
||||||
|
|
||||||
def to_device(self, batch):
|
def to_device(self, batch):
|
||||||
"""
|
"""
|
||||||
Transfers the data to the TPU.
|
Transfers the data to the TPU.
|
||||||
|
|
|
@ -260,10 +260,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
||||||
dl_outputs = []
|
dl_outputs = []
|
||||||
|
|
||||||
# on TPU we have to wrap it under the ParallelLoader
|
# on TPU we have to wrap it under the ParallelLoader
|
||||||
if self.use_tpu:
|
dataloader = self.accelerator_backend.process_dataloader(dataloader)
|
||||||
device = xm.xla_device(self.tpu_id)
|
|
||||||
dataloader = xla_pl.ParallelLoader(dataloader, [device])
|
|
||||||
dataloader = dataloader.per_device_loader(device)
|
|
||||||
|
|
||||||
# each dataloader has a max num batches
|
# each dataloader has a max num batches
|
||||||
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
||||||
|
|
|
@ -427,15 +427,6 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
|
|
||||||
self.run_training_teardown()
|
self.run_training_teardown()
|
||||||
|
|
||||||
def prepare_train_loop_dataloader(self, train_dataloader):
|
|
||||||
# on TPU we have to wrap it under the ParallelLoader
|
|
||||||
if self.use_tpu:
|
|
||||||
device = xm.xla_device(self.tpu_id)
|
|
||||||
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
|
|
||||||
train_dataloader = train_dataloader.per_device_loader(device)
|
|
||||||
|
|
||||||
return train_dataloader
|
|
||||||
|
|
||||||
def run_on_epoch_start_hook(self, model):
|
def run_on_epoch_start_hook(self, model):
|
||||||
# Epoch start events
|
# Epoch start events
|
||||||
with self.profiler.profile('on_epoch_start'):
|
with self.profiler.profile('on_epoch_start'):
|
||||||
|
@ -464,7 +455,7 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
self.run_on_epoch_start_hook(model)
|
self.run_on_epoch_start_hook(model)
|
||||||
|
|
||||||
# modify dataloader if needed (ddp, etc...)
|
# modify dataloader if needed (ddp, etc...)
|
||||||
train_dataloader = self.prepare_train_loop_dataloader(self.train_dataloader)
|
train_dataloader = self.accelerator_backend.process_dataloader(self.train_dataloader)
|
||||||
|
|
||||||
# bookkeeping
|
# bookkeeping
|
||||||
num_optimizers = len(self._get_optimizers_iterable())
|
num_optimizers = len(self._get_optimizers_iterable())
|
||||||
|
|
Loading…
Reference in New Issue