lightning/tests/loops/optimization/test_optimizer_loop.py

241 lines
8.2 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
from torch.optim import Adam, SGD
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
def test_closure_result_deepcopy():
closure_loss = torch.tensor(123.45)
result = ClosureResult(closure_loss)
assert closure_loss.data_ptr() == result.closure_loss.data_ptr()
# the `loss` is cloned so the storage is different
assert closure_loss.data_ptr() != result.loss.data_ptr()
copy = result.asdict()
assert result.loss == copy["loss"]
assert copy.keys() == {"loss"}
# no copy
assert id(result.loss) == id(copy["loss"])
assert result.loss.data_ptr() == copy["loss"].data_ptr()
def test_closure_result_apply_accumulation():
closure_loss = torch.tensor(25.0)
result = ClosureResult.from_training_step_output(closure_loss, 5)
assert result.loss == 5
@pytest.mark.parametrize(
"case", [(5.0, "must return a Tensor, a dict, or None"), ({"a": 5}, "the 'loss' key needs to be present")]
)
def test_warning_invalid_trainstep_output(tmpdir, case):
output, match = case
class InvalidTrainStepModel(BoringModel):
def training_step(self, batch, batch_idx):
return output
model = InvalidTrainStepModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
with pytest.raises(MisconfigurationException, match=match):
trainer.fit(model)
@pytest.mark.parametrize(
"frequencies,expected",
[
(
(3, 1),
[
(0, "SGD"),
(0, "SGD"),
(0, "SGD"),
(1, "Adam"),
(0, "SGD"),
(0, "SGD"),
(0, "SGD"),
(1, "Adam"),
(0, "SGD"),
(0, "SGD"),
],
),
(
(1, 2),
[
(0, "SGD"),
(1, "Adam"),
(1, "Adam"),
(0, "SGD"),
(1, "Adam"),
(1, "Adam"),
(0, "SGD"),
(1, "Adam"),
(1, "Adam"),
(0, "SGD"),
],
),
],
)
def test_optimizer_frequencies(tmpdir, frequencies, expected):
"""Test that the optimizer loop runs optimization for the correct optimizer and optimizer idx when different
frequencies are requested."""
class CurrentModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
return super().training_step(batch, batch_idx)
def configure_optimizers(self):
opt0 = SGD(self.parameters(), lr=0.1)
opt1 = Adam(self.parameters(), lr=0.1)
return {"optimizer": opt0, "frequency": frequencies[0]}, {"optimizer": opt1, "frequency": frequencies[1]}
model = CurrentModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=10,
enable_progress_bar=False,
)
trainer.fit(model)
positional_args = [c[0] for c in model.optimizer_step.call_args_list]
pl_optimizer_sequence = [args[2] for args in positional_args]
opt_idx_sequence = [args[3] for args in positional_args]
assert all(isinstance(opt, LightningOptimizer) for opt in pl_optimizer_sequence)
optimizer_sequence = [opt._optimizer.__class__.__name__ for opt in pl_optimizer_sequence]
assert list(zip(opt_idx_sequence, optimizer_sequence)) == expected
class CustomException(Exception):
pass
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("stop_epoch", (0, 1))
@pytest.mark.parametrize("stop_batch", (0, 1, 2))
@pytest.mark.parametrize("n_optimizers,stop_optimizer", [(2, 0), (2, 1), (3, 2)])
def test_loop_restart_progress_multiple_optimizers(tmpdir, n_optimizers, stop_optimizer, stop_epoch, stop_batch):
"""Test that Lightning can resume from a point where a training_step failed while in the middle of processing
several optimizer steps for one batch.
The test asserts that we end up with the same trained weights as if no failure occurred.
"""
n_batches = 3
n_epochs = 2
def _assert_optimizer_sequence(method_mock, expected):
positional_args = [c[0] for c in method_mock.call_args_list]
sequence = [arg[3] for arg in positional_args]
assert sequence == expected
num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer
opt_idx_sequence_complete = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...]
# +1 because we fail inside the closure inside optimizer_step()
opt_idx_sequence_incomplete = opt_idx_sequence_complete[: (num_optimizers_incomplete + 1)]
opt_idx_sequence_resumed = opt_idx_sequence_complete[num_optimizers_incomplete:]
class MultipleOptimizerModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
if (
fail
and self.current_epoch == stop_epoch
and batch_idx == stop_batch
and optimizer_idx == stop_optimizer
):
raise CustomException
return super().training_step(batch, batch_idx)
def configure_optimizers(self):
return [torch.optim.SGD(self.parameters(), lr=0.1) for _ in range(n_optimizers)]
# run without a failure, collect weights
fail = False
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model)
weights_complete = model.parameters()
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete)
# simulate a failure
fail = True
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
enable_checkpointing=False,
)
with pytest.raises(CustomException):
trainer.fit(model)
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_incomplete)
# resume from failure and collect weights
fail = False
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, ckpt_path=str(tmpdir / ".pl_auto_save.ckpt"))
weights_resumed = model.parameters()
# check that the final weights of a resumed run match the weights of a run that never failed
for w0, w1 in zip(weights_complete, weights_resumed):
assert torch.allclose(w0, w1)
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed)