2021-05-19 21:02:20 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-07-05 10:55:41 +00:00
|
|
|
from copy import deepcopy
|
|
|
|
|
2021-05-22 01:09:08 +00:00
|
|
|
import pytest
|
2021-05-19 21:02:20 +00:00
|
|
|
|
2021-06-29 12:47:41 +00:00
|
|
|
from pytorch_lightning.trainer.progress import (
|
|
|
|
BatchProgress,
|
|
|
|
EpochLoopProgress,
|
|
|
|
EpochProgress,
|
|
|
|
FitLoopProgress,
|
|
|
|
OptimizerProgress,
|
|
|
|
Progress,
|
|
|
|
Tracker,
|
|
|
|
)
|
2021-05-19 21:02:20 +00:00
|
|
|
|
|
|
|
|
2021-05-22 01:09:08 +00:00
|
|
|
def test_progress_geattr_setattr():
|
|
|
|
p = Tracker(ready=10, completed=None)
|
|
|
|
# can read
|
|
|
|
assert p.completed is None
|
|
|
|
# can't read non-existing attr
|
|
|
|
with pytest.raises(AttributeError, match="object has no attribute 'non_existing_attr'"):
|
|
|
|
p.non_existing_attr # noqa
|
|
|
|
# can set new attr
|
|
|
|
p.non_existing_attr = 10
|
|
|
|
# can't write unused attr
|
|
|
|
with pytest.raises(AttributeError, match="'completed' attribute is meant to be unused"):
|
|
|
|
p.completed = 10
|
|
|
|
with pytest.raises(TypeError, match="unsupported operand type"):
|
|
|
|
# default python error, would need to override `__getattribute__`
|
|
|
|
# but we want to allow reading the `None` value
|
|
|
|
p.completed += 10
|
2021-05-19 21:02:20 +00:00
|
|
|
|
|
|
|
|
2021-05-22 01:09:08 +00:00
|
|
|
def test_progress_reset():
|
|
|
|
p = Tracker(ready=1, started=2, completed=None)
|
|
|
|
p.reset()
|
|
|
|
assert p == Tracker(completed=None)
|
2021-05-19 21:02:20 +00:00
|
|
|
|
|
|
|
|
2021-05-22 01:09:08 +00:00
|
|
|
def test_progress_repr():
|
|
|
|
assert repr(Tracker(ready=None, started=None)) == "Tracker(processed=0, completed=0)"
|
2021-05-19 21:02:20 +00:00
|
|
|
|
|
|
|
|
2021-05-22 01:09:08 +00:00
|
|
|
@pytest.mark.parametrize("attr", ("ready", "started", "processed", "completed"))
|
|
|
|
def test_base_progress_increment(attr):
|
|
|
|
p = Progress()
|
|
|
|
fn = getattr(p, f"increment_{attr}")
|
|
|
|
fn()
|
|
|
|
expected = Tracker(**{attr: 1})
|
|
|
|
assert p.total == expected
|
|
|
|
assert p.current == expected
|
2021-05-19 21:02:20 +00:00
|
|
|
|
|
|
|
|
2021-05-22 01:09:08 +00:00
|
|
|
def test_base_progress_from_defaults():
|
|
|
|
actual = Progress.from_defaults(completed=5, started=None)
|
|
|
|
expected = Progress(total=Tracker(started=None, completed=5), current=Tracker(started=None, completed=5))
|
|
|
|
assert actual == expected
|
2021-05-19 21:02:20 +00:00
|
|
|
|
|
|
|
|
2021-06-29 12:47:41 +00:00
|
|
|
def test_epoch_loop_progress_increment_epoch():
|
|
|
|
p = EpochLoopProgress()
|
2021-05-22 01:09:08 +00:00
|
|
|
p.increment_epoch_completed()
|
|
|
|
p.increment_epoch_completed()
|
|
|
|
assert p.epoch.total == Tracker(completed=2)
|
|
|
|
assert p.epoch.current == Tracker()
|
2021-06-29 12:47:41 +00:00
|
|
|
assert p.epoch.batch.current == Tracker()
|
|
|
|
|
|
|
|
|
|
|
|
def test_epoch_loop_progress_increment_sequence():
|
|
|
|
"""Test sequences for incrementing batches reads and epochs."""
|
|
|
|
batch = BatchProgress(total=Tracker(started=None))
|
|
|
|
epoch = EpochProgress(batch=batch)
|
|
|
|
loop = EpochLoopProgress(epoch=epoch)
|
|
|
|
|
|
|
|
batch.increment_ready()
|
|
|
|
assert batch.total == Tracker(ready=1, started=None)
|
|
|
|
assert batch.current == Tracker(ready=1)
|
|
|
|
|
|
|
|
batch.increment_started()
|
|
|
|
assert batch.total == Tracker(ready=1, started=None)
|
|
|
|
assert batch.current == Tracker(ready=1)
|
|
|
|
|
|
|
|
batch.increment_processed()
|
|
|
|
assert batch.total == Tracker(ready=1, started=None, processed=1)
|
|
|
|
assert batch.current == Tracker(ready=1, processed=1)
|
|
|
|
|
|
|
|
batch.increment_completed()
|
|
|
|
assert batch.total == Tracker(ready=1, started=None, processed=1, completed=1)
|
|
|
|
assert batch.current == Tracker(ready=1, processed=1, completed=1)
|
|
|
|
|
|
|
|
assert epoch.total == Tracker()
|
|
|
|
assert epoch.current == Tracker()
|
|
|
|
loop.increment_epoch_completed()
|
|
|
|
assert batch.total == Tracker(ready=1, started=None, processed=1, completed=1)
|
|
|
|
assert batch.current == Tracker()
|
|
|
|
assert epoch.total == Tracker(completed=1)
|
|
|
|
assert epoch.current == Tracker()
|
|
|
|
|
|
|
|
batch.increment_ready()
|
|
|
|
assert batch.total == Tracker(ready=2, started=None, processed=1, completed=1)
|
|
|
|
assert batch.current == Tracker(ready=1)
|
|
|
|
assert epoch.total == Tracker(completed=1)
|
|
|
|
assert epoch.current == Tracker()
|
|
|
|
|
|
|
|
loop.reset_on_epoch()
|
|
|
|
assert batch.total == Tracker(ready=2, started=None, processed=1, completed=1)
|
|
|
|
assert batch.current == Tracker()
|
|
|
|
assert epoch.total == Tracker(completed=1)
|
|
|
|
assert epoch.current == Tracker()
|
|
|
|
|
|
|
|
|
|
|
|
def test_optimizer_progress_default_factory():
|
|
|
|
"""
|
|
|
|
Ensure that the defaults are created appropiately. If `default_factory` was not used, the default would
|
|
|
|
be shared between instances.
|
|
|
|
"""
|
|
|
|
p1 = OptimizerProgress()
|
|
|
|
p2 = OptimizerProgress()
|
|
|
|
p1.step.increment_completed()
|
|
|
|
assert p1.step.total.completed == p1.step.current.completed
|
|
|
|
assert p1.step.total.completed == 1
|
|
|
|
assert p2.step.total.completed == 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_fit_loop_progress_serialization():
|
|
|
|
fit_loop = FitLoopProgress()
|
2021-07-05 10:55:41 +00:00
|
|
|
_ = deepcopy(fit_loop)
|
|
|
|
fit_loop.epoch.increment_completed() # check `TrainingEpochProgress.load_state_dict` calls `super`
|
|
|
|
|
2021-06-29 12:47:41 +00:00
|
|
|
state_dict = fit_loop.state_dict()
|
|
|
|
# yapf: disable
|
|
|
|
assert state_dict == {
|
|
|
|
'epoch': {
|
|
|
|
# number of epochs across `fit` calls
|
2021-07-05 10:55:41 +00:00
|
|
|
'total': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0},
|
2021-06-29 12:47:41 +00:00
|
|
|
# number of epochs this `fit` call
|
2021-07-05 10:55:41 +00:00
|
|
|
'current': {'completed': 1, 'processed': 0, 'ready': 0, 'started': 0},
|
2021-06-29 12:47:41 +00:00
|
|
|
'batch': {
|
|
|
|
# number of batches across `fit` calls
|
|
|
|
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
# number of batches this epoch
|
|
|
|
'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
},
|
|
|
|
# `fit` optimization progress
|
|
|
|
'optim': {
|
|
|
|
# optimizers progress
|
|
|
|
'optimizer': {
|
|
|
|
'step': {
|
|
|
|
# `optimizer.step` calls across `fit` calls
|
|
|
|
'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0},
|
|
|
|
# `optimizer.step` calls this epoch
|
|
|
|
'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0},
|
|
|
|
},
|
|
|
|
'zero_grad': {
|
|
|
|
# `optimizer.zero_grad` calls across `fit` calls
|
|
|
|
'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0},
|
|
|
|
# `optimizer.zero_grad` calls this epoch
|
|
|
|
'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': 0},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
'scheduler': {
|
|
|
|
# `scheduler.step` calls across `fit` calls
|
|
|
|
'total': {'completed': 0, 'processed': None, 'ready': 0, 'started': None},
|
|
|
|
# `scheduler.step` calls this epoch
|
|
|
|
'current': {'completed': 0, 'processed': None, 'ready': 0, 'started': None},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
# `fit` validation progress
|
|
|
|
'val': {
|
|
|
|
'epoch': {
|
|
|
|
# number of `validation` calls across `fit` calls
|
|
|
|
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
# number of `validation` calls this `fit` call
|
|
|
|
'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
'batch': {
|
|
|
|
# number of batches across `fit` `validation` calls
|
|
|
|
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
# number of batches this `fit` `validation` call
|
|
|
|
'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
# yapf: enable
|
2021-07-05 10:55:41 +00:00
|
|
|
|
2021-06-29 12:47:41 +00:00
|
|
|
new_loop = FitLoopProgress.from_state_dict(state_dict)
|
|
|
|
assert fit_loop == new_loop
|
|
|
|
|
|
|
|
|
|
|
|
def test_epoch_loop_progress_serialization():
|
|
|
|
loop = EpochLoopProgress()
|
2021-07-05 10:55:41 +00:00
|
|
|
_ = deepcopy(loop)
|
2021-06-29 12:47:41 +00:00
|
|
|
state_dict = loop.state_dict()
|
2021-07-05 10:55:41 +00:00
|
|
|
|
2021-06-29 12:47:41 +00:00
|
|
|
# yapf: disable
|
|
|
|
assert state_dict == {
|
|
|
|
'epoch': {
|
|
|
|
# number of times `validate` has been called
|
|
|
|
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
# either 0 or 1 as `max_epochs` does not apply to the `validate` loop
|
|
|
|
'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
'batch': {
|
|
|
|
# number of batches across `validate` calls
|
|
|
|
'total': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
# number of batches this `validate` call
|
|
|
|
'current': {'completed': 0, 'processed': 0, 'ready': 0, 'started': 0},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
# yapf: enable
|
2021-07-05 10:55:41 +00:00
|
|
|
|
2021-06-29 12:47:41 +00:00
|
|
|
new_loop = EpochLoopProgress.from_state_dict(state_dict)
|
|
|
|
assert loop == new_loop
|