23 lines
1.3 KiB
Python
23 lines
1.3 KiB
Python
from pytorch_lightning import Trainer
|
|
|
|
|
|
def test_loop_progress_integration():
|
|
trainer = Trainer()
|
|
fit_loop = trainer.fit_loop
|
|
# check identities inside the fit loop
|
|
assert fit_loop.progress.epoch is fit_loop.epoch_loop.progress
|
|
assert fit_loop.epoch_loop.progress.batch is fit_loop.epoch_loop.batch_loop.progress
|
|
assert fit_loop.epoch_loop.progress.optim is fit_loop.epoch_loop.batch_loop.optim_progress
|
|
assert fit_loop.epoch_loop.progress.val is fit_loop.epoch_loop.val_loop.progress
|
|
assert fit_loop.epoch_loop.val_loop.progress.epoch is fit_loop.epoch_loop.val_loop.epoch_loop.progress
|
|
# check identities inside the evaluation and predict loops
|
|
assert trainer.validate_loop.progress.epoch is trainer.validate_loop.epoch_loop.progress
|
|
assert trainer.test_loop.progress.epoch is trainer.test_loop.epoch_loop.progress
|
|
assert trainer.predict_loop.progress.epoch is trainer.predict_loop.epoch_loop.progress
|
|
# check no progresses are shared
|
|
assert trainer.fit_loop.progress is not trainer.validate_loop.progress
|
|
assert trainer.validate_loop.progress is not trainer.test_loop.progress
|
|
assert trainer.test_loop.progress is not trainer.predict_loop.progress
|
|
# check the validation progresses are not shared
|
|
assert trainer.fit_loop.epoch_loop.val_loop.progress is not trainer.validate_loop.progress
|