[HOTFIX] ModelCheckpoint - Don't increase current_epoch and global_step if not trained (#4291)
* add two tests w/wo tempdir * resolve flake8 * this test is failing * update bug report * resolve bug and add test * remove bug_report * resolve flake8 * resolve bug * resolve pep8 * resolve pep8 Co-authored-by: Teddy Koker <teddy.koker@gmail.com>
This commit is contained in:
parent
d24cf56d0a
commit
3abfec8962
|
@ -49,6 +49,9 @@ class CheckpointConnector:
|
|||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
# used to validate checkpointing logic
|
||||
self.has_trained = False
|
||||
|
||||
def restore_weights(self, model: LightningModule):
|
||||
"""
|
||||
We attempt to restore weights in this order:
|
||||
|
@ -246,9 +249,19 @@ class CheckpointConnector:
|
|||
Return:
|
||||
structured dictionary
|
||||
"""
|
||||
|
||||
current_epoch = self.trainer.current_epoch
|
||||
global_step = self.trainer.global_step
|
||||
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step
|
||||
|
||||
global_step += 1
|
||||
if self.has_trained:
|
||||
if not has_reached_max_steps:
|
||||
current_epoch += 1
|
||||
|
||||
checkpoint = {
|
||||
'epoch': self.trainer.current_epoch + 1,
|
||||
'global_step': self.trainer.global_step + 1,
|
||||
'epoch': current_epoch,
|
||||
'global_step': global_step,
|
||||
'pytorch-lightning_version': pytorch_lightning.__version__,
|
||||
}
|
||||
|
||||
|
|
|
@ -460,6 +460,8 @@ class Trainer(
|
|||
def train(self):
|
||||
self.run_sanity_check(self.get_model())
|
||||
|
||||
self.checkpoint_connector.has_trained = False
|
||||
|
||||
# enable train mode
|
||||
model = self.get_model()
|
||||
model.train()
|
||||
|
|
|
@ -535,6 +535,7 @@ class TrainLoop:
|
|||
dataloader_idx = 0
|
||||
should_check_val = False
|
||||
for batch_idx, (batch, is_last_batch) in train_dataloader:
|
||||
|
||||
self.trainer.batch_idx = batch_idx
|
||||
|
||||
# ------------------------------------
|
||||
|
@ -602,6 +603,8 @@ class TrainLoop:
|
|||
# progress global step according to grads progress
|
||||
self.increment_accumulated_grad_global_step()
|
||||
|
||||
self.trainer.checkpoint_connector.has_trained = True
|
||||
|
||||
# log epoch metrics
|
||||
self.trainer.logger_connector.log_train_epoch_end_metrics(
|
||||
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers
|
||||
|
|
|
@ -12,6 +12,8 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import os.path as osp
|
||||
import pytorch_lightning as pl
|
||||
from distutils.version import LooseVersion
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
|
@ -30,6 +32,7 @@ from pytorch_lightning import Trainer, seed_everything
|
|||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests.base import EvalModelTemplate, BoringModel
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
|
@ -472,7 +475,8 @@ def test_model_checkpoint_save_last_warning(tmpdir, caplog, max_epochs, should_v
|
|||
model.validation_step = None
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, save_top_k=0, save_last=save_last),
|
||||
checkpoint_callback=ModelCheckpoint(monitor='early_stop_on', filepath=tmpdir,
|
||||
save_top_k=0, save_last=save_last),
|
||||
max_epochs=max_epochs,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
@ -542,7 +546,7 @@ def test_checkpointing_with_nan_as_first(tmpdir, mode):
|
|||
assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1
|
||||
|
||||
|
||||
def test_checkpoint_within_callbacks_list(tmpdir):
|
||||
def test_checkpoint_repeated_strategy(tmpdir):
|
||||
"""
|
||||
This test validates that the checkpoint can be called when provided to callacks list
|
||||
"""
|
||||
|
@ -572,6 +576,159 @@ def test_checkpoint_within_callbacks_list(tmpdir):
|
|||
trainer.fit(model)
|
||||
assert os.listdir(tmpdir) == ['epoch=00.ckpt']
|
||||
|
||||
def get_last_checkpoint():
|
||||
ckpts = os.listdir(tmpdir)
|
||||
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
|
||||
num_ckpts = len(ckpts_map) - 1
|
||||
return ckpts_map[num_ckpts]
|
||||
|
||||
for idx in range(1, 5):
|
||||
# load from checkpoint
|
||||
chk = get_last_checkpoint()
|
||||
model = BoringModel.load_from_checkpoint(chk)
|
||||
trainer = pl.Trainer(max_epochs=1,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
resume_from_checkpoint=chk)
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
|
||||
assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']"
|
||||
|
||||
|
||||
def test_checkpoint_repeated_strategy_tmpdir(tmpdir):
|
||||
"""
|
||||
This test validates that the checkpoint can be called when provided to callacks list
|
||||
"""
|
||||
|
||||
os.environ['PL_DEV_DEBUG'] = '1'
|
||||
|
||||
checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join(tmpdir, "{epoch:02d}"))
|
||||
|
||||
class ExtendedBoringModel(BoringModel):
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"val_loss": loss}
|
||||
|
||||
model = ExtendedBoringModel()
|
||||
model.validation_step_end = None
|
||||
model.validation_epoch_end = None
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
callbacks=[checkpoint_callback])
|
||||
|
||||
trainer.fit(model)
|
||||
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
|
||||
path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs')
|
||||
assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0'])
|
||||
|
||||
def get_last_checkpoint():
|
||||
ckpts = os.listdir(tmpdir)
|
||||
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
|
||||
num_ckpts = len(ckpts_map) - 1
|
||||
return ckpts_map[num_ckpts]
|
||||
|
||||
for idx in range(1, 5):
|
||||
|
||||
# load from checkpoint
|
||||
chk = get_last_checkpoint()
|
||||
model = BoringModel.load_from_checkpoint(chk)
|
||||
trainer = pl.Trainer(default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
resume_from_checkpoint=chk)
|
||||
|
||||
trainer.fit(model)
|
||||
trainer.test(model)
|
||||
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
|
||||
assert sorted(os.listdir(path_to_lightning_logs)) == sorted([f'version_{i}' for i in range(idx + 1)])
|
||||
|
||||
|
||||
def test_checkpoint_repeated_strategy_extended(tmpdir):
|
||||
"""
|
||||
This test validates checkpoint can be called several times without
|
||||
increasing internally its global step if nothing run.
|
||||
"""
|
||||
|
||||
os.environ['PL_DEV_DEBUG'] = '1'
|
||||
|
||||
class ExtendedBoringModel(BoringModel):
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"val_loss": loss}
|
||||
|
||||
model = ExtendedBoringModel()
|
||||
model.validation_step_end = None
|
||||
model.validation_epoch_end = None
|
||||
trainer = pl.Trainer(default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
)
|
||||
|
||||
assert trainer.checkpoint_connector.has_trained is not True
|
||||
assert trainer.current_epoch == 0
|
||||
trainer.fit(model)
|
||||
assert trainer.checkpoint_connector.has_trained is True
|
||||
assert trainer.global_step == 2
|
||||
assert trainer.current_epoch == 0
|
||||
trainer.test(model)
|
||||
assert trainer.current_epoch == 0
|
||||
assert str(os.listdir(osp.join(tmpdir, 'lightning_logs'))) == "['version_0']"
|
||||
|
||||
def get_last_checkpoint():
|
||||
logs_dir = osp.join(tmpdir, 'lightning_logs')
|
||||
versions = os.listdir(logs_dir)
|
||||
versions.sort()
|
||||
|
||||
last_version = versions[-1]
|
||||
ckpt_dir = osp.join(logs_dir, last_version, "checkpoints")
|
||||
|
||||
ckpts = os.listdir(ckpt_dir)
|
||||
ckpts.sort()
|
||||
|
||||
return osp.join(ckpt_dir, ckpts[-1])
|
||||
|
||||
def assert_checkpoint_content():
|
||||
chk = pl_load(get_last_checkpoint())
|
||||
assert chk["epoch"] == 1
|
||||
assert chk["global_step"] == 2
|
||||
|
||||
assert_checkpoint_content()
|
||||
|
||||
for idx in range(1, 5):
|
||||
# load from checkpoint
|
||||
chk = get_last_checkpoint()
|
||||
assert_checkpoint_content()
|
||||
model = BoringModel.load_from_checkpoint(chk)
|
||||
trainer = pl.Trainer(default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=2,
|
||||
resume_from_checkpoint=chk)
|
||||
assert trainer.checkpoint_connector.has_trained is not True
|
||||
assert trainer.global_step == 0
|
||||
trainer.test(model)
|
||||
assert trainer.global_step == 2
|
||||
trainer.fit(model)
|
||||
assert trainer.global_step == 2
|
||||
assert trainer.checkpoint_connector.has_trained is not True
|
||||
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
|
||||
assert sorted(os.listdir(lightning_logs_path)) == [f"version_{i}" for i in range(idx + 1)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'filepath, dirpath, filename',
|
||||
|
|
Loading…
Reference in New Issue