diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index d15b86f02e..f4a58902b2 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -7,6 +7,9 @@ from pytorch_lightning.root_module.model_saving import TrainerIO from torch.optim.lr_scheduler import MultiStepLR import pdb + + + try: from apex import amp APEX_AVAILABLE = True @@ -83,6 +86,7 @@ class Trainer(TrainerIO): use_amp = True self.use_amp = use_amp and APEX_AVAILABLE if self.use_amp: + self.amp_handle = amp.init(enabled=True) print('using 16bit precision') def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct): @@ -219,14 +223,14 @@ class Trainer(TrainerIO): # filter out the weights that were done on gpu so we can load on good old cpus self.optimizers = model.configure_optimizers() - if self.use_amp: - # An example - self.model, optimizer = amp.initialize( - self.model, self.optimizers[0], opt_level="O2", - keep_batchnorm_fp32=True, loss_scale="dynamic" - ) - self.optimizers[0] = optimizer - model.trainer = self + # if self.use_amp: + # An example + # self.model, optimizer = amp.initialize( + # self.model, self.optimizers[0], opt_level="O2", + # keep_batchnorm_fp32=True, loss_scale="dynamic" + # ) + # self.optimizers[0] = optimizer + # model.trainer = self # add lr schedulers if self.lr_scheduler_milestones is not None: @@ -370,7 +374,7 @@ class Trainer(TrainerIO): # backward pass if self.use_amp: for optimizer in self.optimizers: - with amp.scale_loss(loss, optimizer) as scaled_loss: + with self.amp_handle.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() diff --git a/setup.py b/setup.py index 684b309963..699a89329b 100755 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages # http://blog.ionelmc.ro/2014/05/25/python-packaging/ setup( name="pytorch-lightning", - version='0.1.dev1721', + version='0.1.dev1722', description="The Keras for ML researchers using PyTorch", author="William Falcon", author_email="waf2107@columbia.edu",