[Bug-Fix]:properties `current_epoch` and `global_step` between model and trainer same always (#3785)
* make current_epoch and global_step to be same as trainer, after model restore. * remove assignment here * test * minor modification * Update pytorch_lightning/core/lightning.py type check, better clarity Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Update pytorch_lightning/core/lightning.py type check, better clarity Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * comments for current_epoch and global_step properties * Update tests/models/test_restore.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update comments according to the changes made * Update tests/models/test_restore.py * add current_epoch, global_step to jit ignore list * Add comments to CHANGELOG * Update CHANGELOG.md * Update tests/models/test_restore.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
6ac0958166
commit
7d47ed178b
|
@ -103,6 +103,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed aggregation of metrics ([#3517](https://github.com/PyTorchLightning/pytorch-lightning/pull/3517))
|
||||
|
||||
- Fixed `current_epoch` and `global_step` properties mismatch between `Trainer` and `LightningModule` ([#3785](https://github.com/PyTorchLightning/pytorch-lightning/pull/3785))
|
||||
|
||||
## [0.9.0] - YYYY-MM-DD
|
||||
|
||||
### Added
|
||||
|
|
|
@ -68,6 +68,8 @@ class LightningModule(
|
|||
"example_input_array",
|
||||
"hparams",
|
||||
"on_gpu",
|
||||
"current_epoch",
|
||||
"global_step",
|
||||
] + DeviceDtypeModuleMixin.__jit_unused_properties__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -79,12 +81,6 @@ class LightningModule(
|
|||
|
||||
self.exp_save_path = None
|
||||
|
||||
#: The current epoch
|
||||
self.current_epoch = 0
|
||||
|
||||
#: Total training batches seen across all epochs
|
||||
self.global_step = 0
|
||||
|
||||
self.loaded_optimizer_states_dict = {}
|
||||
|
||||
#: Pointer to the trainer object
|
||||
|
@ -121,6 +117,16 @@ class LightningModule(
|
|||
def example_input_array(self) -> Any:
|
||||
return self._example_input_array
|
||||
|
||||
@property
|
||||
def current_epoch(self) -> int:
|
||||
"""The current epoch"""
|
||||
return self.trainer.current_epoch if self.trainer else 0
|
||||
|
||||
@property
|
||||
def global_step(self) -> int:
|
||||
"""Total training batches seen across all epochs"""
|
||||
return self.trainer.global_step if self.trainer else 0
|
||||
|
||||
@example_input_array.setter
|
||||
def example_input_array(self, example: Any) -> None:
|
||||
self._example_input_array = example
|
||||
|
|
|
@ -213,8 +213,7 @@ class TrainLoop:
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
# update training progress in trainer and model
|
||||
model.current_epoch = epoch
|
||||
# update training progress in trainer
|
||||
self.trainer.current_epoch = epoch
|
||||
|
||||
# changing gradient according accumulation_scheduler
|
||||
|
@ -520,7 +519,6 @@ class TrainLoop:
|
|||
should_check_val = False
|
||||
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
||||
self.trainer.batch_idx = batch_idx
|
||||
model.global_step = self.trainer.global_step
|
||||
|
||||
# ------------------------------------
|
||||
# TRAINING_STEP + TRAINING_STEP_END
|
||||
|
|
|
@ -10,11 +10,52 @@ import torch
|
|||
|
||||
import tests.base.develop_pipelines as tpipes
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import Trainer, Callback
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from tests.base import EvalModelTemplate, GenericEvalModelTemplate
|
||||
|
||||
|
||||
class ModelTrainerPropertyParity(Callback):
|
||||
|
||||
def _check_properties(self, trainer, pl_module):
|
||||
assert trainer.global_step == pl_module.global_step
|
||||
assert trainer.current_epoch == pl_module.current_epoch
|
||||
|
||||
def on_train_start(self, trainer, pl_module):
|
||||
self._check_properties(trainer, pl_module)
|
||||
|
||||
def on_train_batch_start(self, trainer, pl_module, *args, **kwargs):
|
||||
self._check_properties(trainer, pl_module)
|
||||
|
||||
def on_train_batch_end(self, trainer, pl_module, *args, **kwargs):
|
||||
self._check_properties(trainer, pl_module)
|
||||
|
||||
def on_epoch_end(self, trainer, pl_module):
|
||||
self._check_properties(trainer, pl_module)
|
||||
|
||||
def on_train_end(self, trainer, pl_module):
|
||||
self._check_properties(trainer, pl_module)
|
||||
|
||||
|
||||
def test_resume_from_checkpoint(tmpdir):
|
||||
""" Test that properties like `current_epoch` and `global_step`
|
||||
in model and trainer are always the same. """
|
||||
model = EvalModelTemplate()
|
||||
checkpoint_callback = ModelCheckpoint(filepath=tmpdir, monitor="early_stop_on", save_last=True)
|
||||
trainer_args = dict(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
logger=False,
|
||||
early_stop_callback=False,
|
||||
checkpoint_callback=checkpoint_callback,
|
||||
callbacks=[ModelTrainerPropertyParity()] # this performs the assertions
|
||||
)
|
||||
trainer = Trainer(**trainer_args)
|
||||
trainer.fit(model)
|
||||
trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt"))
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_running_test_pretrained_model_distrib_dp(tmpdir):
|
||||
"""Verify `test()` on pretrained model."""
|
||||
|
|
Loading…
Reference in New Issue