Add method to TPUSpawn plugin to override how models are setup (#10039)
This commit is contained in:
parent
e94dcf6936
commit
64fc0d4257
|
@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
from torch.nn import Module
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
|
@ -118,6 +119,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
||||||
def setup(self) -> None:
|
def setup(self) -> None:
|
||||||
self.create_mp_queue()
|
self.create_mp_queue()
|
||||||
|
|
||||||
|
def _setup_model(self, model: Module) -> Module:
|
||||||
|
return model
|
||||||
|
|
||||||
def create_mp_queue(self):
|
def create_mp_queue(self):
|
||||||
self.start_method = "fork"
|
self.start_method = "fork"
|
||||||
smp = mp.get_context(self.start_method)
|
smp = mp.get_context(self.start_method)
|
||||||
|
|
Loading…
Reference in New Issue