From 1c7d477d038d95fd5bfe26befe6648e01bb46bb2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 13 May 2019 21:52:02 -0400 Subject: [PATCH] added option to change default tensor --- pytorch_lightning/models/trainer.py | 22 +++++++++------------- setup.py | 2 +- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index f4a58902b2..d15b86f02e 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -7,9 +7,6 @@ from pytorch_lightning.root_module.model_saving import TrainerIO from torch.optim.lr_scheduler import MultiStepLR import pdb - - - try: from apex import amp APEX_AVAILABLE = True @@ -86,7 +83,6 @@ class Trainer(TrainerIO): use_amp = True self.use_amp = use_amp and APEX_AVAILABLE if self.use_amp: - self.amp_handle = amp.init(enabled=True) print('using 16bit precision') def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct): @@ -223,14 +219,14 @@ class Trainer(TrainerIO): # filter out the weights that were done on gpu so we can load on good old cpus self.optimizers = model.configure_optimizers() - # if self.use_amp: - # An example - # self.model, optimizer = amp.initialize( - # self.model, self.optimizers[0], opt_level="O2", - # keep_batchnorm_fp32=True, loss_scale="dynamic" - # ) - # self.optimizers[0] = optimizer - # model.trainer = self + if self.use_amp: + # An example + self.model, optimizer = amp.initialize( + self.model, self.optimizers[0], opt_level="O2", + keep_batchnorm_fp32=True, loss_scale="dynamic" + ) + self.optimizers[0] = optimizer + model.trainer = self # add lr schedulers if self.lr_scheduler_milestones is not None: @@ -374,7 +370,7 @@ class Trainer(TrainerIO): # backward pass if self.use_amp: for optimizer in self.optimizers: - with self.amp_handle.scale_loss(loss, optimizer) as scaled_loss: + with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() diff --git a/setup.py b/setup.py index 699a89329b..59c20b2690 100755 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from setuptools import setup, find_packages # http://blog.ionelmc.ro/2014/05/25/python-packaging/ setup( name="pytorch-lightning", - version='0.1.dev1722', + version='0.1.dev1723', description="The Keras for ML researchers using PyTorch", author="William Falcon", author_email="waf2107@columbia.edu",