From 4693276494e791be59a92b4230f654d58b639669 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 13 May 2019 20:40:07 -0400 Subject: [PATCH] added option to change default tensor --- pytorch_lightning/models/trainer.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 16f2288229..ee47791ef6 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -6,6 +6,11 @@ import traceback from pytorch_lightning.root_module.model_saving import TrainerIO from torch.optim.lr_scheduler import MultiStepLR +try: + from apex import amp + APEX_AVAILABLE = True +except ModuleNotFoundError: + APEX_AVAILABLE = False class Trainer(TrainerIO): @@ -73,6 +78,10 @@ class Trainer(TrainerIO): self.__determine_data_use_amount(train_percent_check, val_percent_check, test_percent_check, overfit_pct) print('gpu available: {}, used: {}'.format(torch.cuda.is_available(), self.on_gpu)) + # apex test + use_amp = True + self.use_amp = use_amp and APEX_AVAILABLE + def __determine_data_use_amount(self, train_percent_check, val_percent_check, test_percent_check, overfit_pct): """ Use less data for debugging purposes @@ -207,6 +216,15 @@ class Trainer(TrainerIO): # filter out the weights that were done on gpu so we can load on good old cpus self.optimizers = model.configure_optimizers() + if self.use_amp: + # An example + self.model, optimizer = amp.initialize( + self.model, self.optimizers[0], opt_level="O2", + keep_batchnorm_fp32=True, loss_scale="dynamic" + ) + self.optimizers[0] = optimizer + model.trainer = self + # add lr schedulers if self.lr_scheduler_milestones is not None: for optimizer in self.optimizers: @@ -347,7 +365,13 @@ class Trainer(TrainerIO): self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) # backward pass - loss.backward() + if self.use_amp: + for optimizer in self.optimizers: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + self.batch_loss_value += loss.item() # gradient update with accumulated gradients