diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index e7fbdf9b18..2797504288 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -11,30 +11,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import List, Optional +from weakref import proxy +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException class OptimizerConnector: - def __init__(self, trainer): - self.trainer = trainer + def __init__(self, trainer: 'pl.Trainer') -> None: + self.trainer = proxy(trainer) - def on_trainer_init(self): + def on_trainer_init(self) -> None: self.trainer.lr_schedulers = [] self.trainer.optimizers = [] self.trainer.optimizer_frequencies = [] - def update_learning_rates( - self, interval: str, monitor_metrics: Optional[Dict[str, Any]] = None, opt_indices: Optional[List[int]] = None - ): + def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None: """Update learning rates. Args: interval: either 'epoch' or 'step'. - monitor_metrics: dict of possible values to monitor + opt_indices: indices of the optimizers to update. """ if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization: return @@ -55,10 +55,7 @@ class OptimizerConnector: monitor_key, monitor_val = None, None if lr_scheduler['reduce_on_plateau']: monitor_key = lr_scheduler['monitor'] - monitor_val = ( - monitor_metrics.get(monitor_key) if monitor_metrics is not None else - self.trainer.logger_connector.callback_metrics.get(monitor_key) - ) + monitor_val = self.trainer.logger_connector.callback_metrics.get(monitor_key) if monitor_val is None: if lr_scheduler.get('strict', True): avail_metrics = list(self.trainer.logger_connector.callback_metrics.keys()) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 10b727edbc..1906679a2b 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -14,7 +14,7 @@ from collections import OrderedDict from contextlib import contextmanager, suppress -from copy import copy, deepcopy +from copy import copy from functools import partial, update_wrapper from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -478,7 +478,6 @@ class TrainLoop: train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 - batch_idx = None is_last_batch = None @@ -525,8 +524,7 @@ class TrainLoop: self.save_loggers_on_train_batch_end() # update LR schedulers - monitor_metrics = deepcopy(self.trainer.logger_connector.callback_metrics) - self.update_train_loop_lr_schedulers(monitor_metrics=monitor_metrics) + self.update_lr_schedulers('step') self.trainer.checkpoint_connector.has_trained = True self.total_batch_idx += 1 @@ -567,7 +565,7 @@ class TrainLoop: # update epoch level lr_schedulers if no val loop outside train loop is triggered if not should_check_val or should_train_only: - self.trainer.optimizer_connector.update_learning_rates(interval='epoch') + self.update_lr_schedulers('epoch') if should_train_only: self.check_checkpoint_callback(True) @@ -863,17 +861,16 @@ class TrainLoop: # track gradients result.grad_norm_dict = self.track_and_norm_grad(optimizer=optimizer) - def update_train_loop_lr_schedulers(self, monitor_metrics=None): - num_accumulated_batches_reached = self._accumulated_batches_reached() - num_training_batches_reached = self._num_training_batches_reached() - - if num_accumulated_batches_reached or num_training_batches_reached: - # update lr - self.trainer.optimizer_connector.update_learning_rates( - interval="step", - monitor_metrics=monitor_metrics, - opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()], - ) + def update_lr_schedulers(self, interval: str) -> None: + if interval == "step": + finished_accumulation = self._accumulated_batches_reached() + finished_epoch = self._num_training_batches_reached() + if not finished_accumulation and not finished_epoch: + return + self.trainer.optimizer_connector.update_learning_rates( + interval=interval, + opt_indices=[opt_idx for opt_idx, _ in self.get_active_optimizers()], + ) def increment_accumulated_grad_global_step(self): num_accumulated_batches_reached = self._accumulated_batches_reached() @@ -897,15 +894,21 @@ class TrainLoop: def _should_check_val_fx(self, batch_idx: int, is_last_batch: bool, on_epoch: bool = False) -> bool: """ Decide if we should run validation. """ - if not self.trainer.enable_validation: return False - # check if this epoch is eligible to run validation - if (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch != 0: + is_val_check_epoch = (self.trainer.current_epoch + 1) % self.trainer.check_val_every_n_epoch == 0 + if not is_val_check_epoch: return False # val_check_batch is inf for iterable datasets with no length defined + is_infinite_dataset = self.trainer.val_check_batch == float('inf') + if on_epoch and is_last_batch and is_infinite_dataset: + return True + + if on_epoch and self.trainer.should_stop: + return True + # TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch is_val_check_batch = False if isinstance(self.trainer.limit_train_batches, int) and self.trainer.val_check_batch == float('inf'): @@ -915,12 +918,9 @@ class TrainLoop: # Note: num_training_batches is also inf for iterable datasets with no length defined epoch_end_val_check = (batch_idx + 1) % self.trainer.num_training_batches == 0 - is_last_batch_for_infinite_dataset = is_last_batch and self.trainer.val_check_batch == float("inf") if on_epoch: - return ( - is_val_check_batch and epoch_end_val_check - ) or self.trainer.should_stop or is_last_batch_for_infinite_dataset + return is_val_check_batch and epoch_end_val_check else: return is_val_check_batch and not epoch_end_val_check diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 8f89bedeb4..b1242de725 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -157,7 +157,7 @@ def test_early_stopping_patience_train( """Test to ensure that early stopping is not triggered before patience is exhausted.""" class ModelOverrideTrainReturn(BoringModel): - train_return_values = torch.Tensor(loss_values) + train_return_values = torch.tensor(loss_values) def training_epoch_end(self, outputs): loss = self.train_return_values[self.current_epoch] diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index f22cdcfe2b..f7fe1c3bfd 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -264,67 +264,42 @@ def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir): @mock.patch('pytorch_lightning.loggers.TensorBoardLogger.log_metrics') -@pytest.mark.parametrize('expected', [ - ([5, 11, 17]), -]) -def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmpdir): - """ - Tests to ensure that tensorboard log properly when accumulated_gradients > 1 - """ +def test_tensorboard_with_accummulated_gradients(mock_log_metrics, tmpdir): + """Tests to ensure that tensorboard log properly when accumulated_gradients > 1""" class TestModel(BoringModel): def __init__(self): super().__init__() - self._count = 0 - self._indexes = [] - - def training_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log('count', self._count, on_step=True, on_epoch=True) - self.log('loss', loss, on_step=True, on_epoch=True) + self.indexes = [] + def training_step(self, *args): + self.log('foo', 1, on_step=True, on_epoch=True) if not self.trainer.train_loop.should_accumulate(): if self.trainer.logger_connector.should_update_logs: - self._indexes.append(self.trainer.global_step) - - return loss - - def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - self.log('val_loss', loss, on_step=True, on_epoch=True) - return loss - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=.001) - lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) - return [optimizer], [lr_scheduler] + self.indexes.append(self.trainer.global_step) + return super().training_step(*args) model = TestModel() model.training_epoch_end = None - model.validation_epoch_end = None - logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False) - trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=12, limit_val_batches=0, max_epochs=3, - gpus=0, accumulate_grad_batches=2, logger=[logger_0], log_every_n_steps=3, ) trainer.fit(model) - mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]] - assert mock_count_epochs == expected + calls = [m[2] for m in mock_log_metrics.mock_calls] + count_epochs = [c["step"] for c in calls if "foo_epoch" in c["metrics"]] + assert count_epochs == [5, 11, 17] - mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]] - assert model._indexes == mock_count_steps + count_steps = [c["step"] for c in calls if "foo_step" in c["metrics"]] + assert count_steps == model.indexes @mock.patch('pytorch_lightning.loggers.tensorboard.SummaryWriter') diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index 5fb9f7243b..678f34d298 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -234,11 +234,55 @@ class HookedModel(BoringModel): def __init__(self): super().__init__() self.called = [] + self.train_batch = [ + 'on_train_batch_start', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'training_step', + 'on_before_zero_grad', + 'optimizer_zero_grad', + 'backward', + 'on_after_backward', + 'optimizer_step', + 'on_train_batch_end', + ] + self.val_batch = [ + 'on_validation_batch_start', + 'on_before_batch_transfer', + 'transfer_batch_to_device', + 'on_after_batch_transfer', + 'on_validation_batch_end', + ] + + def training_step(self, *args, **kwargs): + self.called.append("training_step") + return super().training_step(*args, **kwargs) + + def optimizer_zero_grad(self, *args, **kwargs): + self.called.append("optimizer_zero_grad") + super().optimizer_zero_grad(*args, **kwargs) + + def training_epoch_end(self, *args, **kwargs): + self.called.append("training_epoch_end") + super().training_epoch_end(*args, **kwargs) + + def backward(self, *args, **kwargs): + self.called.append("backward") + super().backward(*args, **kwargs) def on_after_backward(self): self.called.append("on_after_backward") super().on_after_backward() + def optimizer_step(self, *args, **kwargs): + super().optimizer_step(*args, **kwargs) + self.called.append("optimizer_step") # append after as closure calls other methods + + def validation_epoch_end(self, *args, **kwargs): + self.called.append("validation_epoch_end") + super().validation_epoch_end(*args, **kwargs) + def on_before_zero_grad(self, *args, **kwargs): self.called.append("on_before_zero_grad") super().on_before_zero_grad(*args, **kwargs) @@ -394,12 +438,13 @@ class HookedModel(BoringModel): def test_trainer_model_hook_system_fit(tmpdir): model = HookedModel() + train_batches = 2 + val_batches = 2 trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_val_batches=1, - limit_train_batches=2, - limit_test_batches=1, + limit_train_batches=train_batches, + limit_val_batches=val_batches, progress_bar_refresh_rate=0, weights_summary=None, ) @@ -414,11 +459,8 @@ def test_trainer_model_hook_system_fit(tmpdir): 'on_validation_start', 'on_epoch_start', 'on_validation_epoch_start', - 'on_validation_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_validation_batch_end', + *(model.val_batch * val_batches), + 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_validation_end', @@ -426,31 +468,16 @@ def test_trainer_model_hook_system_fit(tmpdir): 'on_train_start', 'on_epoch_start', 'on_train_epoch_start', - 'on_train_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_before_zero_grad', - 'on_after_backward', - 'on_train_batch_end', - 'on_train_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_before_zero_grad', - 'on_after_backward', - 'on_train_batch_end', + *(model.train_batch * train_batches), + 'training_epoch_end', 'on_train_epoch_end', 'on_epoch_end', 'on_validation_model_eval', 'on_validation_start', 'on_epoch_start', 'on_validation_epoch_start', - 'on_validation_batch_start', - 'on_before_batch_transfer', - 'transfer_batch_to_device', - 'on_after_batch_transfer', - 'on_validation_batch_end', + *(model.val_batch * val_batches), + 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_save_checkpoint', @@ -463,14 +490,45 @@ def test_trainer_model_hook_system_fit(tmpdir): assert model.called == expected +def test_trainer_model_hook_system_fit_no_val(tmpdir): + model = HookedModel() + train_batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_val_batches=0, + limit_train_batches=train_batches, + progress_bar_refresh_rate=0, + weights_summary=None, + ) + assert model.called == [] + trainer.fit(model) + expected = [ + 'setup_fit', + 'on_fit_start', + 'on_pretrain_routine_start', + 'on_pretrain_routine_end', + 'on_train_start', + 'on_epoch_start', + 'on_train_epoch_start', + *(model.train_batch * train_batches), + 'training_epoch_end', + 'on_train_epoch_end', + 'on_epoch_end', + 'on_save_checkpoint', # from train epoch end + 'on_train_end', + 'on_fit_end', + 'teardown_fit', + ] + assert model.called == expected + + def test_trainer_model_hook_system_validate(tmpdir): model = HookedModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_val_batches=1, - limit_train_batches=2, - limit_test_batches=1, progress_bar_refresh_rate=0, weights_summary=None, ) @@ -487,6 +545,7 @@ def test_trainer_model_hook_system_validate(tmpdir): 'transfer_batch_to_device', 'on_after_batch_transfer', 'on_validation_batch_end', + 'validation_epoch_end', 'on_validation_epoch_end', 'on_epoch_end', 'on_validation_end', @@ -501,8 +560,6 @@ def test_trainer_model_hook_system_test(tmpdir): trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, - limit_val_batches=1, - limit_train_batches=2, limit_test_batches=1, progress_bar_refresh_rate=0, weights_summary=None, @@ -644,7 +701,6 @@ def test_trainer_datamodule_hook_system(tmpdir): reload_dataloaders_every_epoch=True, ) trainer.fit(model, datamodule=dm) - expected = [ 'prepare_data', 'setup_fit', @@ -669,7 +725,6 @@ def test_trainer_datamodule_hook_system(tmpdir): dm = HookedDataModule() trainer.validate(model, datamodule=dm, verbose=False) - expected = [ 'prepare_data', 'setup_validate', @@ -683,7 +738,6 @@ def test_trainer_datamodule_hook_system(tmpdir): dm = HookedDataModule() trainer.test(model, datamodule=dm, verbose=False) - expected = [ 'prepare_data', 'setup_test', diff --git a/tests/trainer/loops/test_training_loop.py b/tests/trainer/loops/test_training_loop.py index db87a0baab..b89909a40f 100644 --- a/tests/trainer/loops/test_training_loop.py +++ b/tests/trainer/loops/test_training_loop.py @@ -18,118 +18,6 @@ from pytorch_lightning import seed_everything, Trainer from tests.helpers import BoringModel -def test_training_loop_hook_call_order(tmpdir): - """Tests that hooks / methods called in the training loop are in the correct order as detailed in the docs: - https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#hooks""" - - class HookedModel(BoringModel): - - def __init__(self): - super().__init__() - self.called = [] - - def on_epoch_start(self): - self.called.append("on_epoch_start") - super().on_epoch_start() - - def on_train_epoch_start(self): - self.called.append("on_train_epoch_start") - super().on_train_epoch_start() - - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): - self.called.append("on_train_batch_start") - super().on_train_batch_start(batch, batch_idx, dataloader_idx) - - def training_step(self, batch, batch_idx): - self.called.append("training_step") - return super().training_step(batch, batch_idx) - - def on_before_zero_grad(self, optimizer): - self.called.append("on_before_zero_grad") - super().on_before_zero_grad(optimizer) - - def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx): - self.called.append("optimizer_zero_grad") - super().optimizer_zero_grad(epoch, batch_idx, optimizer, optimizer_idx) - - def backward(self, loss, optimizer, optimizer_idx, *args, **kwargs): - self.called.append("backward") - super().backward(loss, optimizer, optimizer_idx, *args, **kwargs) - - def on_after_backward(self): - self.called.append("on_after_backward") - super().on_after_backward() - - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - on_tpu, - using_native_amp, - using_lbfgs, - ): - super().optimizer_step( - epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs - ) - self.called.append("optimizer_step") # append after as closure calls other methods - - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): - self.called.append("on_train_batch_end") - super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) - - def training_epoch_end(self, outputs): - self.called.append("training_epoch_end") - super().training_epoch_end(outputs) - - def on_train_epoch_end(self, outputs): - self.called.append("on_train_epoch_end") - super().on_train_epoch_end(outputs) - - def on_epoch_end(self): - self.called.append("on_epoch_end") - super().on_epoch_end() - - model = HookedModel() - - # fit model - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_val_batches=1, - limit_train_batches=1, - limit_test_batches=1, - progress_bar_refresh_rate=0, - weights_summary=None, - ) - - assert model.called == [] - - trainer.fit(model) - expected = [ - "on_epoch_start", # validation - "on_epoch_end", - "on_epoch_start", # training - "on_train_epoch_start", - "on_train_batch_start", - "training_step", - "on_before_zero_grad", - "optimizer_zero_grad", - "backward", - "on_after_backward", - "optimizer_step", - "on_train_batch_end", - "training_epoch_end", - "on_train_epoch_end", - "on_epoch_end", - "on_epoch_start", # validation - "on_epoch_end", - ] - assert model.called == expected - - def test_outputs_format(tmpdir): """Tests that outputs objects passed to model hooks and methods are consistent and in the correct format."""