# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import glob import logging as log import os import pickle from copy import deepcopy from typing import Generic, Mapping, TypeVar from unittest import mock import cloudpickle import pytest import torch import torch.nn.functional as F import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils from lightning_lite import seed_everything from pytorch_lightning import Callback, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.trainer.states import TrainerFn from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel from tests_pytorch.loops.test_loops import CustomException class ModelTrainerPropertyParity(Callback): def _check_properties(self, trainer, pl_module): assert trainer.global_step == pl_module.global_step assert trainer.current_epoch == pl_module.current_epoch def on_train_start(self, trainer, pl_module): self._check_properties(trainer, pl_module) def on_train_batch_start(self, trainer, pl_module, *args, **kwargs): self._check_properties(trainer, pl_module) def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): self._check_properties(trainer, pl_module) def on_train_end(self, trainer, pl_module): self._check_properties(trainer, pl_module) class ValTestLossBoringModel(BoringModel): def __init__(self, batch_size=4): super().__init__() self.save_hyperparameters() def validation_step(self, batch, batch_idx): out = super().validation_step(batch, batch_idx) self.log("val_loss", out["x"]) return out def test_step(self, batch, batch_idx): out = super().test_step(batch, batch_idx) self.log("test_loss", out["y"]) return out T = TypeVar("T") class GenericParentValTestLossBoringModel(Generic[T], ValTestLossBoringModel): def __init__(self, batch_size: int = 4): super().__init__(batch_size=batch_size) class GenericValTestLossBoringModel(GenericParentValTestLossBoringModel[int]): pass class CustomClassificationModelDP(ClassificationModel): def _step(self, batch): x, y = batch logits = self(x) return {"logits": logits, "y": y} def training_step(self, batch, batch_idx): out = self._step(batch) loss = F.cross_entropy(out["logits"], out["y"]) return loss def validation_step(self, batch, batch_idx): return self._step(batch) def test_step(self, batch, batch_idx): return self._step(batch) def validation_step_end(self, outputs): self.log("val_acc", self.valid_acc(outputs["logits"], outputs["y"])) def test_model_properties_fit_ckpt_path(tmpdir): """Test that properties like `current_epoch` and `global_step` in model and trainer are always the same.""" model = BoringModel() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) trainer_args = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, logger=False, callbacks=[checkpoint_callback, ModelTrainerPropertyParity()], # this performs the assertions ) trainer = Trainer(**trainer_args) trainer.fit(model) trainer_args.update(max_epochs=2) trainer = Trainer(**trainer_args) trainer.fit(model, ckpt_path=str(tmpdir / "last.ckpt")) @RunIf(sklearn=True) def test_trainer_properties_restore_ckpt_path(tmpdir): """Test that required trainer properties are set correctly when resuming from checkpoint in different phases.""" class CustomClassifModel(ClassificationModel): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) return [optimizer], [lr_scheduler] model = CustomClassifModel() dm = ClassifDataModule() checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, save_last=True) trainer_args = dict( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, limit_predict_batches=2, logger=False, callbacks=[checkpoint_callback], num_sanity_val_steps=0, ) trainer = Trainer(**trainer_args) trainer.fit(model, datamodule=dm) resume_ckpt = str(tmpdir / "last.ckpt") state_dict = torch.load(resume_ckpt) trainer_args.update({"max_epochs": 3, "enable_checkpointing": False, "callbacks": []}) class CustomClassifModel(CustomClassifModel): def _is_equal(self, a, b): if isinstance(a, torch.Tensor): return torch.all(torch.eq(a, b)) if isinstance(a, Mapping): return all(self._is_equal(a.get(k, None), b.get(k, None)) for k in b.keys()) return a == b def _check_optimizers(self): return all( self._is_equal(optimizer.state_dict(), state) for optimizer, state in zip(self.trainer.optimizers, state_dict["optimizer_states"]) ) def _check_schedulers(self): return all( self._is_equal(config.scheduler.state_dict(), state) for config, state in zip(self.trainer.lr_scheduler_configs, state_dict["lr_schedulers"]) ) def _check_model_state_dict(self): return all( self._is_equal(actual, expected) for actual, expected in zip(self.state_dict(), state_dict["state_dict"]) ) def _test_on_val_test_predict_start(self): assert self.trainer.current_epoch == state_dict["epoch"] assert self.trainer.global_step == 0 assert self._check_model_state_dict() def on_train_start(self): assert self.trainer.current_epoch == state_dict["epoch"] + 1 assert self.trainer.global_step == state_dict["global_step"] assert self._check_model_state_dict() assert self._check_optimizers() assert self._check_schedulers() def on_validation_start(self): if self.trainer.state.fn == TrainerFn.VALIDATING: self._test_on_val_test_predict_start() def on_test_start(self): self._test_on_val_test_predict_start() for fn in ("fit", "validate", "test", "predict"): model = CustomClassifModel() dm = ClassifDataModule() trainer = Trainer(**trainer_args) trainer_fn = getattr(trainer, fn) trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt) def test_correct_step_and_epoch(tmpdir): model = BoringModel() first_max_epochs = 2 train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=first_max_epochs, limit_train_batches=train_batches, limit_val_batches=0 ) assert trainer.current_epoch == 0 assert trainer.global_step == 0 trainer.fit(model) assert trainer.current_epoch == first_max_epochs assert trainer.global_step == first_max_epochs * train_batches # save checkpoint after loop ends, training end called, epoch count increased ckpt_path = str(tmpdir / "model.ckpt") trainer.save_checkpoint(ckpt_path) ckpt = torch.load(ckpt_path) assert ckpt["epoch"] == first_max_epochs assert ckpt["global_step"] == first_max_epochs * train_batches max_epochs = first_max_epochs + 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=train_batches, limit_val_batches=0 ) # the ckpt state is not loaded at this point assert trainer.current_epoch == 0 assert trainer.global_step == 0 class TestModel(BoringModel): def on_train_start(self) -> None: assert self.trainer.current_epoch == first_max_epochs assert self.trainer.global_step == first_max_epochs * train_batches assert self.trainer.fit_loop.epoch_loop._batches_that_stepped == first_max_epochs * train_batches trainer.fit(TestModel(), ckpt_path=ckpt_path) assert trainer.current_epoch == max_epochs assert trainer.global_step == max_epochs * train_batches assert trainer.fit_loop.epoch_loop._batches_that_stepped == max_epochs * train_batches def test_fit_twice(tmpdir): epochs = [] class TestModel(BoringModel): def on_train_epoch_end(self, *_): epochs.append(self.current_epoch) trainer = Trainer( max_epochs=2, limit_train_batches=1, limit_val_batches=1, default_root_dir=tmpdir, logger=False, enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False, ) trainer.fit(TestModel()) trainer.fit_loop.max_epochs = 4 trainer.fit(TestModel()) assert epochs == [0, 1, 2, 3] def test_try_resume_from_non_existing_checkpoint(tmpdir): """Test that trying to resume from non-existing `ckpt_path` fails with an error.""" model = BoringModel() trainer = Trainer() with pytest.raises(FileNotFoundError, match="Aborting training"): trainer.fit(model, ckpt_path=str(tmpdir / "non_existing.ckpt")) class CaptureCallbacksBeforeTraining(Callback): callbacks = [] def on_fit_start(self, trainer, pl_module): self.callbacks = deepcopy(trainer.callbacks) @RunIf(sklearn=True) def test_callbacks_state_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint restores callbacks that persist state.""" dm = ClassifDataModule() model = ClassificationModel() callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=2, max_epochs=1, logger=False, callbacks=[checkpoint, callback_capture], ) assert checkpoint.best_model_path == "" assert checkpoint.best_model_score is None return trainer_args # initial training trainer = Trainer(**get_trainer_args()) trainer.fit(model, datamodule=dm) callbacks_before_resume = deepcopy(trainer.callbacks) # resumed training trainer = Trainer(**get_trainer_args()) trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) assert len(callbacks_before_resume) == len(callback_capture.callbacks) for before, after in zip(callbacks_before_resume, callback_capture.callbacks): if isinstance(before, ModelCheckpoint): for attribute in ( "best_model_path", "best_model_score", "best_k_models", "kth_best_model_path", "kth_value", "last_model_path", ): assert getattr(before, attribute) == getattr(after, attribute) @RunIf(sklearn=True) def test_callbacks_references_fit_ckpt_path(tmpdir): """Test that resuming from a checkpoint sets references as expected.""" dm = ClassifDataModule() model = ClassificationModel() args = { "default_root_dir": tmpdir, "max_steps": 1, "logger": False, "limit_val_batches": 2, "num_sanity_val_steps": 0, } # initial training checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer = Trainer(**args, callbacks=[checkpoint]) assert checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback trainer.fit(model, datamodule=dm) # resumed training new_checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) # pass in a new checkpoint object, which should take # precedence over the one in the last.ckpt file trainer = Trainer(**args, callbacks=[new_checkpoint]) assert checkpoint is not new_checkpoint assert new_checkpoint is trainer.callbacks[-1] is trainer.checkpoint_callback trainer.fit(model, datamodule=dm, ckpt_path=str(tmpdir / "last.ckpt")) @RunIf(min_cuda_gpus=2, sklearn=True) def test_running_test_pretrained_model_distrib_dp(tmpdir): """Verify `test()` on pretrained model.""" seed_everything(7) dm = ClassifDataModule() model = CustomClassificationModelDP(lr=0.1) # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=5, limit_val_batches=5, callbacks=[checkpoint], logger=logger, accelerator="gpu", devices=[0, 1], strategy="dp", default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" pretrained_model = CustomClassificationModelDP.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_model_prediction(pretrained_model, dataloader) @RunIf(min_cuda_gpus=2, sklearn=True) def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir): """Verify `test()` on pretrained model.""" dm = ClassifDataModule() model = ClassificationModel() # exp file to get meta logger = tutils.get_default_logger(tmpdir) # exp file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=2, limit_val_batches=2, callbacks=[checkpoint], logger=logger, accelerator="gpu", devices=[0, 1], strategy="ddp_spawn", default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) log.info(os.listdir(tutils.get_data_path(logger, path_dir=tmpdir))) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # run test set new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) pretrained_model.cpu() dataloaders = dm.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] for dataloader in dataloaders: tpipes.run_model_prediction(pretrained_model, dataloader, min_acc=0.1) @RunIf(sklearn=True) def test_running_test_pretrained_model_cpu(tmpdir): """Verify test() on pretrained model.""" seed_everything(1) dm = ClassifDataModule() model = ClassificationModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, callbacks=[checkpoint], logger=logger, default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" pretrained_model = ClassificationModel.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model, datamodule=dm) # test we have good test accuracy tutils.assert_ok_model_acc(new_trainer, key="test_acc", thr=0.45) @pytest.mark.parametrize("model_template", [ValTestLossBoringModel, GenericValTestLossBoringModel]) def test_load_model_from_checkpoint(tmpdir, model_template): """Verify test() on pretrained model.""" model = model_template() trainer_options = dict( enable_progress_bar=False, max_epochs=2, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, callbacks=[ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_top_k=-1)], default_root_dir=tmpdir, ) # fit model trainer = Trainer(**trainer_options) trainer.fit(model) trainer.test(model) # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" # load last checkpoint last_checkpoint = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt")))[-1] # Since `BoringModel` has `_save_hparams = True` by default, check that ckpt has hparams ckpt = torch.load(last_checkpoint) assert model_template.CHECKPOINT_HYPER_PARAMS_KEY in ckpt.keys(), "hyper_parameters missing from checkpoints" # Ensure that model can be correctly restored from checkpoint pretrained_model = model_template.load_from_checkpoint(last_checkpoint) # test that hparams loaded correctly for k, v in model.hparams.items(): assert getattr(pretrained_model.hparams, k) == v # assert weights are the same for (old_name, old_p), (new_name, new_p) in zip(model.named_parameters(), pretrained_model.named_parameters()): assert torch.all(torch.eq(old_p, new_p)), "loaded weights are not the same as the saved weights" # Check `test` on pretrained model: new_trainer = Trainer(**trainer_options) new_trainer.test(pretrained_model) @RunIf(min_cuda_gpus=2, sklearn=True) def test_dp_resume(tmpdir): """Make sure DP continues training correctly.""" model = CustomClassificationModelDP(lr=0.1) dm = ClassifDataModule() trainer_options = dict(max_epochs=1, accelerator="gpu", devices=2, strategy="dp", default_root_dir=tmpdir) # get logger logger = tutils.get_default_logger(tmpdir) # exp file to get weights # logger file to get weights checkpoint = tutils.init_checkpoint_callback(logger) # add these to the trainer options trainer_options["logger"] = logger trainer_options["callbacks"] = [checkpoint] # fit model trainer = Trainer(**trainer_options) trainer.fit(model, datamodule=dm) # track epoch before saving real_global_epoch = trainer.current_epoch # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" # --------------------------- # HPC LOAD/SAVE # --------------------------- # save # save logger to make sure we get all the metrics if logger: logger.finalize("finished") hpc_save_path = trainer._checkpoint_connector.hpc_save_path(tmpdir) trainer.save_checkpoint(hpc_save_path) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) trainer_options["logger"] = new_logger trainer_options["callbacks"] = [ModelCheckpoint(dirpath=tmpdir)] trainer_options["limit_train_batches"] = 0.5 trainer_options["limit_val_batches"] = 0.2 trainer_options["max_epochs"] = 1 new_trainer = Trainer(**trainer_options) class CustomModel(CustomClassificationModelDP): def __init__(self): super().__init__() self.on_train_start_called = False def on_train_start(self): assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0 def on_validation_start(self): dataloader = dm.val_dataloader() tpipes.run_model_prediction(self.trainer.lightning_module, dataloader=dataloader) # new model model = CustomModel() # validate new model which should load hpc weights new_trainer.validate(model, datamodule=dm, ckpt_path=hpc_save_path) # test freeze on gpu model.freeze() model.unfreeze() def test_model_saving_loading(tmpdir): """Tests use case where trainer saves the model, and user loads it from tags independently.""" model = BoringModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( max_epochs=1, limit_train_batches=2, limit_val_batches=2, logger=logger, callbacks=[ModelCheckpoint(dirpath=tmpdir)], default_root_dir=tmpdir, ) trainer.fit(model) # traning complete assert trainer.state.finished, f"Training failed with {trainer.state}" # make a prediction dataloaders = model.test_dataloader() if not isinstance(dataloaders, list): dataloaders = [dataloaders] batch = next(iter(dataloaders[0])) # generate preds before saving model model.eval() pred_before_saving = model(batch) # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # load new model hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) hparams_path = os.path.join(hparams_path, "hparams.yaml") model_2 = BoringModel.load_from_checkpoint(checkpoint_path=new_weights_path, hparams_file=hparams_path) model_2.eval() # make prediction # assert that both predictions are the same new_pred = model_2(batch) assert torch.all(torch.eq(pred_before_saving, new_pred)).item() == 1 @pytest.mark.parametrize("url_ckpt", [True, False]) def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) model = BoringModel() # Extra layer model.c_d3 = torch.nn.Linear(32, 32) # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, logger=logger, callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) trainer.fit(model) # traning complete assert trainer.state.finished, f"Training failed with {trainer.state}" # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # load new model hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), "hparams.yaml") hparams_url = f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" ckpt_path = hparams_url if url_ckpt else new_weights_path BoringModel.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=False) with pytest.raises(RuntimeError, match=r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): BoringModel.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=True) @pytest.mark.parametrize("url_ckpt", [True, False]) def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server, url_ckpt): """Tests use case where trainer saves the model, and user loads it from tags independently.""" # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) model = BoringModel() # logger file to get meta logger = tutils.get_default_logger(tmpdir) # fit model trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, logger=logger, callbacks=[ModelCheckpoint(dirpath=tmpdir)], ) trainer.fit(model) # traning complete assert trainer.state.finished, f"Training failed with {trainer.state}" # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") trainer.save_checkpoint(new_weights_path) # load new model hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir), "hparams.yaml") ckpt_url = f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}" ckpt_path = ckpt_url if url_ckpt else new_weights_path class CurrentModel(BoringModel): def __init__(self): super().__init__() self.c_d3 = torch.nn.Linear(7, 7) CurrentModel.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=False) with pytest.raises(RuntimeError, match=r'Missing key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'): CurrentModel.load_from_checkpoint(checkpoint_path=ckpt_path, hparams_file=hparams_path, strict=True) def test_model_pickle(tmpdir): model = BoringModel() pickle.dumps(model) cloudpickle.dumps(model) class ExceptionModel(BoringModel): def __init__(self, stop_batch_idx): super().__init__() self.stop_batch_idx = stop_batch_idx def training_step(self, batch, batch_idx): if batch_idx == self.stop_batch_idx: raise CustomException() return super().training_step(batch, batch_idx) class ShouldStopModel(ExceptionModel): def training_step(self, batch, batch_idx): if batch_idx == self.stop_batch_idx: # setting should_stop is treated differently to raising an exception. # checking both tests that this warning is raised in the correct loop self.trainer.should_stop = True return super().training_step(batch, batch_idx) @pytest.mark.parametrize("stop_in_the_middle", (True, False)) @pytest.mark.parametrize("model_cls", (ExceptionModel, ShouldStopModel)) def test_restarting_mid_epoch_raises_warning(tmpdir, stop_in_the_middle, model_cls): """Test that a warning is raised if training is restarted from mid-epoch.""" limit_train_batches = 8 trainer_kwargs = { "default_root_dir": tmpdir, "limit_train_batches": limit_train_batches, "limit_val_batches": 0, "enable_progress_bar": False, "enable_model_summary": False, } trainer = Trainer(max_epochs=1, **trainer_kwargs) model = model_cls(limit_train_batches // 2 if stop_in_the_middle else -1) if stop_in_the_middle: with pytest.raises(CustomException): trainer.fit(model) else: trainer.fit(model) ckpt_path = str(tmpdir / "resume.ckpt") trainer.save_checkpoint(ckpt_path) trainer = Trainer(max_epochs=2, **trainer_kwargs) model.stop_batch_idx = -1 context_manager = pytest.warns if stop_in_the_middle else tutils.no_warning_call with context_manager(UserWarning, match="resuming from a checkpoint that ended"): trainer.fit(model, ckpt_path=ckpt_path) if stop_in_the_middle: with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): trainer = Trainer(max_epochs=2, **trainer_kwargs) with tutils.no_warning_call(UserWarning, match="resuming from a checkpoint that ended"): trainer.fit(model, ckpt_path=ckpt_path)