Add method to TPUSpawn plugin to override how models are setup (#10039)
This commit is contained in:
parent
e94dcf6936
commit
64fc0d4257
|
@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn import Module
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -118,6 +119,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
def setup(self) -> None:
|
||||
self.create_mp_queue()
|
||||
|
||||
def _setup_model(self, model: Module) -> Module:
|
||||
return model
|
||||
|
||||
def create_mp_queue(self):
|
||||
self.start_method = "fork"
|
||||
smp = mp.get_context(self.start_method)
|
||||
|
|
Loading…
Reference in New Issue