Refactor some loops code and hook tests (#7682)

This commit is contained in:
Carlos Mocholí 2021-05-25 13:27:54 +02:00 committed by GitHub
parent 8ba6304c73
commit e2ead9abd7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 133 additions and 219 deletions

View File

@ -11,30 +11,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import List, Optional
from weakref import proxy
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class OptimizerConnector:
def __init__(self, trainer):
self.trainer = trainer
def __init__(self, trainer: 'pl.Trainer') -> None:
self.trainer = proxy(trainer)
def on_trainer_init(self):
def on_trainer_init(self) -> None:
self.trainer.lr_schedulers = []
self.trainer.optimizers = []
self.trainer.optimizer_frequencies = []
def update_learning_rates(
self, interval: str, monitor_metrics: Optional[Dict[str, Any]] = None, opt_indices: Optional[List[int]] = None
):
def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None:
"""Update learning rates.
Args:
interval: either 'epoch' or 'step'.
monitor_metrics: dict of possible values to monitor
opt_indices: indices of the optimizers to update.
"""
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
return
@ -55,10 +55,7 @@ class OptimizerConnector:
monitor_key, monitor_val = None, None
if lr_scheduler['reduce_on_plateau']:
monitor_key = lr_scheduler['monitor']
monitor_val = (
monitor_metrics.get(monitor_key) if monitor_metrics is not None else
self.trainer.logger_connector.callback_metrics.get(monitor_key)
)
monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key)
if monitor_val is None:
if lr_scheduler.get('strict', True):
avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys())

View File

@ -14,7 +14,7 @@
from collections import OrderedDict
from contextlib import contextmanager, suppress
from copy import copy, deepcopy
from copy import copy
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@ -478,7 +478,6 @@ class TrainLoop:
train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
batch_idx = None
is_last_batch = None
@ -525,8 +524,7 @@ class TrainLoop:
self.save_loggers_on_train_batch_end()
# update LR schedulers
monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics)
self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics)
self.update_lr_schedulers('step')
self.trainer.checkpoint_connector.has_trained = True
self.total_batch_idx += 1
@ -567,7 +565,7 @@ class TrainLoop:
# update epoch level lr_schedulers if no val loop outside train loop is triggered
if not should_check_val or should_train_only:
self.trainer.optimizer_connector.update_learning_rates(interval='epoch')
self.update_lr_schedulers('epoch')
if should_train_only:
self.check_checkpoint_callback(True)
@ -863,15 +861,14 @@ class TrainLoop:
# track gradients
result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer)
def update_train_loop_lr_schedulers(self, monitor_metrics=None):
num_accumulated_batches_reached = self._accumulated_batches_reached()
num_training_batches_reached = self._num_training_batches_reached()
if num_accumulated_batches_reached or num_training_batches_reached:
# update lr
def update_lr_schedulers(self, interval: str) -> None:
if interval == "step":
finished_accumulation = self._accumulated_batches_reached()
finished_epoch = self._num_training_batches_reached()
if not finished_accumulation and not finished_epoch:
return
self.trainer.optimizer_connector.update_learning_rates(
interval="step",
monitor_metrics=monitor_metrics,
interval=interval,
opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()],
)
@ -897,15 +894,21 @@ class TrainLoop:
def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool:
""" Decide if we should run validation. """
if not self.trainer.enable_validation:
return False
# check if this epoch is eligible to run validation
if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0:
is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0
if not is_val_check_epoch:
return False
# val_check_batch is inf for iterable datasets with no length defined
is_infinite_dataset = self.trainer.val_check_batch == float('inf')
if on_epoch and is_last_batch and is_infinite_dataset:
return True
if on_epoch and self.trainer.should_stop:
return True
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
is_val_check_batch = False
if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'):
@ -915,12 +918,9 @@ class TrainLoop:
# Note: num_training_batches is also inf for iterable datasets with no length defined
epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0
is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf")
if on_epoch:
return (
is_val_check_batch and epoch_end_val_check
) or self.trainer.should_stop or is_last_batch_for_infinite_dataset
return is_val_check_batch and epoch_end_val_check
else:
return is_val_check_batch and not epoch_end_val_check

View File

