Add method to TPUSpawn plugin to override how models are setup (#10039)

This commit is contained in:
Kaushik B 2021-10-25 17:14:32 +05:30 committed by GitHub
parent e94dcf6936
commit 64fc0d4257
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch
import torch.multiprocessing as mp
from torch.nn import Module
from torch.utils.data import DataLoader
import pytorch_lightning as pl
@ -118,6 +119,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def setup(self) -> None:
self.create_mp_queue()
def _setup_model(self, model: Module) -> Module:
return model
def create_mp_queue(self):
self.start_method = "fork"
smp = mp.get_context(self.start_method)