fixed amp bug
This commit is contained in:
parent
9e187574de
commit
b20a122e9c
|
@ -19,7 +19,7 @@ import tqdm
|
|||
from pytorch_lightning.root_module.memory import get_gpu_memory_map
|
||||
from pytorch_lightning.root_module.model_saving import TrainerIO
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDistributedDataParallel, LightningDataParallel
|
||||
from pytorch_lightning.utils.debugging import ForkedPdb
|
||||
from pytorch_lightning.utils.debugging import IncompatibleArgumentsException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -466,7 +466,7 @@ class Trainer(TrainerIO):
|
|||
m = f'amp level {self.amp_level} with DataParallel is not supported. ' \
|
||||
f'See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227. ' \
|
||||
f'We recommend you switch to ddp if you want to use amp'
|
||||
raise Exception(m)
|
||||
raise IncompatibleArgumentsException(m)
|
||||
|
||||
# run through amp wrapper
|
||||
if self.use_amp:
|
||||
|
|
|
@ -12,4 +12,8 @@ class ForkedPdb(pdb.Pdb):
|
|||
sys.stdin = open('/dev/stdin')
|
||||
pdb.Pdb.interaction(self, *args, **kwargs)
|
||||
finally:
|
||||
sys.stdin = _stdin
|
||||
sys.stdin = _stdin
|
||||
|
||||
|
||||
class IncompatibleArgumentsException(Exception):
|
||||
pass
|
|
@ -4,6 +4,7 @@ from pytorch_lightning.examples.new_project_templates.lightning_module_template
|
|||
from argparse import Namespace
|
||||
from test_tube import Experiment
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utils.debugging import IncompatibleArgumentsException
|
||||
import numpy as np
|
||||
import warnings
|
||||
import torch
|
||||
|
@ -90,17 +91,14 @@ def test_amp_gpu_dp():
|
|||
warnings.warn('test_amp_gpu_dp cannot run. Rerun on a node with 2+ GPUs to run this test')
|
||||
return
|
||||
|
||||
try:
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='dp',
|
||||
use_amp=True
|
||||
)
|
||||
except Exception as e:
|
||||
assert 'https://github.com/NVIDIA/apex/issues/227' in str(e)
|
||||
|
||||
run_gpu_model_test(trainer_options)
|
||||
trainer_options = dict(
|
||||
max_nb_epochs=1,
|
||||
gpus=[0, 1],
|
||||
distributed_backend='dp',
|
||||
use_amp=True
|
||||
)
|
||||
with pytest.raises(IncompatibleArgumentsException):
|
||||
run_gpu_model_test(trainer_options)
|
||||
|
||||
|
||||
def test_multi_gpu_model_ddp():
|
||||
|
|
Loading…
Reference in New Issue