From 734b28ed2dcd0feb23b44744a3d3d40de0b20a08 Mon Sep 17 00:00:00 2001 From: Shunsuke Hidaka Date: Wed, 5 Feb 2020 19:15:51 +0900 Subject: [PATCH] Set warnings : Unify epoch numbers to be zero-based : #675 (#786) * [update] : #675 : set warnings * [fix] : #675 : remove white space --- pytorch_lightning/callbacks/pt_callbacks.py | 9 ++++++++- pytorch_lightning/trainer/training_loop.py | 2 ++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 12eae50909..ccca95d076 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -174,6 +174,8 @@ class EarlyStopping(Callback): def on_train_end(self, logs=None): if self.stopped_epoch > 0 and self.verbose > 0: + warnings.warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,' + ' but will start from "0" in v0.8.0.', DeprecationWarning) log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping') @@ -374,6 +376,7 @@ class GradientAccumulationScheduler(Callback): Args: scheduling (dict): scheduling in format {epoch: accumulation_factor} + warning:: Epochs indexing starts from "1" until v0.6.x, but will start from "0" in v0.8.0. Example:: @@ -394,6 +397,8 @@ class GradientAccumulationScheduler(Callback): raise TypeError("All epoches and accumulation factor must be integers") minimal_epoch = min(scheduling.keys()) + warnings.warn('Epochs indexing of `scheduling` starts from "1" until v0.6.x,' + ' but will start from "0" in v0.8.0.', DeprecationWarning) if minimal_epoch < 1: msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" raise IndexError(msg) @@ -404,7 +409,9 @@ class GradientAccumulationScheduler(Callback): self.epochs = sorted(scheduling.keys()) def on_epoch_begin(self, epoch, trainer): - epoch += 1 # indexing epochs from 1 + # indexing epochs from 1 (until v0.6.x) + # In v0.8.0, `epoch += 1` should be removed. + epoch += 1 for i in reversed(range(len(self.epochs))): if epoch >= self.epochs[i]: trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index e1f90308a7..a1a5676494 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -281,6 +281,8 @@ class TrainerTrainLoopMixin(ABC): pass def train(self): + warnings.warn('Displayed epoch numbers in the progress bar start from "1" until v0.6.x,' + ' but will start from "0" in v0.8.0.', DeprecationWarning) model = self.get_model() # run all epochs for epoch in range(self.current_epoch, self.max_epochs):