2021-09-14 13:48:27 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-09-14 20:21:45 +00:00
|
|
|
from unittest.mock import Mock
|
|
|
|
|
|
|
|
import pytest
|
2021-09-14 13:48:27 +00:00
|
|
|
import torch
|
2021-09-14 20:21:45 +00:00
|
|
|
from torch.optim import Adam, SGD
|
2021-09-14 13:48:27 +00:00
|
|
|
|
2021-09-14 20:21:45 +00:00
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning.core.optimizer import LightningOptimizer
|
2021-09-14 13:48:27 +00:00
|
|
|
from pytorch_lightning.loops.optimization.optimizer_loop import ClosureResult
|
2021-09-15 12:18:19 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2021-09-14 20:21:45 +00:00
|
|
|
from tests.helpers import BoringModel
|
2021-09-14 13:48:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_closure_result_deepcopy():
|
|
|
|
closure_loss = torch.tensor(123.45)
|
|
|
|
result = ClosureResult(closure_loss)
|
|
|
|
|
|
|
|
assert closure_loss.data_ptr() == result.closure_loss.data_ptr()
|
|
|
|
# the `loss` is cloned so the storage is different
|
|
|
|
assert closure_loss.data_ptr() != result.loss.data_ptr()
|
|
|
|
|
2021-09-15 12:18:19 +00:00
|
|
|
copy = result.asdict()
|
|
|
|
assert result.loss == copy["loss"]
|
|
|
|
assert copy.keys() == {"loss"}
|
2021-09-14 13:48:27 +00:00
|
|
|
|
|
|
|
# no copy
|
2021-09-15 12:18:19 +00:00
|
|
|
assert id(result.loss) == id(copy["loss"])
|
|
|
|
assert result.loss.data_ptr() == copy["loss"].data_ptr()
|
2021-09-14 13:48:27 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_closure_result_apply_accumulation():
|
|
|
|
closure_loss = torch.tensor(25.0)
|
|
|
|
result = ClosureResult.from_training_step_output(closure_loss, 5)
|
|
|
|
assert result.loss == 5
|
2021-09-14 20:21:45 +00:00
|
|
|
|
|
|
|
|
2021-09-15 12:18:19 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"case", [(5.0, "must return a Tensor, a dict, or None"), ({"a": 5}, "the 'loss' key needs to be present")]
|
|
|
|
)
|
|
|
|
def test_warning_invalid_trainstep_output(tmpdir, case):
|
|
|
|
output, match = case
|
|
|
|
|
|
|
|
class InvalidTrainStepModel(BoringModel):
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
return output
|
|
|
|
|
|
|
|
model = InvalidTrainStepModel()
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
|
|
|
|
|
|
|
|
with pytest.raises(MisconfigurationException, match=match):
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
|
2021-09-14 20:21:45 +00:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"frequencies,expected",
|
|
|
|
[
|
|
|
|
(
|
|
|
|
(3, 1),
|
|
|
|
[
|
|
|
|
(0, "SGD"),
|
|
|
|
(0, "SGD"),
|
|
|
|
(0, "SGD"),
|
|
|
|
(1, "Adam"),
|
|
|
|
(0, "SGD"),
|
|
|
|
(0, "SGD"),
|
|
|
|
(0, "SGD"),
|
|
|
|
(1, "Adam"),
|
|
|
|
(0, "SGD"),
|
|
|
|
(0, "SGD"),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
(
|
|
|
|
(1, 2),
|
|
|
|
[
|
|
|
|
(0, "SGD"),
|
|
|
|
(1, "Adam"),
|
|
|
|
(1, "Adam"),
|
|
|
|
(0, "SGD"),
|
|
|
|
(1, "Adam"),
|
|
|
|
(1, "Adam"),
|
|
|
|
(0, "SGD"),
|
|
|
|
(1, "Adam"),
|
|
|
|
(1, "Adam"),
|
|
|
|
(0, "SGD"),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
],
|
|
|
|
)
|
|
|
|
def test_optimizer_frequencies(tmpdir, frequencies, expected):
|
|
|
|
"""Test that the optimizer loop runs optimization for the correct optimizer and optimizer idx when different
|
|
|
|
frequencies are requested."""
|
|
|
|
|
|
|
|
class CurrentModel(BoringModel):
|
|
|
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
|
|
|
return super().training_step(batch, batch_idx)
|
|
|
|
|
|
|
|
def configure_optimizers(self):
|
|
|
|
opt0 = SGD(self.parameters(), lr=0.1)
|
|
|
|
opt1 = Adam(self.parameters(), lr=0.1)
|
|
|
|
return {"optimizer": opt0, "frequency": frequencies[0]}, {"optimizer": opt1, "frequency": frequencies[1]}
|
|
|
|
|
|
|
|
model = CurrentModel()
|
|
|
|
model.optimizer_step = Mock(wraps=model.optimizer_step)
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
fast_dev_run=10,
|
|
|
|
progress_bar_refresh_rate=0,
|
|
|
|
)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
|
|
|
positional_args = [c[0] for c in model.optimizer_step.call_args_list]
|
|
|
|
pl_optimizer_sequence = [args[2] for args in positional_args]
|
|
|
|
opt_idx_sequence = [args[3] for args in positional_args]
|
|
|
|
assert all(isinstance(opt, LightningOptimizer) for opt in pl_optimizer_sequence)
|
|
|
|
optimizer_sequence = [opt._optimizer.__class__.__name__ for opt in pl_optimizer_sequence]
|
|
|
|
assert list(zip(opt_idx_sequence, optimizer_sequence)) == expected
|