@ -157,7 +157,7 @@ def test_early_stopping_patience_train(
"""Test to ensure that early stopping is not triggered before patience is exhausted."""
class ModelOverrideTrainReturn(BoringModel):
train_return_values = torch.Tensor(loss_values)
train_return_values = torch.tensor(loss_values)
def training_epoch_end(self, outputs):
loss = self.train_return_values[self.current_epoch]

View File

@ -264,67 +264,42 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir):
@mock.patch('pytorch_lightning.loggers.TensorBoardLogger.log_metrics')
@pytest.mark.parametrize('expected', [
([5, 11, 17]),
])
def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmpdir):
"""
Tests to ensure that tensorboard log properly when accumulated_gradients > 1
"""
def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir):
"""Tests to ensure that tensorboard log properly when accumulated_gradients > 1"""
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self._count = 0
self._indexes = []
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('count', self._count, on_step=True, on_epoch=True)
self.log('loss', loss, on_step=True, on_epoch=True)
self.indexes = []
def training_step(self, *args):
self.log('foo', 1, on_step=True, on_epoch=True)
if not self.trainer.train_loop.should_accumulate():
if self.trainer.logger_connector.should_update_logs:
self._indexes.append(self.trainer.global_step)
return loss
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('val_loss', loss, on_step=True, on_epoch=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
self.indexes.append(self.trainer.global_step)
return super().training_step(*args)
model = TestModel()
model.training_epoch_end = None
model.validation_epoch_end = None
logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False)
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=12,
limit_val_batches=0,
max_epochs=3,
gpus=0,
accumulate_grad_batches=2,
logger=[logger_0],
log_every_n_steps=3,
)
trainer.fit(model)
mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]]
assert mock_count_epochs == expected
calls = [m[2] for m in mock_log_metrics.mock_calls]
count_epochs = [c["step"] for c in calls if "foo_epoch" in c["metrics"]]
assert count_epochs == [5, 11, 17]
mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]]
assert model._indexes == mock_count_steps
count_steps = [c["step"] for c in calls if "foo_step" in c["metrics"]]
assert count_steps == model.indexes
@mock.patch('pytorch_lightning.loggers.tensorboard.SummaryWriter')

View File

