lightning/tests/trainer/test_progress.py

111 lines
4.0 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from pytorch_lightning.trainer.progress import LoopProgress, Progress, Tracker
def test_progress_geattr_setattr():
p = Tracker(ready=10, completed=None)
# can read
assert p.completed is None
# can't read non-existing attr
with pytest.raises(AttributeError, match="object has no attribute 'non_existing_attr'"):
p.non_existing_attr # noqa
# can set new attr
p.non_existing_attr = 10
# can't write unused attr
with pytest.raises(AttributeError, match="'completed' attribute is meant to be unused"):
p.completed = 10
with pytest.raises(TypeError, match="unsupported operand type"):
# default python error, would need to override `__getattribute__`
# but we want to allow reading the `None` value
p.completed += 10
def test_progress_reset():
p = Tracker(ready=1, started=2, completed=None)
p.reset()
assert p == Tracker(completed=None)
def test_progress_repr():
assert repr(Tracker(ready=None, started=None)) == "Tracker(processed=0, completed=0)"
@pytest.mark.parametrize("attr", ("ready", "started", "processed", "completed"))
def test_base_progress_increment(attr):
p = Progress()
fn = getattr(p, f"increment_{attr}")
fn()
expected = Tracker(**{attr: 1})
assert p.total == expected
assert p.current == expected
def test_base_progress_from_defaults():
actual = Progress.from_defaults(completed=5, started=None)
expected = Progress(total=Tracker(started=None, completed=5), current=Tracker(started=None, completed=5))
assert actual == expected
def test_loop_progress_increment_epoch():
p = LoopProgress()
p.increment_epoch_completed()
p.increment_epoch_completed()
assert p.epoch.total == Tracker(completed=2)
assert p.epoch.current == Tracker()
assert p.batch.current == Tracker()
def test_loop_progress_increment_sequence():
""" Test sequences for incrementing batches reads and epochs. """
p = LoopProgress(batch=Progress(total=Tracker(started=None)))
p.batch.increment_ready()
assert p.batch.total == Tracker(ready=1, started=None)
assert p.batch.current == Tracker(ready=1)
p.batch.increment_started()
assert p.batch.total == Tracker(ready=1, started=None)
assert p.batch.current == Tracker(ready=1)
p.batch.increment_processed()
assert p.batch.total == Tracker(ready=1, started=None, processed=1)
assert p.batch.current == Tracker(ready=1, processed=1)
p.batch.increment_completed()
assert p.batch.total == Tracker(ready=1, started=None, processed=1, completed=1)
assert p.batch.current == Tracker(ready=1, processed=1, completed=1)
assert p.epoch.total == Tracker()
assert p.epoch.current == Tracker()
p.increment_epoch_completed()
assert p.batch.total == Tracker(ready=1, started=None, processed=1, completed=1)
assert p.batch.current == Tracker()
assert p.epoch.total == Tracker(completed=1)
assert p.epoch.current == Tracker()
p.batch.increment_ready()
assert p.batch.total == Tracker(ready=2, started=None, processed=1, completed=1)
assert p.batch.current == Tracker(ready=1)
assert p.epoch.total == Tracker(completed=1)
assert p.epoch.current == Tracker()
p.reset_on_epoch()
assert p.batch.total == Tracker(ready=2, started=None, processed=1, completed=1)
assert p.batch.current == Tracker()
assert p.epoch.total == Tracker(completed=1)
assert p.epoch.current == Tracker()