added test for no dist sampler
This commit is contained in:
parent
5c21683566
commit
8064a77aa7
|
@ -19,7 +19,7 @@ import tqdm
|
|||
from pytorch_lightning.root_module.memory import get_gpu_memory_map
|
||||
from pytorch_lightning.root_module.model_saving import TrainerIO
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel
|
||||
from pytorch_lightning.utils.debugging import IncompatibleArgumentsException
|
||||
from pytorch_lightning.utils.debugging import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -392,7 +392,7 @@ class Trainer(TrainerIO):
|
|||
dist_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||||
dataloader = Dataloader(dataset, sampler=dist_sampler)
|
||||
'''
|
||||
raise Exception(msg)
|
||||
raise MisconfigurationException(msg)
|
||||
|
||||
# -----------------------------
|
||||
# MODEL TRAINING
|
||||
|
@ -467,7 +467,7 @@ class Trainer(TrainerIO):
|
|||
m = f'amp level {self.amp_level} with DataParallel is not supported. ' \
|
||||
f'See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227. ' \
|
||||
f'We recommend you switch to ddp if you want to use amp'
|
||||
raise IncompatibleArgumentsException(m)
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
# run through amp wrapper
|
||||
if self.use_amp:
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import pdb
|
||||
import sys
|
||||
|
||||
class IncompatibleArgumentsException(Exception):
|
||||
class MisconfigurationException(Exception):
|
||||
pass
|
|
@ -1,10 +1,11 @@
|
|||
import pytest
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.examples.new_project_templates.lightning_module_template import LightningTemplateModel
|
||||
from pytorch_lightning.testing_models.lm_test_module import LightningTestModel
|
||||
from argparse import Namespace
|
||||
from test_tube import Experiment
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.utils.debugging import IncompatibleArgumentsException
|
||||
from pytorch_lightning.utils.debugging import MisconfigurationException
|
||||
import numpy as np
|
||||
import warnings
|
||||
import torch
|
||||
|
@ -33,7 +34,8 @@ def test_cpu_model():
|
|||
val_percent_check=0.4
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, on_gpu=False)
|
||||
model, hparams = get_model()
|
||||
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_all_features_cpu_model():
|
||||
|
@ -54,7 +56,8 @@ def test_all_features_cpu_model():
|
|||
val_percent_check=0.4
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, on_gpu=False)
|
||||
model, hparams = get_model()
|
||||
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_early_stopping_cpu_model():
|
||||
|
@ -77,7 +80,8 @@ def test_early_stopping_cpu_model():
|
|||
val_percent_check=0.4
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options, on_gpu=False)
|
||||
model, hparams = get_model()
|
||||
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
|
||||
|
||||
|
||||
def test_single_gpu_model():
|
||||
|
@ -88,6 +92,7 @@ def test_single_gpu_model():
|
|||
if not torch.cuda.is_available():
|
||||
warnings.warn('test_single_gpu_model cannot run. Rerun on a GPU node to run this test')
|
||||
return
|
||||
model, hparams = get_model()
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
|
@ -97,7 +102,7 @@ def test_single_gpu_model():
|
|||
gpus=[0]
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options)
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_multi_gpu_model_dp():
|
||||
|
@ -111,7 +116,7 @@ def test_multi_gpu_model_dp():
|
|||
if not torch.cuda.device_count() > 1:
|
||||
warnings.warn('test_multi_gpu_model_dp cannot run. Rerun on a node with 2+ GPUs to run this test')
|
||||
return
|
||||
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
|
@ -120,7 +125,7 @@ def test_multi_gpu_model_dp():
|
|||
gpus=[0, 1]
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options)
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_amp_gpu_dp():
|
||||
|
@ -134,15 +139,15 @@ def test_amp_gpu_dp():
|
|||
if not torch.cuda.device_count() > 1:
|
||||
warnings.warn('test_amp_gpu_dp cannot run. Rerun on a node with 2+ GPUs to run this test')
|
||||
return
|
||||
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
gpus='0, 1', # test init with gpu string
|
||||
distributed_backend='dp',
|
||||
use_amp=True
|
||||
)
|
||||
with pytest.raises(IncompatibleArgumentsException):
|
||||
run_gpu_model_test(trainer_options)
|
||||
with pytest.raises(MisconfigurationException):
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_multi_gpu_model_ddp():
|
||||
|
@ -158,7 +163,7 @@ def test_multi_gpu_model_ddp():
|
|||
return
|
||||
|
||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
progress_bar=False,
|
||||
max_nb_epochs=1,
|
||||
|
@ -168,7 +173,7 @@ def test_multi_gpu_model_ddp():
|
|||
distributed_backend='ddp'
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options)
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
def test_amp_gpu_ddp():
|
||||
|
@ -185,6 +190,7 @@ def test_amp_gpu_ddp():
|
|||
|
||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||
|
||||
model, hparams = get_model()
|
||||
trainer_options = dict(
|
||||
progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
|
@ -193,14 +199,10 @@ def test_amp_gpu_ddp():
|
|||
use_amp=True
|
||||
)
|
||||
|
||||
run_gpu_model_test(trainer_options)
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# UTILS
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def run_gpu_model_test(trainer_options, on_gpu=True):
|
||||
def test_ddp_sampler_error():
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
:return:
|
||||
|
@ -212,8 +214,34 @@ def run_gpu_model_test(trainer_options, on_gpu=True):
|
|||
warnings.warn('test_amp_gpu_ddp cannot run. Rerun on a node with 2+ GPUs to run this test')
|
||||
return
|
||||
|
||||
os.environ['MASTER_PORT'] = str(np.random.randint(12000, 19000, 1)[0])
|
||||
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams, force_remove_distributed_sampler=True)
|
||||
|
||||
trainer_options = dict(
|
||||
progress_bar=True,
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='ddp',
|
||||
use_amp=True
|
||||
)
|
||||
|
||||
with pytest.raises(MisconfigurationException):
|
||||
run_gpu_model_test(trainer_options, model, hparams)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
# UTILS
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
def run_gpu_model_test(trainer_options, model, hparams, on_gpu=True):
|
||||
"""
|
||||
Make sure DDP + AMP work
|
||||
:return:
|
||||
"""
|
||||
|
||||
save_dir = init_save_dir()
|
||||
model, hparams = get_model()
|
||||
|
||||
# exp file to get meta
|
||||
exp = get_exp(False)
|
||||
|
@ -243,8 +271,7 @@ def run_gpu_model_test(trainer_options, on_gpu=True):
|
|||
clear_save_dir()
|
||||
|
||||
|
||||
def get_model():
|
||||
# set up model with these hyperparams
|
||||
def get_hparams():
|
||||
root_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
hparams = Namespace(**{'drop_prob': 0.2,
|
||||
'batch_size': 32,
|
||||
|
@ -254,6 +281,12 @@ def get_model():
|
|||
'data_root': os.path.join(root_dir, 'mnist'),
|
||||
'out_features': 10,
|
||||
'hidden_dim': 1000})
|
||||
return hparams
|
||||
|
||||
|
||||
def get_model():
|
||||
# set up model with these hyperparams
|
||||
hparams = get_hparams()
|
||||
model = LightningTemplateModel(hparams)
|
||||
|
||||
return model, hparams
|
||||
|
|
Loading…
Reference in New Issue