parent
79a3ff690b
commit
a44881cd90
|
@ -419,6 +419,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Removed `Strategy.on_tpu` property ([#11536](https://github.com/PyTorchLightning/pytorch-lightning/pull/11536))
|
- Removed `Strategy.on_tpu` property ([#11536](https://github.com/PyTorchLightning/pytorch-lightning/pull/11536))
|
||||||
|
|
||||||
|
|
||||||
|
- Removed `FitLoop.current_epoch` getter and setter ([#11562](https://github.com/PyTorchLightning/pytorch-lightning/pull/11562))
|
||||||
|
|
||||||
|
|
||||||
- Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517))
|
- Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -366,16 +366,14 @@ class ModelCheckpoint(Callback):
|
||||||
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
|
This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the
|
||||||
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
|
behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
|
||||||
"""
|
"""
|
||||||
epoch = trainer.current_epoch
|
|
||||||
global_step = trainer.global_step
|
|
||||||
|
|
||||||
self._validate_monitor_key(trainer)
|
self._validate_monitor_key(trainer)
|
||||||
|
|
||||||
# track epoch when ckpt was last checked
|
# track epoch when ckpt was last checked
|
||||||
|
global_step = trainer.global_step
|
||||||
self._last_global_step_saved = global_step
|
self._last_global_step_saved = global_step
|
||||||
|
|
||||||
# what can be monitored
|
# what can be monitored
|
||||||
monitor_candidates = self._monitor_candidates(trainer, epoch=epoch, step=global_step)
|
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=global_step)
|
||||||
|
|
||||||
# callback supports multiple simultaneous modes
|
# callback supports multiple simultaneous modes
|
||||||
# here we call each mode sequentially
|
# here we call each mode sequentially
|
||||||
|
|
|
@ -199,10 +199,7 @@ class LightningModule(
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_epoch(self) -> int:
|
def current_epoch(self) -> int:
|
||||||
"""The current epoch in the Trainer.
|
"""The current epoch in the ``Trainer``, or 0 if not attached."""
|
||||||
|
|
||||||
If no Trainer is attached, this propery is 0.
|
|
||||||
"""
|
|
||||||
return self.trainer.current_epoch if self.trainer else 0
|
return self.trainer.current_epoch if self.trainer else 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -206,6 +206,7 @@ class Loop(ABC, Generic[T]):
|
||||||
self._restarting = False
|
self._restarting = False
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
self._restarting = False
|
||||||
|
|
||||||
output = self.on_run_end()
|
output = self.on_run_end()
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -98,11 +98,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
"""Returns whether the training should be stopped.
|
"""Evaluates when to leave the loop."""
|
||||||
|
|
||||||
The criteria are that the number of steps reached the max steps, the last batch is reached or the trainer
|
|
||||||
signals to stop (e.g. by early stopping).
|
|
||||||
"""
|
|
||||||
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop
|
return (self._is_training_done and self._is_validation_done) or self.trainer.should_stop
|
||||||
|
|
||||||
def connect( # type: ignore[override]
|
def connect( # type: ignore[override]
|
||||||
|
|
|
@ -56,16 +56,6 @@ class FitLoop(Loop[None]):
|
||||||
self._is_fresh_start_epoch: bool = True
|
self._is_fresh_start_epoch: bool = True
|
||||||
self._outputs: _EPOCH_OUTPUTS_TYPE = []
|
self._outputs: _EPOCH_OUTPUTS_TYPE = []
|
||||||
|
|
||||||
@property
|
|
||||||
def current_epoch(self) -> int:
|
|
||||||
"""Return the current epoch."""
|
|
||||||
return self.epoch_progress.current.completed
|
|
||||||
|
|
||||||
@current_epoch.setter
|
|
||||||
def current_epoch(self, value: int) -> None:
|
|
||||||
"""Setter for the current epoch."""
|
|
||||||
self.epoch_progress.current.completed = value
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def global_step(self) -> int:
|
def global_step(self) -> int:
|
||||||
"""Returns the global step."""
|
"""Returns the global step."""
|
||||||
|
@ -149,19 +139,15 @@ class FitLoop(Loop[None]):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
"""Evaluates when to leave the loop.
|
"""Evaluates when to leave the loop."""
|
||||||
|
|
||||||
Returns True if trainer.should_stop was set (e.g. by early stopping) or if the maximum number of steps or epochs
|
|
||||||
is reached.
|
|
||||||
"""
|
|
||||||
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
|
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
|
||||||
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
|
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
|
||||||
stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs)
|
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.completed, self.max_epochs)
|
||||||
|
|
||||||
should_stop = False
|
should_stop = False
|
||||||
if self.trainer.should_stop:
|
if self.trainer.should_stop:
|
||||||
# early stopping
|
# early stopping
|
||||||
met_min_epochs = self.current_epoch >= self.min_epochs if self.min_epochs else True
|
met_min_epochs = self.epoch_progress.current.completed >= self.min_epochs if self.min_epochs else True
|
||||||
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
|
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
|
||||||
if met_min_epochs and met_min_steps:
|
if met_min_epochs and met_min_steps:
|
||||||
should_stop = True
|
should_stop = True
|
||||||
|
@ -219,7 +205,7 @@ class FitLoop(Loop[None]):
|
||||||
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
|
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
|
||||||
):
|
):
|
||||||
# set seed for distributed sampler (enables shuffling for each epoch)
|
# set seed for distributed sampler (enables shuffling for each epoch)
|
||||||
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)
|
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.completed)
|
||||||
|
|
||||||
# changing gradient according accumulation_scheduler
|
# changing gradient according accumulation_scheduler
|
||||||
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
|
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
|
||||||
|
@ -307,7 +293,7 @@ class FitLoop(Loop[None]):
|
||||||
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
|
# Lightning today does not increment the current epoch at the last epoch run in Trainer.fit
|
||||||
# To simulate that current behavior, we decrement here.
|
# To simulate that current behavior, we decrement here.
|
||||||
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
|
# TODO: must be fixed by https://github.com/PyTorchLightning/pytorch-lightning/issues/5007
|
||||||
self.current_epoch = max(self.current_epoch - 1, 0)
|
self.epoch_progress.current.completed = max(self.epoch_progress.current.completed - 1, 0)
|
||||||
|
|
||||||
# hook
|
# hook
|
||||||
self.trainer._call_callback_hooks("on_train_end")
|
self.trainer._call_callback_hooks("on_train_end")
|
||||||
|
|
|
@ -218,7 +218,9 @@ class CheckpointConnector:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"]
|
self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"]
|
||||||
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"]
|
# set the `current_epoch` value for old checkpoints without the progress tracking state.
|
||||||
|
# it will be overwritten by the loop's state if it was also saved
|
||||||
|
self.trainer.fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
|
||||||
|
|
||||||
assert self.trainer.state.fn is not None
|
assert self.trainer.state.fn is not None
|
||||||
state_dict = self._loaded_checkpoint.get("loops")
|
state_dict = self._loaded_checkpoint.get("loops")
|
||||||
|
|
|
@ -336,7 +336,6 @@ class Trainer(
|
||||||
To enable infinite training, set ``max_epochs = -1``.
|
To enable infinite training, set ``max_epochs = -1``.
|
||||||
|
|
||||||
min_epochs: Force training for at least these many epochs. Disabled by default (None).
|
min_epochs: Force training for at least these many epochs. Disabled by default (None).
|
||||||
If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``.
|
|
||||||
|
|
||||||
max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
|
max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
|
||||||
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
|
and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
|
||||||
|
@ -2349,7 +2348,8 @@ class Trainer(
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_epoch(self) -> int:
|
def current_epoch(self) -> int:
|
||||||
return self.fit_loop.current_epoch
|
"""The current epoch, updated after the epoch end hooks are run."""
|
||||||
|
return self.fit_loop.epoch_progress.current.completed
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_epochs(self) -> int:
|
def max_epochs(self) -> int:
|
||||||
|
|
|
@ -60,10 +60,10 @@ def scale_batch_size(
|
||||||
|
|
||||||
# Save initial model, that is loaded after batch size is found
|
# Save initial model, that is loaded after batch size is found
|
||||||
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
|
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
|
||||||
trainer.fit_loop.current_epoch -= 1
|
trainer.fit_loop.epoch_progress.current.completed -= 1
|
||||||
trainer.fit_loop.global_step -= 1
|
trainer.fit_loop.global_step -= 1
|
||||||
trainer.save_checkpoint(ckpt_path)
|
trainer.save_checkpoint(ckpt_path)
|
||||||
trainer.fit_loop.current_epoch += 1
|
trainer.fit_loop.epoch_progress.current.completed += 1
|
||||||
trainer.fit_loop.global_step += 1
|
trainer.fit_loop.global_step += 1
|
||||||
params = __scale_batch_dump_params(trainer)
|
params = __scale_batch_dump_params(trainer)
|
||||||
|
|
||||||
|
@ -110,7 +110,6 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
|
||||||
def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None:
|
def __scale_batch_reset_params(trainer: "pl.Trainer", steps_per_trial: int) -> None:
|
||||||
trainer.auto_scale_batch_size = None # prevent recursion
|
trainer.auto_scale_batch_size = None # prevent recursion
|
||||||
trainer.auto_lr_find = False # avoid lr find being called multiple times
|
trainer.auto_lr_find = False # avoid lr find being called multiple times
|
||||||
trainer.fit_loop.current_epoch = 0
|
|
||||||
trainer.fit_loop.max_steps = steps_per_trial # take few steps
|
trainer.fit_loop.max_steps = steps_per_trial # take few steps
|
||||||
trainer.logger = DummyLogger() if trainer.logger is not None else None
|
trainer.logger = DummyLogger() if trainer.logger is not None else None
|
||||||
trainer.callbacks = [] # not needed before full run
|
trainer.callbacks = [] # not needed before full run
|
||||||
|
|
|
@ -204,10 +204,10 @@ def lr_find(
|
||||||
|
|
||||||
# Save initial model, that is loaded after learning rate is found
|
# Save initial model, that is loaded after learning rate is found
|
||||||
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
|
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
|
||||||
trainer.fit_loop.current_epoch -= 1
|
trainer.fit_loop.epoch_progress.current.completed -= 1
|
||||||
trainer.fit_loop.global_step -= 1
|
trainer.fit_loop.global_step -= 1
|
||||||
trainer.save_checkpoint(ckpt_path)
|
trainer.save_checkpoint(ckpt_path)
|
||||||
trainer.fit_loop.current_epoch += 1
|
trainer.fit_loop.epoch_progress.current.completed += 1
|
||||||
trainer.fit_loop.global_step += 1
|
trainer.fit_loop.global_step += 1
|
||||||
params = __lr_finder_dump_params(trainer)
|
params = __lr_finder_dump_params(trainer)
|
||||||
|
|
||||||
|
|
|
@ -162,9 +162,9 @@ def test_model_checkpoint_score_and_ckpt(
|
||||||
for epoch in range(max_epochs):
|
for epoch in range(max_epochs):
|
||||||
score = model.scores[epoch]
|
score = model.scores[epoch]
|
||||||
expected_score = getattr(model, f"{monitor}s")[epoch].mean().item()
|
expected_score = getattr(model, f"{monitor}s")[epoch].mean().item()
|
||||||
expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
|
|
||||||
assert math.isclose(score, expected_score, rel_tol=1e-4)
|
assert math.isclose(score, expected_score, rel_tol=1e-4)
|
||||||
|
|
||||||
|
expected_filename = f"{monitor}={score:.4f}-epoch={epoch}.ckpt"
|
||||||
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
|
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
|
||||||
assert chk["epoch"] == epoch + 1
|
assert chk["epoch"] == epoch + 1
|
||||||
assert chk["global_step"] == limit_train_batches * (epoch + 1)
|
assert chk["global_step"] == limit_train_batches * (epoch + 1)
|
||||||
|
@ -462,7 +462,6 @@ class ModelCheckpointExtensionTest(ModelCheckpoint):
|
||||||
|
|
||||||
def test_model_checkpoint_file_extension(tmpdir):
|
def test_model_checkpoint_file_extension(tmpdir):
|
||||||
"""Test ModelCheckpoint with different file extension."""
|
"""Test ModelCheckpoint with different file extension."""
|
||||||
|
|
||||||
model = LogInTwoMethods()
|
model = LogInTwoMethods()
|
||||||
model_checkpoint = ModelCheckpointExtensionTest(
|
model_checkpoint = ModelCheckpointExtensionTest(
|
||||||
monitor="early_stop_on", dirpath=tmpdir, save_top_k=1, save_last=True
|
monitor="early_stop_on", dirpath=tmpdir, save_top_k=1, save_last=True
|
||||||
|
@ -613,7 +612,7 @@ def test_model_checkpoint_every_n_epochs(tmpdir, every_n_epochs):
|
||||||
)
|
)
|
||||||
trainer.fit(model)
|
trainer.fit(model)
|
||||||
|
|
||||||
# check that the correct ckpts were created
|
# check that the correct ckpts were created, the modulo condition is checked in `ModelCheckpoint`
|
||||||
expected = [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
|
expected = [f"epoch={e}.ckpt" for e in range(epochs) if not (e + 1) % every_n_epochs] if every_n_epochs > 0 else []
|
||||||
assert set(os.listdir(tmpdir)) == set(expected)
|
assert set(os.listdir(tmpdir)) == set(expected)
|
||||||
|
|
||||||
|
@ -967,15 +966,13 @@ def test_checkpoint_repeated_strategy_extended(tmpdir):
|
||||||
assert_checkpoint_content(ckpt_dir)
|
assert_checkpoint_content(ckpt_dir)
|
||||||
|
|
||||||
# load from checkpoint
|
# load from checkpoint
|
||||||
trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)]
|
|
||||||
trainer = pl.Trainer(**trainer_config)
|
trainer = pl.Trainer(**trainer_config)
|
||||||
assert_trainer_init(trainer)
|
assert_trainer_init(trainer)
|
||||||
|
|
||||||
model = ExtendedBoringModel()
|
model = ExtendedBoringModel()
|
||||||
|
|
||||||
trainer.test(model)
|
trainer.test(model)
|
||||||
assert trainer.global_step == 0
|
assert_trainer_init(trainer)
|
||||||
assert trainer.current_epoch == 0
|
|
||||||
|
|
||||||
trainer.fit(model, ckpt_path=chk)
|
trainer.fit(model, ckpt_path=chk)
|
||||||
assert trainer.global_step == epochs * limit_train_batches
|
assert trainer.global_step == epochs * limit_train_batches
|
||||||
|
|
|
@ -12,7 +12,6 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -53,8 +52,6 @@ def test_finetuning_with_ckpt_path(tmpdir):
|
||||||
assert os.listdir(tmpdir) == ["epoch=00.ckpt"]
|
assert os.listdir(tmpdir) == ["epoch=00.ckpt"]
|
||||||
|
|
||||||
best_model_paths = [checkpoint_callback.best_model_path]
|
best_model_paths = [checkpoint_callback.best_model_path]
|
||||||
results = []
|
|
||||||
|
|
||||||
for idx in range(3, 6):
|
for idx in range(3, 6):
|
||||||
# load from checkpoint
|
# load from checkpoint
|
||||||
trainer = pl.Trainer(
|
trainer = pl.Trainer(
|
||||||
|
@ -67,7 +64,6 @@ def test_finetuning_with_ckpt_path(tmpdir):
|
||||||
)
|
)
|
||||||
trainer.fit(model, ckpt_path=best_model_paths[-1])
|
trainer.fit(model, ckpt_path=best_model_paths[-1])
|
||||||
trainer.test()
|
trainer.test()
|
||||||
results.append(deepcopy(trainer.callback_metrics))
|
|
||||||
best_model_paths.append(trainer.checkpoint_callback.best_model_path)
|
best_model_paths.append(trainer.checkpoint_callback.best_model_path)
|
||||||
|
|
||||||
for idx, best_model_path in enumerate(best_model_paths):
|
for idx, best_model_path in enumerate(best_model_paths):
|
||||||
|
|
|
@ -33,7 +33,6 @@ from pytorch_lightning.trainer.connectors.logger_connector.result import (
|
||||||
_ResultMetric,
|
_ResultMetric,
|
||||||
_Sync,
|
_Sync,
|
||||||
)
|
)
|
||||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
|
||||||
from tests.helpers import BoringModel
|
from tests.helpers import BoringModel
|
||||||
from tests.helpers.runif import RunIf
|
from tests.helpers.runif import RunIf
|
||||||
|
|
||||||
|
@ -373,13 +372,9 @@ class DummyMeanMetric(Metric):
|
||||||
|
|
||||||
|
|
||||||
def result_collection_reload(accelerator="auto", devices=1, **kwargs):
|
def result_collection_reload(accelerator="auto", devices=1, **kwargs):
|
||||||
|
|
||||||
"""This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault
|
"""This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault
|
||||||
Tolerant Training is correct."""
|
Tolerant Training is correct."""
|
||||||
|
|
||||||
if not _fault_tolerant_training():
|
|
||||||
pytest.skip("Fault tolerant not available")
|
|
||||||
|
|
||||||
class CustomException(Exception):
|
class CustomException(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -222,6 +222,75 @@ def test_trainer_properties_restore_ckpt_path(tmpdir):
|
||||||
trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
|
trainer_fn(model, datamodule=dm, ckpt_path=resume_ckpt)
|
||||||
|
|
||||||
|
|
||||||
|
def test_correct_step_and_epoch(tmpdir):
|
||||||
|
model = BoringModel()
|
||||||
|
first_max_epochs = 2
|
||||||
|
train_batches = 2
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir, max_epochs=first_max_epochs, limit_train_batches=train_batches, limit_val_batches=0
|
||||||
|
)
|
||||||
|
assert trainer.current_epoch == 0
|
||||||
|
assert trainer.global_step == 0
|
||||||
|
|
||||||
|
trainer.fit(model)
|
||||||
|
# TODO(@carmocca): should not need `-1`
|
||||||
|
assert trainer.current_epoch == first_max_epochs - 1
|
||||||
|
assert trainer.global_step == first_max_epochs * train_batches
|
||||||
|
|
||||||
|
# save checkpoint after loop ends, training end called, epoch count increased
|
||||||
|
ckpt_path = str(tmpdir / "model.ckpt")
|
||||||
|
trainer.save_checkpoint(ckpt_path)
|
||||||
|
|
||||||
|
ckpt = torch.load(ckpt_path)
|
||||||
|
assert ckpt["epoch"] == first_max_epochs
|
||||||
|
# TODO(@carmocca): should not need `+1`
|
||||||
|
assert ckpt["global_step"] == first_max_epochs * train_batches + 1
|
||||||
|
|
||||||
|
max_epochs = first_max_epochs + 2
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=train_batches, limit_val_batches=0
|
||||||
|
)
|
||||||
|
# the ckpt state is not loaded at this point
|
||||||
|
assert trainer.current_epoch == 0
|
||||||
|
assert trainer.global_step == 0
|
||||||
|
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def on_pretrain_routine_end(self) -> None:
|
||||||
|
assert self.trainer.current_epoch == first_max_epochs
|
||||||
|
# TODO(@carmocca): should not need `+1`
|
||||||
|
assert self.trainer.global_step == first_max_epochs * train_batches + 1
|
||||||
|
|
||||||
|
trainer.fit(TestModel(), ckpt_path=ckpt_path)
|
||||||
|
# TODO(@carmocca): should not need `-1`
|
||||||
|
assert trainer.current_epoch == max_epochs - 1
|
||||||
|
# TODO(@carmocca): should not need `+1`
|
||||||
|
assert trainer.global_step == max_epochs * train_batches + 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_twice(tmpdir):
|
||||||
|
epochs = []
|
||||||
|
|
||||||
|
class TestModel(BoringModel):
|
||||||
|
def on_train_epoch_end(self, *_):
|
||||||
|
epochs.append(self.current_epoch)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
max_epochs=2,
|
||||||
|
limit_train_batches=1,
|
||||||
|
limit_val_batches=1,
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
logger=False,
|
||||||
|
enable_checkpointing=False,
|
||||||
|
enable_model_summary=False,
|
||||||
|
enable_progress_bar=False,
|
||||||
|
)
|
||||||
|
trainer.fit(TestModel())
|
||||||
|
trainer.fit_loop.max_epochs = 4
|
||||||
|
trainer.fit(TestModel())
|
||||||
|
# TODO(@carmocca): 1 should not be duplicated
|
||||||
|
assert epochs == [0, 1, 1, 2, 3]
|
||||||
|
|
||||||
|
|
||||||
def test_try_resume_from_non_existing_checkpoint(tmpdir):
|
def test_try_resume_from_non_existing_checkpoint(tmpdir):
|
||||||
"""Test that trying to resume from non-existing `ckpt_path` fails with an error."""
|
"""Test that trying to resume from non-existing `ckpt_path` fails with an error."""
|
||||||
model = BoringModel()
|
model = BoringModel()
|
||||||
|
|
|
@ -331,7 +331,7 @@ def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files)
|
||||||
|
|
||||||
# emulate callback's calls during the training
|
# emulate callback's calls during the training
|
||||||
for i, loss in enumerate(losses):
|
for i, loss in enumerate(losses):
|
||||||
trainer.fit_loop.current_epoch = i
|
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
|
||||||
trainer.fit_loop.global_step = i
|
trainer.fit_loop.global_step = i
|
||||||
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
|
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
|
||||||
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
|
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
|
||||||
|
|
Loading…
Reference in New Issue