From bc0278252e6c1b31496ff9de7b0576da1625d7e3 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 26 Jun 2019 18:04:29 -0400 Subject: [PATCH] removed self.model refs --- pytorch_lightning/models/trainer.py | 33 +++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 50634c2e37..1f82dd1ed3 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -293,11 +293,13 @@ class Trainer(TrainerIO): for lr_scheduler in self.lr_schedulers: lr_scheduler.step() - self.model.current_epoch = epoch_nb + model = self.model.module if self.data_parallel else self.model + model.current_epoch = epoch_nb # hook if self.__is_function_implemented('on_epoch_start'): - self.model.on_epoch_start() + model = self.model.module if self.data_parallel else self.model + model.on_epoch_start() self.current_epoch = epoch_nb self.total_batches = self.nb_tng_batches + self.nb_val_batches @@ -310,7 +312,9 @@ class Trainer(TrainerIO): for batch_nb, data_batch in enumerate(self.tng_dataloader): self.batch_nb = batch_nb self.global_step += 1 - self.model.global_step = self.global_step + + model = self.model.module if self.data_parallel else self.model + model.global_step = self.global_step # stop when the flag is changed or we've gone past the amount requested in the batches self.total_batch_nb += 1 @@ -340,7 +344,10 @@ class Trainer(TrainerIO): # count items in memory # nb_params, nb_tensors = count_mem_items() - metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic) + if self.data_parallel: + metrics = self.model.module.update_tng_log_metrics(self.__tng_tqdm_dic) + else: + metrics = self.model.update_tng_log_metrics(self.__tng_tqdm_dic) # add gpu memory if self.on_gpu: @@ -349,7 +356,9 @@ class Trainer(TrainerIO): # add norms if self.track_grad_norm > 0: - grad_norm_dic = self.model.module.grad_norm(self.track_grad_norm) + model = self.model.module if self.data_parallel else self.model + grad_norm_dic = model.grad_norm(self.track_grad_norm) + metrics.update(grad_norm_dic) # log metrics @@ -358,7 +367,8 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_batch_end'): - self.model.on_batch_end() + model = self.model.module if self.data_parallel else self.model + model.on_batch_end() # end epoch early if early_stop_epoch: @@ -366,7 +376,8 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_epoch_end'): - self.model.on_epoch_end() + model = self.model.module if self.data_parallel else self.model + model.on_epoch_end() # early stopping if self.enable_early_stop: @@ -385,7 +396,9 @@ class Trainer(TrainerIO): # hook if self.__is_function_implemented('on_batch_start'): - response = self.model.on_batch_start(data_batch) + model = self.model.module if self.data_parallel else self.model + response = model.on_batch_start(data_batch) + if response == -1: return -1 @@ -415,7 +428,9 @@ class Trainer(TrainerIO): loss.backward() if self.check_grad_nans: - for param in self.model.parameters(): + + model = self.model.module if self.data_parallel else self.model + for param in model.parameters(): print(param.grad.float().sum()) self.batch_loss_value += loss.item()