fixed amp bug

This commit is contained in:
William Falcon 2019-07-24 14:23:52 -04:00
parent 9e187574de
commit b20a122e9c
3 changed files with 16 additions and 14 deletions

View File

@ -19,7 +19,7 @@ import tqdm
from pytorch_lightning.root_module.memory import get_gpu_memory_map
from pytorch_lightning.root_module.model_saving import TrainerIO
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utils.debugging import ForkedPdb
from pytorch_lightning.utils.debugging import IncompatibleArgumentsException
try:
from apex import amp
@ -466,7 +466,7 @@ class Trainer(TrainerIO):
m = f'amp level {self.amp_level} with DataParallel is not supported. ' \
f'See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227. ' \
f'We recommend you switch to ddp if you want to use amp'
raise Exception(m)
raise IncompatibleArgumentsException(m)
# run through amp wrapper
if self.use_amp:

View File

@ -12,4 +12,8 @@ class ForkedPdb(pdb.Pdb):
sys.stdin = open('/dev/stdin')
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
sys.stdin = _stdin
class IncompatibleArgumentsException(Exception):
pass

View File

@ -4,6 +4,7 @@ from pytorch_lightning.examples.new_project_templates.lightning_module_template
from argparse import Namespace
from test_tube import Experiment
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utils.debugging import IncompatibleArgumentsException
import numpy as np
import warnings
import torch
@ -90,17 +91,14 @@ def test_amp_gpu_dp():
warnings.warn('test_amp_gpu_dp cannot run. Rerun on a node with 2+ GPUs to run this test')
return
try:
trainer_options = dict(
max_nb_epochs=1,
gpus=[0, 1],
distributed_backend='dp',
use_amp=True
)
except Exception as e:
assert 'https://github.com/NVIDIA/apex/issues/227' in str(e)
run_gpu_model_test(trainer_options)
trainer_options = dict(
max_nb_epochs=1,
gpus=[0, 1],
distributed_backend='dp',
use_amp=True
)
with pytest.raises(IncompatibleArgumentsException):
run_gpu_model_test(trainer_options)
def test_multi_gpu_model_ddp():