lightning/tests/loops/epoch/test_training_epoch_loop.py

144 lines
5.1 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pytorch_lightning.loops import TrainingEpochLoop
_out00 = {"loss": 0.0}
_out01 = {"loss": 0.1}
_out02 = {"loss": 0.2}
_out03 = {"loss": 0.3}
_out10 = {"loss": 1.0}
_out11 = {"loss": 1.1}
_out12 = {"loss": 1.2}
_out13 = {"loss": 1.3}
@pytest.mark.parametrize(
"num_optimizers,batch_outputs,expected",
[
(1, [], []),
(1, [[]], []),
# 1 batch
(1, [[{0: _out00}]], [_out00]),
# 2 batches
(1, [[{0: _out00}], [{0: _out01}]], [_out00, _out01]),
# 1 batch, 2 optimizers
(2, [[{0: _out00, 1: _out01}]], [_out00, _out01]),
# 2 batches, 2 optimizers
(
2,
[[{0: _out00, 1: _out01}], [{0: _out10, 1: _out11}]],
[[_out00, _out10], [_out01, _out11]],
),
# 4 batches, 2 optimizers, different frequency
(
2,
[[{0: _out00}], [{1: _out10}], [{1: _out11}], [{0: _out01}]],
[[_out00, _out01], [_out10, _out11]],
),
# 1 batch, tbptt with 2 splits (uneven)
(1, [[{0: _out00}, {0: _out01}], [{0: _out03}]], [[_out00, _out01], [_out03]]),
# 3 batches, tbptt with 2 splits, 2 optimizers alternating
(
2,
[[{0: _out00}, {0: _out01}], [{1: _out10}, {1: _out11}], [{0: _out02}, {0: _out03}]],
[[[_out00, _out01], [], [_out02, _out03]], [[], [_out10, _out11], []]],
),
],
)
def test_prepare_outputs_training_epoch_end_automatic(num_optimizers, batch_outputs, expected):
"""Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
currently expects in the case of automatic optimization."""
prepared = TrainingEpochLoop._prepare_outputs_training_epoch_end(
batch_outputs,
automatic=True,
num_optimizers=num_optimizers,
)
assert prepared == expected
@pytest.mark.parametrize(
"batch_outputs,expected",
[
([], []),
([[]], []),
# 1 batch
([[_out00]], [_out00]),
# 2 batches
([[_out00], [_out01]], [_out00, _out01]),
# skipped outputs
([[_out00], [], [], [_out03]], [_out00, _out03]),
# tbptt with 2 splits, uneven, skipped output
([[_out00, _out01], [_out02, _out03], [], [_out10]], [[_out00, _out01], [_out02, _out03], [_out10]]),
],
)
def test_prepare_outputs_training_epoch_end_manual(batch_outputs, expected):
"""Test that the loop converts the nested lists of outputs to the format that the `training_epoch_end` hook
currently expects in the case of manual optimization."""
prepared = TrainingEpochLoop._prepare_outputs_training_epoch_end(
batch_outputs,
automatic=False,
num_optimizers=-1, # does not matter for manual optimization
)
assert prepared == expected
@pytest.mark.parametrize(
"num_optimizers,batch_end_outputs,expected",
[
(1, [], []),
(1, [[]], []),
# 1 optimizer
(1, [{0: _out00}], _out00),
# 2 optimizers
(2, [{0: _out00, 1: _out01}], [_out00, _out01]),
# tbptt with 2 splits
(1, [{0: _out00}, {0: _out01}], [_out00, _out01]),
# 2 optimizers, tbptt with 2 splits
(2, [{0: _out00, 1: _out01}, {0: _out10, 1: _out11}], [[_out00, _out10], [_out01, _out11]]),
],
)
def test_prepare_outputs_training_batch_end_automatic(num_optimizers, batch_end_outputs, expected):
"""Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
currently expects in the case of automatic optimization."""
prepared = TrainingEpochLoop._prepare_outputs_training_batch_end(
batch_end_outputs,
automatic=True,
num_optimizers=num_optimizers,
)
assert prepared == expected
@pytest.mark.parametrize(
"batch_end_outputs,expected",
[
([], []),
([[]], []),
# skipped outputs
([_out00, None, _out02], [_out00, _out02]),
# tbptt with 3 splits, skipped output
([_out00, _out01, None, _out03], [_out00, _out01, _out03]),
],
)
def test_prepare_outputs_training_batch_end_manual(batch_end_outputs, expected):
"""Test that the loop converts the nested lists of outputs to the format that the `on_train_batch_end` hook
currently expects in the case of manual optimization."""
prepared = TrainingEpochLoop._prepare_outputs_training_batch_end(
batch_end_outputs,
automatic=False,
num_optimizers=-1, # does not matter for manual optimization
)
assert prepared == expected