Refactor some loops code and hook tests (#7682)
This commit is contained in:
parent
8ba6304c73
commit
e2ead9abd7
|
@ -11,30 +11,30 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
from weakref import proxy
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class OptimizerConnector:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
def __init__(self, trainer: 'pl.Trainer') -> None:
|
||||
self.trainer = proxy(trainer)
|
||||
|
||||
def on_trainer_init(self):
|
||||
def on_trainer_init(self) -> None:
|
||||
self.trainer.lr_schedulers = []
|
||||
self.trainer.optimizers = []
|
||||
self.trainer.optimizer_frequencies = []
|
||||
|
||||
def update_learning_rates(
|
||||
self, interval: str, monitor_metrics: Optional[Dict[str, Any]] = None, opt_indices: Optional[List[int]] = None
|
||||
):
|
||||
def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None:
|
||||
"""Update learning rates.
|
||||
|
||||
Args:
|
||||
interval: either 'epoch' or 'step'.
|
||||
monitor_metrics: dict of possible values to monitor
|
||||
opt_indices: indices of the optimizers to update.
|
||||
"""
|
||||
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
|
||||
return
|
||||
|
@ -55,10 +55,7 @@ class OptimizerConnector:
|
|||
monitor_key, monitor_val = None, None
|
||||
if lr_scheduler['reduce_on_plateau']:
|
||||
monitor_key = lr_scheduler['monitor']
|
||||
monitor_val = (
|
||||
monitor_metrics.get(monitor_key) if monitor_metrics is not None else
|
||||
self.trainer.logger_connector.callback_metrics.get(monitor_key)
|
||||
)
|
||||
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)
|
||||
if monitor_val is None:
|
||||
if lr_scheduler.get('strict', True):
|
||||
avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys())
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager, suppress
|
||||
from copy import copy, deepcopy
|
||||
from copy import copy
|
||||
from functools import partial, update_wrapper
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
@ -478,7 +478,6 @@ class TrainLoop:
|
|||
|
||||
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
|
||||
dataloader_idx = 0
|
||||
|
||||
batch_idx = None
|
||||
is_last_batch = None
|
||||
|
||||
|
@ -525,8 +524,7 @@ class TrainLoop:
|
|||
self.save_loggers_on_train_batch_end()
|
||||
|
||||
# update LR schedulers
|
||||
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
|
||||
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
|
||||
self.update_lr_schedulers('step')
|
||||
self.trainer.checkpoint_connector.has_trained = True
|
||||
|
||||
self.total_batch_idx += 1
|
||||
|
@ -567,7 +565,7 @@ class TrainLoop:
|
|||
|
||||
# update epoch level lr_schedulers if no val loop outside train loop is triggered
|
||||
if not should_check_val or should_train_only:
|
||||
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
|
||||
self.update_lr_schedulers('epoch')
|
||||
|
||||
if should_train_only:
|
||||
self.check_checkpoint_callback(True)
|
||||
|
@ -863,15 +861,14 @@ class TrainLoop:
|
|||
# track gradients
|
||||
result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer)
|
||||
|
||||
def update_train_loop_lr_schedulers(self, monitor_metrics=None):
|
||||
num_accumulated_batches_reached = self._accumulated_batches_reached()
|
||||
num_training_batches_reached = self._num_training_batches_reached()
|
||||
|
||||
if num_accumulated_batches_reached or num_training_batches_reached:
|
||||
# update lr
|
||||
def update_lr_schedulers(self, interval: str) -> None:
|
||||
if interval == "step":
|
||||
finished_accumulation = self._accumulated_batches_reached()
|
||||
finished_epoch = self._num_training_batches_reached()
|
||||
if not finished_accumulation and not finished_epoch:
|
||||
return
|
||||
self.trainer.optimizer_connector.update_learning_rates(
|
||||
interval="step",
|
||||
monitor_metrics=monitor_metrics,
|
||||
interval=interval,
|
||||
opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()],
|
||||
)
|
||||
|
||||
|
@ -897,15 +894,21 @@ class TrainLoop:
|
|||
|
||||
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
|
||||
""" Decide if we should run validation. """
|
||||
|
||||
if not self.trainer.enable_validation:
|
||||
return False
|
||||
|
||||
# check if this epoch is eligible to run validation
|
||||
if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
|
||||
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
|
||||
if not is_val_check_epoch:
|
||||
return False
|
||||
|
||||
# val_check_batch is inf for iterable datasets with no length defined
|
||||
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
|
||||
if on_epoch and is_last_batch and is_infinite_dataset:
|
||||
return True
|
||||
|
||||
if on_epoch and self.trainer.should_stop:
|
||||
return True
|
||||
|
||||
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
|
||||
is_val_check_batch = False
|
||||
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
|
||||
|
@ -915,12 +918,9 @@ class TrainLoop:
|
|||
|
||||
# Note: num_training_batches is also inf for iterable datasets with no length defined
|
||||
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
|
||||
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
|
||||
|
||||
if on_epoch:
|
||||
return (
|
||||
is_val_check_batch and epoch_end_val_check
|
||||
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset
|
||||
return is_val_check_batch and epoch_end_val_check
|
||||
else:
|
||||
return is_val_check_batch and not epoch_end_val_check
|
||||
|
||||
|
|
|
@ -157,7 +157,7 @@ def test_early_stopping_patience_train(
|
|||
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
|
||||
|
||||
class ModelOverrideTrainReturn(BoringModel):
|
||||
train_return_values = torch.Tensor(loss_values)
|
||||
train_return_values = torch.tensor(loss_values)
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
loss = self.train_return_values[self.current_epoch]
|
||||
|
|
|
@ -264,67 +264,42 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
|
|||
|
||||
|
||||
@mock.patch('pytorch_lightning.loggers.TensorBoardLogger.log_metrics')
|
||||
@pytest.mark.parametrize('expected', [
|
||||
([5, 11, 17]),
|
||||
])
|
||||
def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmpdir):
|
||||
"""
|
||||
Tests to ensure that tensorboard log properly when accumulated_gradients > 1
|
||||
"""
|
||||
def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
|
||||
"""Tests to ensure that tensorboard log properly when accumulated_gradients > 1"""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._count = 0
|
||||
self._indexes = []
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.log('count', self._count, on_step=True, on_epoch=True)
|
||||
self.log('loss', loss, on_step=True, on_epoch=True)
|
||||
self.indexes = []
|
||||
|
||||
def training_step(self, *args):
|
||||
self.log('foo', 1, on_step=True, on_epoch=True)
|
||||
if not self.trainer.train_loop.should_accumulate():
|
||||
if self.trainer.logger_connector.should_update_logs:
|
||||
self._indexes.append(self.trainer.global_step)
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.log('val_loss', loss, on_step=True, on_epoch=True)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=.001)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
self.indexes.append(self.trainer.global_step)
|
||||
return super().training_step(*args)
|
||||
|
||||
model = TestModel()
|
||||
model.training_epoch_end = None
|
||||
model.validation_epoch_end = None
|
||||
|
||||
logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False)
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=12,
|
||||
limit_val_batches=0,
|
||||
max_epochs=3,
|
||||
gpus=0,
|
||||
accumulate_grad_batches=2,
|
||||
logger=[logger_0],
|
||||
log_every_n_steps=3,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]]
|
||||
assert mock_count_epochs == expected
|
||||
calls = [m[2] for m in mock_log_metrics.mock_calls]
|
||||
count_epochs = [c["step"] for c in calls if "foo_epoch" in c["metrics"]]
|
||||
assert count_epochs == [5, 11, 17]
|
||||
|
||||
mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]]
|
||||
assert model._indexes == mock_count_steps
|
||||
count_steps = [c["step"] for c in calls if "foo_step" in c["metrics"]]
|
||||
assert count_steps == model.indexes
|
||||
|
||||
|
||||
@mock.patch('pytorch_lightning.loggers.tensorboard.SummaryWriter')
|
||||
|
|
|
@ -234,11 +234,55 @@ class HookedModel(BoringModel):
|
|||
def __init__(self):
|
||||
super().__init__()
|
||||
self.called = []
|
||||
self.train_batch = [
|
||||
'on_train_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'training_step',
|
||||
'on_before_zero_grad',
|
||||
'optimizer_zero_grad',
|
||||
'backward',
|
||||
'on_after_backward',
|
||||
'optimizer_step',
|
||||
'on_train_batch_end',
|
||||
]
|
||||
self.val_batch = [
|
||||
'on_validation_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'on_validation_batch_end',
|
||||
]
|
||||
|
||||
def training_step(self, *args, **kwargs):
|
||||
self.called.append("training_step")
|
||||
return super().training_step(*args, **kwargs)
|
||||
|
||||
def optimizer_zero_grad(self, *args, **kwargs):
|
||||
self.called.append("optimizer_zero_grad")
|
||||
super().optimizer_zero_grad(*args, **kwargs)
|
||||
|
||||
def training_epoch_end(self, *args, **kwargs):
|
||||
self.called.append("training_epoch_end")
|
||||
super().training_epoch_end(*args, **kwargs)
|
||||
|
||||
def backward(self, *args, **kwargs):
|
||||
self.called.append("backward")
|
||||
super().backward(*args, **kwargs)
|
||||
|
||||
def on_after_backward(self):
|
||||
self.called.append("on_after_backward")
|
||||
super().on_after_backward()
|
||||
|
||||
def optimizer_step(self, *args, **kwargs):
|
||||
super().optimizer_step(*args, **kwargs)
|
||||
self.called.append("optimizer_step") # append after as closure calls other methods
|
||||
|
||||
def validation_epoch_end(self, *args, **kwargs):
|
||||
self.called.append("validation_epoch_end")
|
||||
super().validation_epoch_end(*args, **kwargs)
|
||||
|
||||
def on_before_zero_grad(self, *args, **kwargs):
|
||||
self.called.append("on_before_zero_grad")
|
||||
super().on_before_zero_grad(*args, **kwargs)
|
||||
|
@ -394,12 +438,13 @@ class HookedModel(BoringModel):
|
|||
|
||||
def test_trainer_model_hook_system_fit(tmpdir):
|
||||
model = HookedModel()
|
||||
train_batches = 2
|
||||
val_batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=1,
|
||||
limit_train_batches=2,
|
||||
limit_test_batches=1,
|
||||
limit_train_batches=train_batches,
|
||||
limit_val_batches=val_batches,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
@ -414,11 +459,8 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
'on_validation_start',
|
||||
'on_epoch_start',
|
||||
'on_validation_epoch_start',
|
||||
'on_validation_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'on_validation_batch_end',
|
||||
*(model.val_batch * val_batches),
|
||||
'validation_epoch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_validation_end',
|
||||
|
@ -426,31 +468,16 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
'on_train_start',
|
||||
'on_epoch_start',
|
||||
'on_train_epoch_start',
|
||||
'on_train_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'on_before_zero_grad',
|
||||
'on_after_backward',
|
||||
'on_train_batch_end',
|
||||
'on_train_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'on_before_zero_grad',
|
||||
'on_after_backward',
|
||||
'on_train_batch_end',
|
||||
*(model.train_batch * train_batches),
|
||||
'training_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_validation_model_eval',
|
||||
'on_validation_start',
|
||||
'on_epoch_start',
|
||||
'on_validation_epoch_start',
|
||||
'on_validation_batch_start',
|
||||
'on_before_batch_transfer',
|
||||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'on_validation_batch_end',
|
||||
*(model.val_batch * val_batches),
|
||||
'validation_epoch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_save_checkpoint',
|
||||
|
@ -463,14 +490,45 @@ def test_trainer_model_hook_system_fit(tmpdir):
|
|||
assert model.called == expected
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_fit_no_val(tmpdir):
|
||||
model = HookedModel()
|
||||
train_batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=0,
|
||||
limit_train_batches=train_batches,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
assert model.called == []
|
||||
trainer.fit(model)
|
||||
expected = [
|
||||
'setup_fit',
|
||||
'on_fit_start',
|
||||
'on_pretrain_routine_start',
|
||||
'on_pretrain_routine_end',
|
||||
'on_train_start',
|
||||
'on_epoch_start',
|
||||
'on_train_epoch_start',
|
||||
*(model.train_batch * train_batches),
|
||||
'training_epoch_end',
|
||||
'on_train_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_save_checkpoint', # from train epoch end
|
||||
'on_train_end',
|
||||
'on_fit_end',
|
||||
'teardown_fit',
|
||||
]
|
||||
assert model.called == expected
|
||||
|
||||
|
||||
def test_trainer_model_hook_system_validate(tmpdir):
|
||||
model = HookedModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=1,
|
||||
limit_train_batches=2,
|
||||
limit_test_batches=1,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
@ -487,6 +545,7 @@ def test_trainer_model_hook_system_validate(tmpdir):
|
|||
'transfer_batch_to_device',
|
||||
'on_after_batch_transfer',
|
||||
'on_validation_batch_end',
|
||||
'validation_epoch_end',
|
||||
'on_validation_epoch_end',
|
||||
'on_epoch_end',
|
||||
'on_validation_end',
|
||||
|
@ -501,8 +560,6 @@ def test_trainer_model_hook_system_test(tmpdir):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=1,
|
||||
limit_train_batches=2,
|
||||
limit_test_batches=1,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
|
@ -644,7 +701,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
reload_dataloaders_every_epoch=True,
|
||||
)
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'setup_fit',
|
||||
|
@ -669,7 +725,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
|
||||
dm = HookedDataModule()
|
||||
trainer.validate(model, datamodule=dm, verbose=False)
|
||||
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'setup_validate',
|
||||
|
@ -683,7 +738,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
|
|||
|
||||
dm = HookedDataModule()
|
||||
trainer.test(model, datamodule=dm, verbose=False)
|
||||
|
||||
expected = [
|
||||
'prepare_data',
|
||||
'setup_test',
|
||||
|
|
|
@ -18,118 +18,6 @@ from pytorch_lightning import seed_everything, Trainer
|
|||
from tests.helpers import BoringModel
|
||||
|
||||
|
||||
def test_training_loop_hook_call_order(tmpdir):
|
||||
"""Tests that hooks / methods called in the training loop are in the correct order as detailed in the docs:
|
||||
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks"""
|
||||
|
||||
class HookedModel(BoringModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.called = []
|
||||
|
||||
def on_epoch_start(self):
|
||||
self.called.append("on_epoch_start")
|
||||
super().on_epoch_start()
|
||||
|
||||
def on_train_epoch_start(self):
|
||||
self.called.append("on_train_epoch_start")
|
||||
super().on_train_epoch_start()
|
||||
|
||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
self.called.append("on_train_batch_start")
|
||||
super().on_train_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
self.called.append("training_step")
|
||||
return super().training_step(batch, batch_idx)
|
||||
|
||||
def on_before_zero_grad(self, optimizer):
|
||||
self.called.append("on_before_zero_grad")
|
||||
super().on_before_zero_grad(optimizer)
|
||||
|
||||
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
|
||||
self.called.append("optimizer_zero_grad")
|
||||
super().optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx)
|
||||
|
||||
def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs):
|
||||
self.called.append("backward")
|
||||
super().backward(loss, optimizer, optimizer_idx, *args, **kwargs)
|
||||
|
||||
def on_after_backward(self):
|
||||
self.called.append("on_after_backward")
|
||||
super().on_after_backward()
|
||||
|
||||
def optimizer_step(
|
||||
self,
|
||||
epoch,
|
||||
batch_idx,
|
||||
optimizer,
|
||||
optimizer_idx,
|
||||
optimizer_closure,
|
||||
on_tpu,
|
||||
using_native_amp,
|
||||
using_lbfgs,
|
||||
):
|
||||
super().optimizer_step(
|
||||
epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
|
||||
)
|
||||
self.called.append("optimizer_step") # append after as closure calls other methods
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
|
||||
self.called.append("on_train_batch_end")
|
||||
super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
self.called.append("training_epoch_end")
|
||||
super().training_epoch_end(outputs)
|
||||
|
||||
def on_train_epoch_end(self, outputs):
|
||||
self.called.append("on_train_epoch_end")
|
||||
super().on_train_epoch_end(outputs)
|
||||
|
||||
def on_epoch_end(self):
|
||||
self.called.append("on_epoch_end")
|
||||
super().on_epoch_end()
|
||||
|
||||
model = HookedModel()
|
||||
|
||||
# fit model
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=1,
|
||||
limit_val_batches=1,
|
||||
limit_train_batches=1,
|
||||
limit_test_batches=1,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
)
|
||||
|
||||
assert model.called == []
|
||||
|
||||
trainer.fit(model)
|
||||
expected = [
|
||||
"on_epoch_start", # validation
|
||||
"on_epoch_end",
|
||||
"on_epoch_start", # training
|
||||
"on_train_epoch_start",
|
||||
"on_train_batch_start",
|
||||
"training_step",
|
||||
"on_before_zero_grad",
|
||||
"optimizer_zero_grad",
|
||||
"backward",
|
||||
"on_after_backward",
|
||||
"optimizer_step",
|
||||
"on_train_batch_end",
|
||||
"training_epoch_end",
|
||||
"on_train_epoch_end",
|
||||
"on_epoch_end",
|
||||
"on_epoch_start", # validation
|
||||
"on_epoch_end",
|
||||
]
|
||||
assert model.called == expected
|
||||
|
||||
|
||||
def test_outputs_format(tmpdir):
|
||||
"""Tests that outputs objects passed to model hooks and methods are consistent and in the correct format."""
|
||||
|
||||
|
|
Loading…
Reference in New Issue