diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 5643882a34..94edfe5354 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.multiprocessing as mp +from torch.nn import Module from torch.utils.data import DataLoader import pytorch_lightning as pl @@ -118,6 +119,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin): def setup(self) -> None: self.create_mp_queue() + def _setup_model(self, model: Module) -> Module: + return model + def create_mp_queue(self): self.start_method = "fork" smp = mp.get_context(self.start_method)