lightning/tests/loops/optimization/test_optimizer_loop.py

241 lines
8.2 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
from torch.optim import Adam, SGD
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.core.optimizer import LightningOptimizer
from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
def test_closure_result_deepcopy():
closure_loss = torch.tensor(123.45)
result = ClosureResult(closure_loss)
assert closure_loss.data_ptr() == result.closure_loss.data_ptr()
# the `loss` is cloned so the storage is different
assert closure_loss.data_ptr() != result.loss.data_ptr()
copy = result.asdict()
assert result.loss == copy["loss"]
assert copy.keys() == {"loss"}
# no copy
assert id(result.loss) == id(copy["loss"])
assert result.loss.data_ptr() == copy["loss"].data_ptr()
def test_closure_result_apply_accumulation():
closure_loss = torch.tensor(25.0)
result = ClosureResult.from_training_step_output(closure_loss, 5)
assert result.loss == 5
@pytest.mark.parametrize(
"case", [(5.0, "must return a Tensor, a dict, or None"), ({"a": 5}, "the 'loss' key needs to be present")]
)
def test_warning_invalid_trainstep_output(tmpdir, case):
output, match = case
class InvalidTrainStepModel(BoringModel):
def training_step(self, batch, batch_idx):
return output
model = InvalidTrainStepModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
with pytest.raises(MisconfigurationException, match=match):
trainer.fit(model)
@pytest.mark.parametrize(
"frequencies,expected",
[
(
(3, 1),
[
(0, "SGD"),
(0, "SGD"),
(0, "SGD"),
(1, "Adam"),
(0, "SGD"),
(0, "SGD"),
(0, "SGD"),
(1, "Adam"),
(0, "SGD"),
(0, "SGD"),
],
),
(
(1, 2),
[
(0, "SGD"),
(1, "Adam"),
(1, "Adam"),
(0, "SGD"),
(1, "Adam"),
(1, "Adam"),
(0, "SGD"),
(1, "Adam"),
(1, "Adam"),
(0, "SGD"),
],
),
],
)
def test_optimizer_frequencies(tmpdir, frequencies, expected):
"""Test that the optimizer loop runs optimization for the correct optimizer and optimizer idx when different
frequencies are requested."""
class CurrentModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
return super().training_step(batch, batch_idx)
def configure_optimizers(self):
opt0 = SGD(self.parameters(), lr=0.1)
opt1 = Adam(self.parameters(), lr=0.1)
return {"optimizer": opt0, "frequency": frequencies[0]}, {"optimizer": opt1, "frequency": frequencies[1]}
model = CurrentModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=10,
enable_progress_bar=False,
)
trainer.fit(model)
positional_args = [c[0] for c in model.optimizer_step.call_args_list]
pl_optimizer_sequence = [args[2] for args in positional_args]
opt_idx_sequence = [args[3] for args in positional_args]
assert all(isinstance(opt, LightningOptimizer) for opt in pl_optimizer_sequence)
optimizer_sequence = [opt._optimizer.__class__.__name__ for opt in pl_optimizer_sequence]
assert list(zip(opt_idx_sequence, optimizer_sequence)) == expected
class CustomException(Exception):
pass
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize("stop_epoch", (0, 1))
@pytest.mark.parametrize("stop_batch", (0, 1, 2))
@pytest.mark.parametrize("n_optimizers,stop_optimizer", [(2, 0), (2, 1), (3, 2)])
def test_loop_restart_progress_multiple_optimizers(tmpdir, n_optimizers, stop_optimizer, stop_epoch, stop_batch):
"""Test that Lightning can resume from a point where a training_step failed while in the middle of processing
several optimizer steps for one batch.
2022-02-17 01:27:51 +00:00
The test asserts that we end up with the same trained weights as if no failure occurred.
"""
n_batches = 3
n_epochs = 2
def _assert_optimizer_sequence(method_mock, expected):
positional_args = [c[0] for c in method_mock.call_args_list]
sequence = [arg[3] for arg in positional_args]
assert sequence == expected
num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer
opt_idx_sequence_complete = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...]
# +1 because we fail inside the closure inside optimizer_step()
opt_idx_sequence_incomplete = opt_idx_sequence_complete[: (num_optimizers_incomplete + 1)]
opt_idx_sequence_resumed = opt_idx_sequence_complete[num_optimizers_incomplete:]
class MultipleOptimizerModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
if (
fail
and self.current_epoch == stop_epoch
and batch_idx == stop_batch
and optimizer_idx == stop_optimizer
):
raise CustomException
return super().training_step(batch, batch_idx)
def configure_optimizers(self):
return [torch.optim.SGD(self.parameters(), lr=0.1) for _ in range(n_optimizers)]
# run without a failure, collect weights
fail = False
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model)
weights_complete = model.parameters()
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete)
# simulate a failure
fail = True
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
enable_checkpointing=False,
)
with pytest.raises(CustomException):
trainer.fit(model)
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_incomplete)
# resume from failure and collect weights
fail = False
seed_everything(0)
model = MultipleOptimizerModel()
model.training_epoch_end = None
model.optimizer_step = Mock(wraps=model.optimizer_step)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=n_epochs,
limit_train_batches=n_batches,
limit_val_batches=0,
num_sanity_val_steps=0,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, ckpt_path=str(tmpdir / ".pl_auto_save.ckpt"))
weights_resumed = model.parameters()
# check that the final weights of a resumed run match the weights of a run that never failed
for w0, w1 in zip(weights_complete, weights_resumed):
assert torch.allclose(w0, w1)
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed)