From 64fc0d42579c69ad1849b115622460483727b8af Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 25 Oct 2021 17:14:32 +0530 Subject: [PATCH] Add method to TPUSpawn plugin to override how models are setup (#10039) --- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 5643882a34..94edfe5354 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.multiprocessing as mp +from torch.nn import Module from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -118,6 +119,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def setup(self) -> None: self.create_mp_queue() + def _setup_model(self, model: Module) -> Module: + return model + def create_mp_queue(self): self.start_method = "fork" smp = mp.get_context(self.start_method)