@ -234,11 +234,55 @@ class HookedModel(BoringModel):
def __init__(self):
super().__init__()
self.called = []
self.train_batch = [
'on_train_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'training_step',
'on_before_zero_grad',
'optimizer_zero_grad',
'backward',
'on_after_backward',
'optimizer_step',
'on_train_batch_end',
]
self.val_batch = [
'on_validation_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'on_validation_batch_end',
]
def training_step(self, *args, **kwargs):
self.called.append("training_step")
return super().training_step(*args, **kwargs)
def optimizer_zero_grad(self, *args, **kwargs):
self.called.append("optimizer_zero_grad")
super().optimizer_zero_grad(*args, **kwargs)
def training_epoch_end(self, *args, **kwargs):
self.called.append("training_epoch_end")
super().training_epoch_end(*args, **kwargs)
def backward(self, *args, **kwargs):
self.called.append("backward")
super().backward(*args, **kwargs)
def on_after_backward(self):
self.called.append("on_after_backward")
super().on_after_backward()
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
self.called.append("optimizer_step") # append after as closure calls other methods
def validation_epoch_end(self, *args, **kwargs):
self.called.append("validation_epoch_end")
super().validation_epoch_end(*args, **kwargs)
def on_before_zero_grad(self, *args, **kwargs):
self.called.append("on_before_zero_grad")
super().on_before_zero_grad(*args, **kwargs)
@ -394,12 +438,13 @@ class HookedModel(BoringModel):
def test_trainer_model_hook_system_fit(tmpdir):
model = HookedModel()
train_batches = 2
val_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=1,
limit_train_batches=2,
limit_test_batches=1,
limit_train_batches=train_batches,
limit_val_batches=val_batches,
progress_bar_refresh_rate=0,
weights_summary=None,
)
@ -414,11 +459,8 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'on_validation_batch_end',
*(model.val_batch * val_batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_validation_end',
@ -426,31 +468,16 @@ def test_trainer_model_hook_system_fit(tmpdir):
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
'on_train_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'on_before_zero_grad',
'on_after_backward',
'on_train_batch_end',
'on_train_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'on_before_zero_grad',
'on_after_backward',
'on_train_batch_end',
*(model.train_batch * train_batches),
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_validation_model_eval',
'on_validation_start',
'on_epoch_start',
'on_validation_epoch_start',
'on_validation_batch_start',
'on_before_batch_transfer',
'transfer_batch_to_device',
'on_after_batch_transfer',
'on_validation_batch_end',
*(model.val_batch * val_batches),
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_save_checkpoint',
@ -463,14 +490,45 @@ def test_trainer_model_hook_system_fit(tmpdir):
assert model.called == expected
def test_trainer_model_hook_system_fit_no_val(tmpdir):
model = HookedModel()
train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=0,
limit_train_batches=train_batches,
progress_bar_refresh_rate=0,
weights_summary=None,
)
assert model.called == []
trainer.fit(model)
expected = [
'setup_fit',
'on_fit_start',
'on_pretrain_routine_start',
'on_pretrain_routine_end',
'on_train_start',
'on_epoch_start',
'on_train_epoch_start',
*(model.train_batch * train_batches),
'training_epoch_end',
'on_train_epoch_end',
'on_epoch_end',
'on_save_checkpoint', # from train epoch end
'on_train_end',
'on_fit_end',
'teardown_fit',
]
assert model.called == expected
def test_trainer_model_hook_system_validate(tmpdir):
model = HookedModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=1,
limit_train_batches=2,
limit_test_batches=1,
progress_bar_refresh_rate=0,
weights_summary=None,
)
@ -487,6 +545,7 @@ def test_trainer_model_hook_system_validate(tmpdir):
'transfer_batch_to_device',
'on_after_batch_transfer',
'on_validation_batch_end',
'validation_epoch_end',
'on_validation_epoch_end',
'on_epoch_end',
'on_validation_end',
@ -501,8 +560,6 @@ def test_trainer_model_hook_system_test(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=1,
limit_train_batches=2,
limit_test_batches=1,
progress_bar_refresh_rate=0,
weights_summary=None,
@ -644,7 +701,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
reload_dataloaders_every_epoch=True,
)
trainer.fit(model, datamodule=dm)
expected = [
'prepare_data',
'setup_fit',
@ -669,7 +725,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
dm = HookedDataModule()
trainer.validate(model, datamodule=dm, verbose=False)
expected = [
'prepare_data',
'setup_validate',
@ -683,7 +738,6 @@ def test_trainer_datamodule_hook_system(tmpdir):
dm = HookedDataModule()
trainer.test(model, datamodule=dm, verbose=False)
expected = [
'prepare_data',
'setup_test',

View File

@ -18,118 +18,6 @@ from pytorch_lightning import seed_everything, Trainer
from tests.helpers import BoringModel
def test_training_loop_hook_call_order(tmpdir):
"""Tests that hooks / methods called in the training loop are in the correct order as detailed in the docs:
https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks"""
class HookedModel(BoringModel):
def __init__(self):
super().__init__()
self.called = []
def on_epoch_start(self):
self.called.append("on_epoch_start")
super().on_epoch_start()
def on_train_epoch_start(self):
self.called.append("on_train_epoch_start")
super().on_train_epoch_start()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
self.called.append("on_train_batch_start")
super().on_train_batch_start(batch, batch_idx, dataloader_idx)
def training_step(self, batch, batch_idx):
self.called.append("training_step")
return super().training_step(batch, batch_idx)
def on_before_zero_grad(self, optimizer):
self.called.append("on_before_zero_grad")
super().on_before_zero_grad(optimizer)
def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
self.called.append("optimizer_zero_grad")
super().optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx)
def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs):
self.called.append("backward")
super().backward(loss, optimizer, optimizer_idx, *args, **kwargs)
def on_after_backward(self):
self.called.append("on_after_backward")
super().on_after_backward()
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu,
using_native_amp,
using_lbfgs,
):
super().optimizer_step(
epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs
)
self.called.append("optimizer_step") # append after as closure calls other methods
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
self.called.append("on_train_batch_end")
super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
def training_epoch_end(self, outputs):
self.called.append("training_epoch_end")
super().training_epoch_end(outputs)
def on_train_epoch_end(self, outputs):
self.called.append("on_train_epoch_end")
super().on_train_epoch_end(outputs)
def on_epoch_end(self):
self.called.append("on_epoch_end")
super().on_epoch_end()
model = HookedModel()
# fit model
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_val_batches=1,
limit_train_batches=1,
limit_test_batches=1,
progress_bar_refresh_rate=0,
weights_summary=None,
)
assert model.called == []
trainer.fit(model)
expected = [
"on_epoch_start", # validation
"on_epoch_end",
"on_epoch_start", # training
"on_train_epoch_start",
"on_train_batch_start",
"training_step",
"on_before_zero_grad",
"optimizer_zero_grad",
"backward",
"on_after_backward",
"optimizer_step",
"on_train_batch_end",
"training_epoch_end",
"on_train_epoch_end",
"on_epoch_end",
"on_epoch_start", # validation
"on_epoch_end",
]
assert model.called == expected
def test_outputs_format(tmpdir):
"""Tests that outputs objects passed to model hooks and methods are consistent and in the correct format."""