lightning/pytorch_lightning/accelerators/tpu_backend.py

253 lines
8.3 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import torch.multiprocessing as mp
from pytorch_lightning import _logger as log
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.accelerators.base_backend import Accelerator
try:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as xla_pl
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
class TPUBackend(Accelerator):
def __init__(self, trainer):
super().__init__(trainer)
self.start_method = None
self.mp_queue = None
def setup(self, model):
rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores')
if not XLA_AVAILABLE:
raise MisconfigurationException('PyTorch XLA not installed.')
# see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2
self.start_method = 'fork'
# pass in a state q
smp = mp.get_context(self.start_method)
self.mp_queue = smp.SimpleQueue()
self.trainer.model = model
def teardown(self):
model = self.trainer.model
# restore main state with best weights
best_path = self.mp_queue.get()
results = self.mp_queue.get()
last_path = self.mp_queue.get()
# transfer back the best path to the trainer
self.trainer.checkpoint_callback.best_model_path = best_path
# todo, pass also bets score
# load last weights
if last_path and not self.trainer.testing:
ckpt = torch.load(last_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt)
self.trainer.model = model
# when training completes, load the weights back in main process
self.__load_weights_on_main_process()
return results
def train(self):
model = self.trainer.model
# train
if self.trainer.tpu_id is not None:
self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue)
else:
xmp.spawn(
self.tpu_train_in_process,
args=(model, self.trainer, self.mp_queue),
nprocs=self.trainer.tpu_cores,
start_method=self.start_method
)
def __load_weights_on_main_process(self):
model = self.trainer.model
# load weights if not interrupted
if self.trainer.on_colab_kaggle and not self.trainer.testing:
self.trainer.load_spawn_weights(model)
self.trainer.model = model
def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None):
"""
Here we are inside each individual process
"""
if not trainer:
trainer = self.trainer
trainer.call_setup_hook(model)
# setup TPU training
self.__setup_tpu_training(model, trainer)
# set up training routine
self.trainer.setup_training(model)
# train or test
results = self.trainer.train_or_test()
# save weights at the end of training
self.__save_end_of_training_weights(model, trainer)
# persist info in spawn
trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results)
def training_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.training_step(*args)
return output
def validation_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.validation_step(*args)
return output
def test_step(self, args):
batch = args[0]
batch = self.to_device(batch)
args[0] = batch
output = self.trainer.model.test_step(*args)
return output
def process_dataloader(self, dataloader):
device = xm.xla_device(self.trainer.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)
return dataloader
def to_device(self, batch):
"""
Transfers the data to the TPU.
Args:
batch: A tensor or collection of tensors.
tpu_id: The id of the TPU core. If omitted, the first available core is chosen.
Return:
the tensor on the TPU device.
See Also:
- :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device`
"""
if not XLA_AVAILABLE:
raise MisconfigurationException(
'Requested to transfer batch to TPU but XLA is not available.'
' Are you sure this machine has TPUs?'
)
device = xm.xla_device(self.trainer.tpu_id)
return self.batch_to_device(batch, device)
def __save_end_of_training_weights(self, model: LightningModule, trainer):
# when training ends on these platforms dump weights to get out of the main process
if trainer.on_colab_kaggle:
rank_zero_warn('cleaning up... please do not interrupt')
trainer.save_spawn_weights(model)
def __setup_tpu_training(self, model: LightningModule, trainer):
# use the default device from the process
# tpu_device = xm.xla_device()
# if given an ordinal device, use this as the device
if trainer.tpu_id is not None:
tpu_device = xm.xla_device(trainer.tpu_id)
else:
tpu_device = xm.xla_device()
# track the device and move model to it
trainer._device = tpu_device
model.to(trainer._device)
# get the appropriate tpu ranks
trainer.tpu_local_core_rank = xm.get_local_ordinal()
trainer.tpu_global_core_rank = xm.get_ordinal()
# avoid duplicating progress bar
if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()
trainer.global_rank = trainer.tpu_local_core_rank
rank_zero_only.rank = trainer.global_rank
# CHOOSE OPTIMIZER
# allow for lr schedulers as well
optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model)
trainer.optimizers = optimizers
trainer.lr_schedulers = lr_schedulers
trainer.optimizer_frequencies = optimizer_frequencies
# init 16 bit for TPU
if trainer.precision == 16:
os.environ['XLA_USE_BF16'] = str(1)
log.info(f'INIT TPU local core: {trainer.tpu_local_core_rank},'
f' global rank: {trainer.tpu_global_core_rank}'
f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}')
def backward(self, closure_loss, optimizer, opt_idx):
model_ref = self.trainer.get_model()
# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)
# detach after backward
closure_loss = closure_loss.detach()
return closure_loss
def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure):
model_ref = self.trainer.get_model()
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)
# model hook
model_ref.optimizer_step(
self.trainer.current_epoch,
batch_idx, optimizer,
opt_idx,
lambda_closure,
on_tpu=True,
using_lbfgs=is_lbfgs
)
def clip_gradients(self, optimizer):
# apply clip gradients
# TODO: separate TPU case from here
self._clip_gradients(optimizer)