From c973245ba160519612b5e7e397793efde16902c9 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 14 May 2019 06:11:16 -0400 Subject: [PATCH] fixed error with shorter batch cycles --- pytorch_lightning/models/trainer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index bbfa5b7316..dd3aed2b6e 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -122,21 +122,21 @@ class Trainer(TrainerIO): self.tqdm_metrics = {} # determine number of training batches - nb_tng_batches = self.model.nb_batches(self.tng_dataloader) - self.nb_tng_batches = int(nb_tng_batches * self.train_percent_check) + self.nb_tng_batches = self.model.nb_batches(self.tng_dataloader) + self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check) # determine number of validation batches - nb_val_batches = self.model.nb_batches(self.val_dataloader) - nb_val_batches = int(nb_val_batches * self.val_percent_check) - nb_val_batches = max(1, nb_val_batches) - self.nb_val_batches = nb_val_batches + self.nb_val_batches = self.model.nb_batches(self.val_dataloader) + self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check) + self.nb_val_batches = max(1, self.nb_val_batches) + self.nb_val_batches = self.nb_val_batches # determine number of test batches - nb_test_batches = self.model.nb_batches(self.test_dataloader) - self.nb_test_batches = int(nb_test_batches * self.test_percent_check) + self.nb_test_batches = self.model.nb_batches(self.test_dataloader) + self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check) # determine when to check validation - self.val_check_batch = int(nb_tng_batches * self.val_check_interval) + self.val_check_batch = int(self.nb_tng_batches * self.val_check_interval) def __add_tqdm_metrics(self, metrics): for k, v in metrics.items():