# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import torch import torch.multiprocessing as mp from pytorch_lightning import _logger as log from pytorch_lightning.core import LightningModule from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn, AMPType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.accelerators.base_backend import Accelerator try: import torch_xla import torch_xla.core.xla_model as xm import torch_xla.distributed.xla_multiprocessing as xmp import torch_xla.distributed.parallel_loader as xla_pl except ImportError: XLA_AVAILABLE = False else: XLA_AVAILABLE = True class TPUBackend(Accelerator): def __init__(self, trainer): super().__init__(trainer) self.start_method = None self.mp_queue = None def setup(self, model): rank_zero_info(f'training on {self.trainer.tpu_cores} TPU cores') if not XLA_AVAILABLE: raise MisconfigurationException('PyTorch XLA not installed.') # see: https://discuss.pytorch.org/t/segfault-with-multiprocessing-queue/81292/2 self.start_method = 'fork' # pass in a state q smp = mp.get_context(self.start_method) self.mp_queue = smp.SimpleQueue() self.trainer.model = model def teardown(self): model = self.trainer.model # restore main state with best weights best_path = self.mp_queue.get() results = self.mp_queue.get() last_path = self.mp_queue.get() # transfer back the best path to the trainer self.trainer.checkpoint_callback.best_model_path = best_path # todo, pass also bets score # load last weights if last_path and not self.trainer.testing: ckpt = torch.load(last_path, map_location=lambda storage, loc: storage) model.load_state_dict(ckpt) self.trainer.model = model # when training completes, load the weights back in main process self.__load_weights_on_main_process() return results def train(self): model = self.trainer.model # train if self.trainer.tpu_id is not None: self.tpu_train_in_process(self.trainer.tpu_id, model, self.trainer, self.mp_queue) else: xmp.spawn( self.tpu_train_in_process, args=(model, self.trainer, self.mp_queue), nprocs=self.trainer.tpu_cores, start_method=self.start_method ) def __load_weights_on_main_process(self): model = self.trainer.model # load weights if not interrupted if self.trainer.on_colab_kaggle and not self.trainer.testing: self.trainer.load_spawn_weights(model) self.trainer.model = model def tpu_train_in_process(self, tpu_core_idx: int, model: LightningModule, trainer=None, mp_queue=None): """ Here we are inside each individual process """ if not trainer: trainer = self.trainer trainer.call_setup_hook(model) # setup TPU training self.__setup_tpu_training(model, trainer) # set up training routine self.trainer.setup_training(model) # train or test results = self.trainer.train_or_test() # save weights at the end of training self.__save_end_of_training_weights(model, trainer) # persist info in spawn trainer.transfer_distrib_spawn_state_on_fit_end(model, mp_queue, results) def training_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.training_step(*args) return output def validation_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.validation_step(*args) return output def test_step(self, args): batch = args[0] batch = self.to_device(batch) args[0] = batch output = self.trainer.model.test_step(*args) return output def process_dataloader(self, dataloader): device = xm.xla_device(self.trainer.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) return dataloader def to_device(self, batch): """ Transfers the data to the TPU. Args: batch: A tensor or collection of tensors. tpu_id: The id of the TPU core. If omitted, the first available core is chosen. Return: the tensor on the TPU device. See Also: - :func:`~pytorch_lightning.utilities.apply_func.move_data_to_device` """ if not XLA_AVAILABLE: raise MisconfigurationException( 'Requested to transfer batch to TPU but XLA is not available.' ' Are you sure this machine has TPUs?' ) device = xm.xla_device(self.trainer.tpu_id) return self.batch_to_device(batch, device) def __save_end_of_training_weights(self, model: LightningModule, trainer): # when training ends on these platforms dump weights to get out of the main process if trainer.on_colab_kaggle: rank_zero_warn('cleaning up... please do not interrupt') trainer.save_spawn_weights(model) def __setup_tpu_training(self, model: LightningModule, trainer): # use the default device from the process # tpu_device = xm.xla_device() # if given an ordinal device, use this as the device if trainer.tpu_id is not None: tpu_device = xm.xla_device(trainer.tpu_id) else: tpu_device = xm.xla_device() # track the device and move model to it trainer._device = tpu_device model.to(trainer._device) # get the appropriate tpu ranks trainer.tpu_local_core_rank = xm.get_local_ordinal() trainer.tpu_global_core_rank = xm.get_ordinal() # avoid duplicating progress bar if trainer.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() trainer.global_rank = trainer.tpu_local_core_rank rank_zero_only.rank = trainer.global_rank # CHOOSE OPTIMIZER # allow for lr schedulers as well optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model) trainer.optimizers = optimizers trainer.lr_schedulers = lr_schedulers trainer.optimizer_frequencies = optimizer_frequencies # init 16 bit for TPU if trainer.precision == 16: os.environ['XLA_USE_BF16'] = str(1) log.info(f'INIT TPU local core: {trainer.tpu_local_core_rank},' f' global rank: {trainer.tpu_global_core_rank}' f' with XLA_USE_BF16={os.environ.get("XLA_USE_BF16")}') def backward(self, closure_loss, optimizer, opt_idx): model_ref = self.trainer.get_model() # do backward pass model_ref.backward(self, closure_loss, optimizer, opt_idx) # detach after backward closure_loss = closure_loss.detach() return closure_loss def optimizer_step(self, optimizer, batch_idx, opt_idx, lambda_closure): model_ref = self.trainer.get_model() is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) # model hook model_ref.optimizer_step( self.trainer.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure, on_tpu=True, using_lbfgs=is_lbfgs ) def clip_gradients(self, optimizer): # apply clip gradients # TODO: separate TPU case from here self._clip_gradients(optimizer)