lightning/tests/callbacks/test_callbacks.py

126 lines
4.9 KiB
Python
Raw Normal View History

2020-10-13 11:18:07 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import ANY, call, MagicMock
from pytorch_lightning import Trainer
from tests.base import BoringModel
@mock.patch("torch.save") # need to mock torch.save or we get pickle error
def test_trainer_callback_system(torch_save):
"""Test the callback system."""
model = BoringModel()
callback_mock = MagicMock()
trainer_options = dict(
callbacks=[callback_mock],
max_epochs=1,
limit_val_batches=1,
limit_train_batches=3,
limit_test_batches=2,
progress_bar_refresh_rate=0,
)
# no call yet
callback_mock.assert_not_called()
# fit model
trainer = Trainer(**trainer_options)
# check that only the to calls exists
assert trainer.callbacks[0] == callback_mock
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
]
trainer.fit(model)
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'fit'),
call.on_fit_start(trainer, model),
call.on_pretrain_routine_start(trainer, model),
call.on_pretrain_routine_end(trainer, model),
call.on_sanity_check_start(trainer, model),
call.on_validation_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_sanity_check_end(trainer, model),
call.on_train_start(trainer, model),
call.on_epoch_start(trainer, model),
call.on_train_epoch_start(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 0, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 1, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_batch_end(trainer, model),
call.on_batch_start(trainer, model),
call.on_train_batch_start(trainer, model, ANY, 2, 0),
call.on_after_backward(trainer, model),
call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
call.on_batch_end(trainer, model),
2021-02-08 18:54:43 +00:00
call.on_train_epoch_end(trainer, model, ANY),
call.on_epoch_end(trainer, model),
call.on_validation_start(trainer, model),
call.on_validation_epoch_start(trainer, model),
call.on_validation_batch_start(trainer, model, ANY, 0, 0),
call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_validation_epoch_end(trainer, model),
call.on_validation_end(trainer, model),
call.on_save_checkpoint(trainer, model),
call.on_train_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
]
callback_mock.reset_mock()
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
trainer = Trainer(**trainer_options)
trainer.test(model)
assert callback_mock.method_calls == [
call.on_init_start(trainer),
call.on_init_end(trainer),
call.on_before_accelerator_backend_setup(trainer, model),
call.setup(trainer, model, 'test'),
call.on_fit_start(trainer, model),
call.on_test_start(trainer, model),
call.on_test_epoch_start(trainer, model),
call.on_test_batch_start(trainer, model, ANY, 0, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 0, 0),
call.on_test_batch_start(trainer, model, ANY, 1, 0),
call.on_test_batch_end(trainer, model, ANY, ANY, 1, 0),
call.on_test_epoch_end(trainer, model),
call.on_test_end(trainer, model),
call.on_fit_end(trainer, model),
call.teardown(trainer, model, 'fit'),
call.teardown(trainer, model, 'test'),
]