lightning/tests/loops/test_loop_progress_integrat...

23 lines
1.3 KiB
Python

from pytorch_lightning import Trainer
def test_loop_progress_integration():
trainer = Trainer()
fit_loop = trainer.fit_loop
# check identities inside the fit loop
assert fit_loop.progress.epoch is fit_loop.epoch_loop.progress
assert fit_loop.epoch_loop.progress.batch is fit_loop.epoch_loop.batch_loop.progress
assert fit_loop.epoch_loop.progress.optim is fit_loop.epoch_loop.batch_loop.optim_progress
assert fit_loop.epoch_loop.progress.val is fit_loop.epoch_loop.val_loop.progress
assert fit_loop.epoch_loop.val_loop.progress.epoch is fit_loop.epoch_loop.val_loop.epoch_loop.progress
# check identities inside the evaluation and predict loops
assert trainer.validate_loop.progress.epoch is trainer.validate_loop.epoch_loop.progress
assert trainer.test_loop.progress.epoch is trainer.test_loop.epoch_loop.progress
assert trainer.predict_loop.progress.epoch is trainer.predict_loop.epoch_loop.progress
# check no progresses are shared
assert trainer.fit_loop.progress is not trainer.validate_loop.progress
assert trainer.validate_loop.progress is not trainer.test_loop.progress
assert trainer.test_loop.progress is not trainer.predict_loop.progress
# check the validation progresses are not shared
assert trainer.fit_loop.epoch_loop.val_loop.progress is not trainer.validate_loop.progress