Add method to TPUSpawn plugin to override how models are setup (#10039)

This commit is contained in:
Kaushik B 2021-10-25 17:14:32 +05:30 committed by GitHub
parent e94dcf6936
commit 64fc0d4257
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 0 deletions

View File

@ -20,6 +20,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from torch.nn import Module
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import pytorch_lightning as pl import pytorch_lightning as pl
@ -118,6 +119,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
def setup(self) -> None: def setup(self) -> None:
self.create_mp_queue() self.create_mp_queue()
def _setup_model(self, model: Module) -> Module:
return model
def create_mp_queue(self): def create_mp_queue(self):
self.start_method = "fork" self.start_method = "fork"
smp = mp.get_context(self.start_method) smp = mp.get_context(self.start_method)