[HOTFIX] ModelCheckpoint - Don't increase current_epoch and global_step if not trained (#4291)

* add two tests w/wo tempdir

* resolve flake8

* this test is failing

* update bug report

* resolve bug and add test

* remove bug_report

* resolve flake8

* resolve bug

* resolve pep8

* resolve pep8

Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
chaton 2020-10-23 11:17:50 +01:00 committed by GitHub
parent d24cf56d0a
commit 3abfec8962
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 179 additions and 4 deletions

View File

@ -49,6 +49,9 @@ class CheckpointConnector:
def __init__(self, trainer):
self.trainer = trainer
# used to validate checkpointing logic
self.has_trained = False
def restore_weights(self, model: LightningModule):
"""
We attempt to restore weights in this order:
@ -246,9 +249,19 @@ class CheckpointConnector:
Return:
structured dictionary
"""
current_epoch = self.trainer.current_epoch
global_step = self.trainer.global_step
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step
global_step += 1
if self.has_trained:
if not has_reached_max_steps:
current_epoch += 1
checkpoint = {
'epoch': self.trainer.current_epoch + 1,
'global_step': self.trainer.global_step + 1,
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
}

View File

@ -460,6 +460,8 @@ class Trainer(
def train(self):
self.run_sanity_check(self.get_model())
self.checkpoint_connector.has_trained = False
# enable train mode
model = self.get_model()
model.train()

View File

@ -535,6 +535,7 @@ class TrainLoop:
dataloader_idx = 0
should_check_val = False
for batch_idx, (batch, is_last_batch) in train_dataloader:
self.trainer.batch_idx = batch_idx
# ------------------------------------
@ -602,6 +603,8 @@ class TrainLoop:
# progress global step according to grads progress
self.increment_accumulated_grad_global_step()
self.trainer.checkpoint_connector.has_trained = True
# log epoch metrics
self.trainer.logger_connector.log_train_epoch_end_metrics(
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers

View File

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import pytorch_lightning as pl
from distutils.version import LooseVersion
from unittest.mock import MagicMock, Mock
@ -30,6 +32,7 @@ from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import EvalModelTemplate, BoringModel
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -472,7 +475,8 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
model.validation_step = None
trainer = Trainer(
default_root_dir=tmpdir,
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last),
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir,
save_top_k=0, save_last=save_last),
max_epochs=max_epochs,
)
trainer.fit(model)
@ -542,7 +546,7 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode):
assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1
def test_checkpoint_within_callbacks_list(tmpdir):
def test_checkpoint_repeated_strategy(tmpdir):
"""
This test validates that the checkpoint can be called when provided to callacks list
"""
@ -572,6 +576,159 @@ def test_checkpoint_within_callbacks_list(tmpdir):
trainer.fit(model)
assert os.listdir(tmpdir) == ['epoch=00.ckpt']
def get_last_checkpoint():
ckpts = os.listdir(tmpdir)
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
num_ckpts = len(ckpts_map) - 1
return ckpts_map[num_ckpts]
for idx in range(1, 5):
# load from checkpoint
chk = get_last_checkpoint()
model = BoringModel.load_from_checkpoint(chk)
trainer = pl.Trainer(max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk)
trainer.fit(model)
trainer.test(model)
assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']"
def test_checkpoint_repeated_strategy_tmpdir(tmpdir):
"""
This test validates that the checkpoint can be called when provided to callacks list
"""
os.environ['PL_DEV_DEBUG'] = '1'
checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join(tmpdir, "{epoch:02d}"))
class ExtendedBoringModel(BoringModel):
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}
model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
callbacks=[checkpoint_callback])
trainer.fit(model)
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0'])
def get_last_checkpoint():
ckpts = os.listdir(tmpdir)
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
num_ckpts = len(ckpts_map) - 1
return ckpts_map[num_ckpts]
for idx in range(1, 5):
# load from checkpoint
chk = get_last_checkpoint()
model = BoringModel.load_from_checkpoint(chk)
trainer = pl.Trainer(default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk)
trainer.fit(model)
trainer.test(model)
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
assert sorted(os.listdir(path_to_lightning_logs)) == sorted([f'version_{i}' for i in range(idx + 1)])
def test_checkpoint_repeated_strategy_extended(tmpdir):
"""
This test validates checkpoint can be called several times without
increasing internally its global step if nothing run.
"""
os.environ['PL_DEV_DEBUG'] = '1'
class ExtendedBoringModel(BoringModel):
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}
model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = pl.Trainer(default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
)
assert trainer.checkpoint_connector.has_trained is not True
assert trainer.current_epoch == 0
trainer.fit(model)
assert trainer.checkpoint_connector.has_trained is True
assert trainer.global_step == 2
assert trainer.current_epoch == 0
trainer.test(model)
assert trainer.current_epoch == 0
assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']"
def get_last_checkpoint():
logs_dir = osp.join(tmpdir, 'lightning_logs')
versions = os.listdir(logs_dir)
versions.sort()
last_version = versions[-1]
ckpt_dir = osp.join(logs_dir, last_version, "checkpoints")
ckpts = os.listdir(ckpt_dir)
ckpts.sort()
return osp.join(ckpt_dir, ckpts[-1])
def assert_checkpoint_content():
chk = pl_load(get_last_checkpoint())
assert chk["epoch"] == 1
assert chk["global_step"] == 2
assert_checkpoint_content()
for idx in range(1, 5):
# load from checkpoint
chk = get_last_checkpoint()
assert_checkpoint_content()
model = BoringModel.load_from_checkpoint(chk)
trainer = pl.Trainer(default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk)
assert trainer.checkpoint_connector.has_trained is not True
assert trainer.global_step == 0
trainer.test(model)
assert trainer.global_step == 2
trainer.fit(model)
assert trainer.global_step == 2
assert trainer.checkpoint_connector.has_trained is not True
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(lightning_logs_path)) == [f"version_{i}" for i in range(idx + 1)]
@pytest.mark.parametrize(
'filepath, dirpath, filename',