[Bug-Fix]:properties `current_epoch` and `global_step` between model and trainer same always (#3785)

* make current_epoch and global_step to be same as trainer, after model restore.

* remove assignment here

* test

* minor modification

* Update pytorch_lightning/core/lightning.py

type check, better clarity

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* Update pytorch_lightning/core/lightning.py

type check, better clarity

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>

* comments for current_epoch and global_step properties

* Update tests/models/test_restore.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* update comments according to the changes made

* Update tests/models/test_restore.py

* add current_epoch, global_step to jit ignore list

* Add comments to CHANGELOG

* Update CHANGELOG.md

* Update tests/models/test_restore.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Nrupatunga 2020-10-05 20:40:40 +05:30 committed by GitHub
parent 6ac0958166
commit 7d47ed178b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 10 deletions

View File

@ -103,6 +103,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517))
- Fixed `current_epoch` and `global_step` properties mismatch between `Trainer` and `LightningModule` ([#3785](https://github.com/PyTorchLightning/pytorch-lightning/pull/3785))
## [0.9.0] - YYYY-MM-DD
### Added

View File

@ -68,6 +68,8 @@ class LightningModule(
"example_input_array",
"hparams",
"on_gpu",
"current_epoch",
"global_step",
] + DeviceDtypeModuleMixin.__jit_unused_properties__
def __init__(self, *args, **kwargs):
@ -79,12 +81,6 @@ class LightningModule(
self.exp_save_path = None
#: The current epoch
self.current_epoch = 0
#: Total training batches seen across all epochs
self.global_step = 0
self.loaded_optimizer_states_dict = {}
#: Pointer to the trainer object
@ -121,6 +117,16 @@ class LightningModule(
def example_input_array(self) -> Any:
return self._example_input_array
@property
def current_epoch(self) -> int:
"""The current epoch"""
return self.trainer.current_epoch if self.trainer else 0
@property
def global_step(self) -> int:
"""Total training batches seen across all epochs"""
return self.trainer.global_step if self.trainer else 0
@example_input_array.setter
def example_input_array(self, example: Any) -> None:
self._example_input_array = example

View File

@ -213,8 +213,7 @@ class TrainLoop:
except Exception:
pass
# update training progress in trainer and model
model.current_epoch = epoch
# update training progress in trainer
self.trainer.current_epoch = epoch
# changing gradient according accumulation_scheduler
@ -520,7 +519,6 @@ class TrainLoop:
should_check_val = False
for batch_idx, (batch, is_last_batch) in train_dataloader:
self.trainer.batch_idx = batch_idx
model.global_step = self.trainer.global_step
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END

View File

@ -10,11 +10,52 @@ import torch
import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning import Trainer, Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.base import EvalModelTemplate, GenericEvalModelTemplate
class ModelTrainerPropertyParity(Callback):
def _check_properties(self, trainer, pl_module):
assert trainer.global_step == pl_module.global_step
assert trainer.current_epoch == pl_module.current_epoch
def on_train_start(self, trainer, pl_module):
self._check_properties(trainer, pl_module)
def on_train_batch_start(self, trainer, pl_module, *args, **kwargs):
self._check_properties(trainer, pl_module)
def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
self._check_properties(trainer, pl_module)
def on_epoch_end(self, trainer, pl_module):
self._check_properties(trainer, pl_module)
def on_train_end(self, trainer, pl_module):
self._check_properties(trainer, pl_module)
def test_resume_from_checkpoint(tmpdir):
""" Test that properties like `current_epoch` and `global_step`
in model and trainer are always the same. """
model = EvalModelTemplate()
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="early_stop_on", save_last=True)
trainer_args = dict(
default_root_dir=tmpdir,
max_epochs=2,
logger=False,
early_stop_callback=False,
checkpoint_callback=checkpoint_callback,
callbacks=[ModelTrainerPropertyParity()] # this performs the assertions
)
trainer = Trainer(**trainer_args)
trainer.fit(model)
trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt"))
trainer.fit(model)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_running_test_pretrained_model_distrib_dp(tmpdir):
"""Verify `test()` on pretrained model."""