# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import time from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Optional import torch.multiprocessing as mp import pytorch_lightning as pl from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnLauncher, _SpawnOutput from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE from pytorch_lightning.utilities.apply_func import move_data_to_device from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_debug if _TPU_AVAILABLE: import torch_xla.distributed.xla_multiprocessing as xmp else: xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 class _XLASpawnLauncher(_SpawnLauncher): r"""Spawns processes that run a given function in parallel on XLA supported hardware, and joins them all at the end. The main process in which this launcher is invoked creates N so-called worker processes (using the `torch_xla` :func:`xmp.spawn`) that run the given function. Worker processes have a rank that ranges from 0 to N - 1. Note: - This launcher requires all objects to be pickleable. - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. """ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. The function is allowed to have a return value. However, when all processes join, only the return value of worker process 0 gets returned from this `launch` method in the main process. Arguments: function: The entry point for all spawned processes. *args: Optional positional arguments to be passed to the given function. **kwargs: Optional keyword arguments to be passed to the given function. If a keyword argument named `trainer` is present and is an instance of :class:`~pytorch_lightning.trainer.trainer.Trainer`, a selected set of attributes from the trainer get restored in the main process after processes join. The `trainer` keyword argument will NOT be passed into the function. """ trainer = kwargs.pop("trainer", None) context = mp.get_context(self._strategy.start_method or "fork") return_queue = context.SimpleQueue() xmp.spawn( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), **self._strategy.get_mp_spawn_kwargs() ) spawn_output = return_queue.get() if trainer is None: return spawn_output self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results def _wrapping_function( self, process_idx: int, trainer: Optional["pl.Trainer"], function: Callable, args: Any, kwargs: Any, return_queue: SimpleQueue, ) -> None: self._strategy._worker_setup(process_idx) results = function(*args, **kwargs) if trainer is not None: results = self._collect_rank_zero_results(trainer, results) if self._strategy.local_rank == 0: return_queue.put(move_data_to_device(results, "cpu")) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self._strategy.barrier("end-process") # Ensure that the rank 0 process is the one exiting last # https://github.com/pytorch/xla/issues/2190#issuecomment-641665358 if self._strategy.local_rank == 0: time.sleep(2) def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the TPU spawn environment.") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training if self._strategy.local_rank != 0: return None # adds the `callback_metrics` to the queue extra = _FakeQueue() if is_overridden("add_to_queue", trainer.lightning_module): # TODO: Remove the if in v1.7 trainer.lightning_module.add_to_queue(extra) self.add_to_queue(trainer, extra) return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)