241 lines
8.2 KiB
Python
241 lines
8.2 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from unittest import mock
|
|
from unittest.mock import Mock
|
|
|
|
import pytest
|
|
import torch
|
|
from torch.optim import Adam, SGD
|
|
|
|
from pytorch_lightning import seed_everything, Trainer
|
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
|
from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from tests.helpers import BoringModel
|
|
|
|
|
|
def test_closure_result_deepcopy():
|
|
closure_loss = torch.tensor(123.45)
|
|
result = ClosureResult(closure_loss)
|
|
|
|
assert closure_loss.data_ptr() == result.closure_loss.data_ptr()
|
|
# the `loss` is cloned so the storage is different
|
|
assert closure_loss.data_ptr() != result.loss.data_ptr()
|
|
|
|
copy = result.asdict()
|
|
assert result.loss == copy["loss"]
|
|
assert copy.keys() == {"loss"}
|
|
|
|
# no copy
|
|
assert id(result.loss) == id(copy["loss"])
|
|
assert result.loss.data_ptr() == copy["loss"].data_ptr()
|
|
|
|
|
|
def test_closure_result_apply_accumulation():
|
|
closure_loss = torch.tensor(25.0)
|
|
result = ClosureResult.from_training_step_output(closure_loss, 5)
|
|
assert result.loss == 5
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"case", [(5.0, "must return a Tensor, a dict, or None"), ({"a": 5}, "the 'loss' key needs to be present")]
|
|
)
|
|
def test_warning_invalid_trainstep_output(tmpdir, case):
|
|
output, match = case
|
|
|
|
class InvalidTrainStepModel(BoringModel):
|
|
def training_step(self, batch, batch_idx):
|
|
return output
|
|
|
|
model = InvalidTrainStepModel()
|
|
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
|
|
|
with pytest.raises(MisconfigurationException, match=match):
|
|
trainer.fit(model)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"frequencies,expected",
|
|
[
|
|
(
|
|
(3, 1),
|
|
[
|
|
(0, "SGD"),
|
|
(0, "SGD"),
|
|
(0, "SGD"),
|
|
(1, "Adam"),
|
|
(0, "SGD"),
|
|
(0, "SGD"),
|
|
(0, "SGD"),
|
|
(1, "Adam"),
|
|
(0, "SGD"),
|
|
(0, "SGD"),
|
|
],
|
|
),
|
|
(
|
|
(1, 2),
|
|
[
|
|
(0, "SGD"),
|
|
(1, "Adam"),
|
|
(1, "Adam"),
|
|
(0, "SGD"),
|
|
(1, "Adam"),
|
|
(1, "Adam"),
|
|
(0, "SGD"),
|
|
(1, "Adam"),
|
|
(1, "Adam"),
|
|
(0, "SGD"),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_optimizer_frequencies(tmpdir, frequencies, expected):
|
|
"""Test that the optimizer loop runs optimization for the correct optimizer and optimizer idx when different
|
|
frequencies are requested."""
|
|
|
|
class CurrentModel(BoringModel):
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
def configure_optimizers(self):
|
|
opt0 = SGD(self.parameters(), lr=0.1)
|
|
opt1 = Adam(self.parameters(), lr=0.1)
|
|
return {"optimizer": opt0, "frequency": frequencies[0]}, {"optimizer": opt1, "frequency": frequencies[1]}
|
|
|
|
model = CurrentModel()
|
|
model.training_epoch_end = None
|
|
model.optimizer_step = Mock(wraps=model.optimizer_step)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=10,
|
|
enable_progress_bar=False,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
positional_args = [c[0] for c in model.optimizer_step.call_args_list]
|
|
pl_optimizer_sequence = [args[2] for args in positional_args]
|
|
opt_idx_sequence = [args[3] for args in positional_args]
|
|
assert all(isinstance(opt, LightningOptimizer) for opt in pl_optimizer_sequence)
|
|
optimizer_sequence = [opt._optimizer.__class__.__name__ for opt in pl_optimizer_sequence]
|
|
assert list(zip(opt_idx_sequence, optimizer_sequence)) == expected
|
|
|
|
|
|
class CustomException(Exception):
|
|
pass
|
|
|
|
|
|
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
|
|
@pytest.mark.parametrize("stop_epoch", (0, 1))
|
|
@pytest.mark.parametrize("stop_batch", (0, 1, 2))
|
|
@pytest.mark.parametrize("n_optimizers,stop_optimizer", [(2, 0), (2, 1), (3, 2)])
|
|
def test_loop_restart_progress_multiple_optimizers(tmpdir, n_optimizers, stop_optimizer, stop_epoch, stop_batch):
|
|
"""Test that Lightning can resume from a point where a training_step failed while in the middle of processing
|
|
several optimizer steps for one batch.
|
|
|
|
The test asserts that we end up with the same trained weights as if no failure occurred.
|
|
"""
|
|
|
|
n_batches = 3
|
|
n_epochs = 2
|
|
|
|
def _assert_optimizer_sequence(method_mock, expected):
|
|
positional_args = [c[0] for c in method_mock.call_args_list]
|
|
sequence = [arg[3] for arg in positional_args]
|
|
assert sequence == expected
|
|
|
|
num_optimizers_incomplete = stop_epoch * n_batches * n_optimizers + stop_batch * n_optimizers + stop_optimizer
|
|
|
|
opt_idx_sequence_complete = list(range(n_optimizers)) * n_epochs * n_batches # [0, 1, 2, 0, 1, 2, 0, 1, ...]
|
|
# +1 because we fail inside the closure inside optimizer_step()
|
|
opt_idx_sequence_incomplete = opt_idx_sequence_complete[: (num_optimizers_incomplete + 1)]
|
|
opt_idx_sequence_resumed = opt_idx_sequence_complete[num_optimizers_incomplete:]
|
|
|
|
class MultipleOptimizerModel(BoringModel):
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
if (
|
|
fail
|
|
and self.current_epoch == stop_epoch
|
|
and batch_idx == stop_batch
|
|
and optimizer_idx == stop_optimizer
|
|
):
|
|
raise CustomException
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
def configure_optimizers(self):
|
|
return [torch.optim.SGD(self.parameters(), lr=0.1) for _ in range(n_optimizers)]
|
|
|
|
# run without a failure, collect weights
|
|
fail = False
|
|
seed_everything(0)
|
|
model = MultipleOptimizerModel()
|
|
model.training_epoch_end = None
|
|
model.optimizer_step = Mock(wraps=model.optimizer_step)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=n_epochs,
|
|
limit_train_batches=n_batches,
|
|
limit_val_batches=0,
|
|
num_sanity_val_steps=0,
|
|
logger=False,
|
|
enable_checkpointing=False,
|
|
)
|
|
trainer.fit(model)
|
|
weights_complete = model.parameters()
|
|
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_complete)
|
|
|
|
# simulate a failure
|
|
fail = True
|
|
seed_everything(0)
|
|
model = MultipleOptimizerModel()
|
|
model.training_epoch_end = None
|
|
model.optimizer_step = Mock(wraps=model.optimizer_step)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=n_epochs,
|
|
limit_train_batches=n_batches,
|
|
limit_val_batches=0,
|
|
num_sanity_val_steps=0,
|
|
logger=False,
|
|
enable_checkpointing=False,
|
|
)
|
|
with pytest.raises(CustomException):
|
|
trainer.fit(model)
|
|
|
|
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_incomplete)
|
|
|
|
# resume from failure and collect weights
|
|
fail = False
|
|
seed_everything(0)
|
|
model = MultipleOptimizerModel()
|
|
model.training_epoch_end = None
|
|
model.optimizer_step = Mock(wraps=model.optimizer_step)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=n_epochs,
|
|
limit_train_batches=n_batches,
|
|
limit_val_batches=0,
|
|
num_sanity_val_steps=0,
|
|
logger=False,
|
|
enable_checkpointing=False,
|
|
)
|
|
trainer.fit(model, ckpt_path=str(tmpdir / ".pl_auto_save.ckpt"))
|
|
weights_resumed = model.parameters()
|
|
|
|
# check that the final weights of a resumed run match the weights of a run that never failed
|
|
for w0, w1 in zip(weights_complete, weights_resumed):
|
|
assert torch.allclose(w0, w1)
|
|
|
|
_assert_optimizer_sequence(model.optimizer_step, opt_idx_sequence_resumed)
|