add test for model hooks (#4010)
This commit is contained in:
parent
9edef4023c
commit
3777988502
|
@ -23,9 +23,7 @@ def is_overridden(method_name: str, model: Union[LightningModule, LightningDataM
|
|||
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
|
||||
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
|
||||
|
||||
# assert model, 'no model passes'
|
||||
|
||||
if not hasattr(model, method_name):
|
||||
if not hasattr(model, method_name) or not hasattr(super_object, method_name):
|
||||
# in case of calling deprecated method
|
||||
return False
|
||||
|
||||
|
|
|
@ -11,14 +11,15 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import MagicMock
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators.gpu_accelerator import GPUAccelerator
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.base import EvalModelTemplate, BoringModel
|
||||
|
||||
|
||||
@pytest.mark.parametrize('max_steps', [1, 2, 3])
|
||||
|
@ -142,3 +143,216 @@ def test_on_train_batch_start_hook(max_epochs, batch_idx_):
|
|||
else:
|
||||
assert trainer.batch_idx == batch_idx_
|
||||
assert trainer.global_step == (batch_idx_ + 1) * max_epochs
|
||||
|
||||
|
||||
def test_trainer_model_hook_system(tmpdir):
|
||||
"""Test the hooks system."""
|
||||
|
||||
class HookedModel(BoringModel):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.called = []
|
||||
|
||||
def on_after_backward(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_after_backward()
|
||||
|
||||
def on_before_zero_grad(self, optimizer):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_before_zero_grad(optimizer)
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_epoch_start()
|
||||
|
||||
def on_epoch_end(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_epoch_end()
|
||||
|
||||
def on_fit_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_fit_start()
|
||||
|
||||
def on_fit_end(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_fit_end()
|
||||
|
||||
def on_hpc_load(self, checkpoint):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_hpc_load(checkpoint)
|
||||
|
||||
def on_hpc_save(self, checkpoint):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_hpc_save(checkpoint)
|
||||
|
||||
def on_load_checkpoint(self, checkpoint):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_load_checkpoint(checkpoint)
|
||||
|
||||
def on_save_checkpoint(self, checkpoint):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_save_checkpoint(checkpoint)
|
||||
|
||||
def on_pretrain_routine_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_pretrain_routine_start()
|
||||
|
||||
def on_pretrain_routine_end(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_pretrain_routine_end()
|
||||
|
||||
def on_train_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_train_start()
|
||||
|
||||
def on_train_end(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_train_end()
|
||||
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_train_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_train_epoch_start()
|
||||
|
||||
def on_train_epoch_end(self, outputs):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_train_epoch_end(outputs)
|
||||
|
||||
def on_validation_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_start()
|
||||
|
||||
def on_validation_end(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_end()
|
||||
|
||||
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
||||
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
||||
|
||||
def on_validation_epoch_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_epoch_start()
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_epoch_end()
|
||||
|
||||
def on_test_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_start()
|
||||
|
||||
def on_test_end(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_end()
|
||||
|
||||
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
||||
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
||||
|
||||
def on_test_epoch_start(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_epoch_start()
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_epoch_end()
|
||||
|
||||
def on_validation_model_eval(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_model_eval()
|
||||
|
||||
def on_validation_model_train(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_validation_model_train()
|
||||
|
||||
def on_test_model_eval(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_model_eval()
|
||||
|
||||
def on_test_model_train(self):
|
||||
self.called.append(inspect.currentframe().f_code.co_name)
|
||||
super().on_test_model_train()
|
||||
|
||||
model = HookedModel()
|
||||
|
||||
assert model.called == []
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=1,
|
||||
limit_train_batches=2,
|
||||
limit_test_batches=1,
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
|
||||
assert model.called == []
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
assert model.called == [
|
||||
'on_fit_start',
|
||||
'on_pretrain_routine_start',
|
||||
'on_pretrain_routine_end',
|
||||
'on_validation_model_eval',
|
||||
'on_validation_epoch_start',
|
||||
'on_validation_batch_start',
|
||||
'on_validation_batch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_validation_model_train',
|
||||
'on_train_start',
|
||||
'on_epoch_start',
|
||||
'on_train_epoch_start',
|
||||
'on_train_batch_start',
|
||||
'on_after_backward',
|
||||
'on_before_zero_grad',
|
||||
'on_train_batch_end',
|
||||
'on_train_batch_start',
|
||||
'on_after_backward',
|
||||
'on_before_zero_grad',
|
||||
'on_train_batch_end',
|
||||
'on_validation_model_eval',
|
||||
'on_validation_epoch_start',
|
||||
'on_validation_batch_start',
|
||||
'on_validation_batch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_validation_model_train',
|
||||
'on_save_checkpoint',
|
||||
'on_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_train_end',
|
||||
'on_fit_end',
|
||||
]
|
||||
|
||||
model2 = HookedModel()
|
||||
trainer.test(model2)
|
||||
|
||||
assert model2.called == [
|
||||
'on_fit_start',
|
||||
'on_pretrain_routine_start',
|
||||
'on_pretrain_routine_end',
|
||||
'on_test_model_eval',
|
||||
'on_test_epoch_start',
|
||||
'on_test_batch_start',
|
||||
'on_test_batch_end',
|
||||
'on_test_epoch_end',
|
||||
'on_test_model_train',
|
||||
'on_fit_end',
|
||||
]
|
||||
|
|
Loading…
Reference in New Issue