From 4decbc0d95f9d1f5c3f437fbafb167d6ec3beb40 Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Thu, 7 Oct 2021 15:48:11 +0530 Subject: [PATCH] Deprecate `dataloader_idx` from `on_train_batch_start/end` (#9816) * deprecate hooks * dep todo * explicit * Apply suggestions from code review * Apply suggestions from code review * code review * base --- pytorch_lightning/accelerators/accelerator.py | 5 ++- pytorch_lightning/callbacks/base.py | 9 ++++- .../callbacks/gpu_stats_monitor.py | 3 +- .../callbacks/model_checkpoint.py | 5 +-- pytorch_lightning/callbacks/progress/base.py | 6 +-- .../callbacks/progress/rich_progress.py | 4 +- .../callbacks/progress/tqdm_progress.py | 4 +- pytorch_lightning/core/hooks.py | 8 ++-- .../loops/batch/training_batch_loop.py | 10 ++++- .../loops/epoch/training_epoch_loop.py | 11 ++++- .../plugins/training_type/ipu.py | 2 +- .../training_type/training_type_plugin.py | 2 +- pytorch_lightning/trainer/callback_hook.py | 17 ++++++-- .../trainer/configuration_validator.py | 17 ++++++++ pytorch_lightning/tuner/lr_finder.py | 2 +- tests/accelerators/test_tpu_backend.py | 4 +- tests/callbacks/test_callback_hook_outputs.py | 4 +- tests/callbacks/test_progress_bar.py | 8 ++-- tests/core/test_lightning_optimizer.py | 4 +- tests/deprecated_api/test_remove_1-7.py | 40 +++++++++++++++++++ tests/loggers/test_all.py | 2 +- tests/loops/test_training_loop.py | 4 +- tests/models/test_hooks.py | 20 +++++----- tests/plugins/test_deepspeed_plugin.py | 6 +-- .../logging_/test_train_loop_logging.py | 2 +- .../optimization/test_manual_optimization.py | 2 +- tests/trainer/optimization/test_optimizers.py | 2 +- tests/trainer/test_dataloaders.py | 4 +- tests/trainer/test_trainer.py | 4 +- tests/utilities/test_dtype_device_mixin.py | 2 +- tests/utilities/test_fetching.py | 4 +- 31 files changed, 150 insertions(+), 67 deletions(-) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py index d858153a38..cfed45e1db 100644 --- a/pytorch_lightning/accelerators/accelerator.py +++ b/pytorch_lightning/accelerators/accelerator.py @@ -487,6 +487,7 @@ class Accelerator: """Called when train ends.""" return self.training_type_plugin.on_train_end() - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + # TODO: Update this in v1.7 (deprecation: #9816) + def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Called in the training loop before anything happens for that batch.""" - return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx) + return self.training_type_plugin.on_train_batch_start(batch, batch_idx) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 97cf4a5ddb..c041cdbc5b 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -97,7 +97,12 @@ class Callback(abc.ABC): pass def on_train_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + batch: Any, + batch_idx: int, + unused: Optional[int] = 0, ) -> None: """Called when the train batch begins.""" pass @@ -109,7 +114,7 @@ class Callback(abc.ABC): outputs: STEP_OUTPUT, batch: Any, batch_idx: int, - dataloader_idx: int, + unused: Optional[int] = 0, ) -> None: """Called when the train batch ends.""" pass diff --git a/pytorch_lightning/callbacks/gpu_stats_monitor.py b/pytorch_lightning/callbacks/gpu_stats_monitor.py index e09af1ea57..8e9e671949 100644 --- a/pytorch_lightning/callbacks/gpu_stats_monitor.py +++ b/pytorch_lightning/callbacks/gpu_stats_monitor.py @@ -135,7 +135,7 @@ class GPUStatsMonitor(Callback): @rank_zero_only def on_train_batch_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int ) -> None: if self._log_stats.intra_step_time: self._snap_intra_step_time = time.time() @@ -161,7 +161,6 @@ class GPUStatsMonitor(Callback): outputs: STEP_OUTPUT, batch: Any, batch_idx: int, - dataloader_idx: int, ) -> None: if self._log_stats.inter_step_time: self._snap_inter_step_time = time.time() diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index cff2116446..a56a4e7eea 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -279,7 +279,6 @@ class ModelCheckpoint(Callback): outputs: STEP_OUTPUT, batch: Any, batch_idx: int, - dataloader_idx: int, ) -> None: """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" if self._should_skip_saving_checkpoint(trainer): @@ -304,9 +303,7 @@ class ModelCheckpoint(Callback): self.save_checkpoint(trainer) - def on_train_epoch_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None - ) -> None: + def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" # as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates trainer.fit_loop.global_step -= 1 diff --git a/pytorch_lightning/callbacks/progress/base.py b/pytorch_lightning/callbacks/progress/base.py index 334dd05ab3..07cc3136fc 100644 --- a/pytorch_lightning/callbacks/progress/base.py +++ b/pytorch_lightning/callbacks/progress/base.py @@ -35,8 +35,8 @@ class ProgressBarBase(Callback): def disable(self): self.enable = False - def on_train_batch_end(self, trainer, pl_module, outputs): - super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :) + def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch_idx) # don't forget this :) percent = (self.train_batch_idx / self.total_train_batches) * 100 sys.stdout.flush() sys.stdout.write(f'{percent:.01f} percent complete \r') @@ -161,7 +161,7 @@ class ProgressBarBase(Callback): def on_train_epoch_start(self, trainer, pl_module): self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self._train_batch_idx += 1 def on_validation_start(self, trainer, pl_module): diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 4da35c7c7a..8507396370 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -369,8 +369,8 @@ class RichProgressBar(ProgressBarBase): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) self._update(self.main_progress_bar_id) def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): diff --git a/pytorch_lightning/callbacks/progress/tqdm_progress.py b/pytorch_lightning/callbacks/progress/tqdm_progress.py index 7f69115883..7f3b902925 100644 --- a/pytorch_lightning/callbacks/progress/tqdm_progress.py +++ b/pytorch_lightning/callbacks/progress/tqdm_progress.py @@ -231,8 +231,8 @@ class ProgressBar(ProgressBarBase): reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx) self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) total_batches = self.total_train_batches + self.total_val_batches total_batches = convert_inf(total_batches) if self._should_update(self.train_batch_idx, total_batches): diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 4f2161fd03..9903e29793 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -79,7 +79,7 @@ class ModelHooks: - training_start """ - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None: """Called in the training loop before anything happens for that batch. If you return -1 here, you will skip training for the rest of the current epoch. @@ -87,17 +87,17 @@ class ModelHooks: Args: batch: The batched data as it is returned by the training DataLoader. batch_idx: the index of the batch - dataloader_idx: the index of the dataloader + unused: Deprecated argument. Will be removed in v1.7. """ - def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None: """Called in the training loop after the batch. Args: outputs: The outputs of training_step_end(training_step(x)) batch: The batched data as it is returned by the training DataLoader. batch_idx: the index of the batch - dataloader_idx: the index of the dataloader + unused: Deprecated argument. Will be removed in v1.7. """ def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: diff --git a/pytorch_lightning/loops/batch/training_batch_loop.py b/pytorch_lightning/loops/batch/training_batch_loop.py index faf6966ca4..93e156070d 100644 --- a/pytorch_lightning/loops/batch/training_batch_loop.py +++ b/pytorch_lightning/loops/batch/training_batch_loop.py @@ -24,6 +24,7 @@ from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop from pytorch_lightning.loops.utilities import _get_active_optimizers from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.warnings import WarningCache _OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]] @@ -76,7 +77,14 @@ class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]): return AttributeDict(signal=-1) # hook - response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0) + # TODO: Update this in v1.7 (deprecation: #9816) + model_fx = self.trainer.lightning_module.on_train_batch_start + extra_kwargs = ( + {"dataloader_idx": 0} + if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) + else {} + ) + response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs) if response == -1: return AttributeDict(signal=-1) diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index 188891db8f..fe3a2dc743 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -27,6 +27,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import AbstractDataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature _OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE] @@ -170,7 +171,15 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization, num_optimizers=len(self.trainer.optimizers), ) - self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, self.batch_idx, 0) + + # TODO: Update this in v1.7 (deprecation: #9816) + model_fx = self.trainer.lightning_module.on_train_batch_end + extra_kwargs = ( + {"dataloader_idx": 0} + if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True) + else {} + ) + self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs) self.trainer.call_hook("on_batch_end") self.trainer.logger_connector.on_batch_end() diff --git a/pytorch_lightning/plugins/training_type/ipu.py b/pytorch_lightning/plugins/training_type/ipu.py index 8849c22777..daa704e8a8 100644 --- a/pytorch_lightning/plugins/training_type/ipu.py +++ b/pytorch_lightning/plugins/training_type/ipu.py @@ -285,7 +285,7 @@ class IPUPlugin(ParallelPlugin): def on_predict_end(self): self._detach_models() - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: # Updates optimizer stats if LR scheduler modified the optimizer state optimizer = self.lightning_module.trainer.optimizers[0] self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index b1de04c9bb..cf36a35027 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -345,7 +345,7 @@ class TrainingTypePlugin(ABC): """Called when predict ends.""" pass - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: """Called in the training loop before anything happens for that batch.""" pass diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index b8931c4155..d2a4608985 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -21,6 +21,7 @@ from packaging.version import Version import pytorch_lightning as pl from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -161,15 +162,23 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_batch_end(self, self.lightning_module) - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + # TODO: Update this in v1.7 (deprecation: #9816) + def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0): """Called when the training batch begins.""" for callback in self.callbacks: - callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx) + if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True): + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0) + else: + callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx) - def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx): + # TODO: Update this in v1.7 (deprecation: #9816) + def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0): """Called when the training batch ends.""" for callback in self.callbacks: - callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx) + if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True): + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0) + else: + callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx) def on_validation_batch_start(self, batch, batch_idx, dataloader_idx): """Called when the validation batch begins.""" diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 6a450fb8c0..3da05d69c1 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -50,6 +50,8 @@ class ConfigValidator: self._check_on_post_move_to_device(model) # TODO: Delete _check_on_keyboard_interrupt in v1.7 self._check_on_keyboard_interrupt() + # TODO: Remove this in v1.7 (deprecation: #9816) + self._check_dl_idx_in_on_train_batch_hooks(model) def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None: # ----------------------------------- @@ -261,3 +263,18 @@ class ConfigValidator: "The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7." " Please use the `on_exception` callback hook instead." ) + + def _check_dl_idx_in_on_train_batch_hooks(self, model: "pl.LightningModule") -> None: + for hook in ("on_train_batch_start", "on_train_batch_end"): + if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True): + rank_zero_deprecation( + f"Base `LightningModule.{hook}` hook signature has changed in v1.5." + " The `dataloader_idx` argument will be removed in v1.7." + ) + + for cb in self.trainer.callbacks: + if is_param_in_hook_signature(getattr(cb, hook), "dataloader_idx", explicit=True): + rank_zero_deprecation( + f"Base `Callback.{hook}` hook signature has changed in v1.5." + " The `dataloader_idx` argument will be removed in v1.7." + ) diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 93aa253d96..6ba9364e86 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -344,7 +344,7 @@ class _LRCallback(Callback): self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0]) - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): """Called when the training batch ends, logs the calculated loss.""" if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0: return diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 297b63f103..d9b9e67aeb 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -165,7 +165,7 @@ def test_manual_optimization_tpus(tmpdir): def should_update(self): return self.count % 2 == 0 - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): self.called["on_train_batch_start"] += 1 self.weight_before = self.layer.weight.clone() @@ -181,7 +181,7 @@ def test_manual_optimization_tpus(tmpdir): opt.zero_grad() return loss - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): self.called["on_train_batch_end"] += 1 after_before = self.layer.weight.clone() if self.should_update: diff --git a/tests/callbacks/test_callback_hook_outputs.py b/tests/callbacks/test_callback_hook_outputs.py index 45a0c364e1..7e52c4f49e 100644 --- a/tests/callbacks/test_callback_hook_outputs.py +++ b/tests/callbacks/test_callback_hook_outputs.py @@ -22,7 +22,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool): """Tests that only training_step can be used.""" class CB(Callback): - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): assert "loss" in outputs def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): @@ -32,7 +32,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool): assert "x" in outputs class TestModel(BoringModel): - def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None: assert "loss" in outputs def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None: diff --git a/tests/callbacks/test_progress_bar.py b/tests/callbacks/test_progress_bar.py index fe40150f92..746a9717db 100644 --- a/tests/callbacks/test_progress_bar.py +++ b/tests/callbacks/test_progress_bar.py @@ -185,12 +185,12 @@ def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int): val_batches_seen = 0 test_batches_seen = 0 - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): - super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx) + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + super().on_train_batch_start(trainer, pl_module, batch, batch_idx) assert self.train_batch_idx == trainer.fit_loop.batch_idx - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) assert self.train_batch_idx == trainer.fit_loop.batch_idx + 1 if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0: assert self.main_progress_bar.n == self.train_batch_idx diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index a41adb1ad1..3f9ac37c2c 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -331,12 +331,12 @@ def test_lightning_optimizer_keeps_hooks(tmpdir): def configure_optimizers(self): return OptimizerWithHooks(self) - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start(self, batch: Any, batch_idx: int) -> None: self.count_on_train_batch_start += 1 optimizer = self.optimizers(use_pl_optimizer=False) assert len(optimizer._fwd_handles) == 1 - def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None: self.count_on_train_batch_end += 1 del self.trainer._lightning_optimizers gc.collect() # not necessary, just in case diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 6a982af37b..30cb0269d7 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -257,6 +257,46 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir): _ = LightningDistributed() +def test_v1_7_0_old_on_train_batch_start(tmpdir): + class OldSignature(Callback): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + ... + + class OldSignatureModel(BoringModel): + def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature()) + with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): + trainer.fit(model) + + model = OldSignatureModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): + trainer.fit(model) + + +def test_v1_7_0_old_on_train_batch_end(tmpdir): + class OldSignature(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + ... + + class OldSignatureModel(BoringModel): + def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + ... + + model = BoringModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True) + with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): + trainer.fit(model) + + model = OldSignatureModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True) + with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."): + trainer.fit(model) + + def test_v1_7_0_deprecate_on_post_move_to_device(tmpdir): class TestModel(BoringModel): def on_post_move_to_device(self): diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 885e6625d8..6cc8daa75f 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -308,7 +308,7 @@ class RankZeroLoggerCheck(Callback): # this class has to be defined outside the test function, otherwise we get pickle error # due to the way ddp process is launched - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): is_dummy = isinstance(trainer.logger.experiment, DummyExperiment) if trainer.is_global_zero: assert not is_dummy diff --git a/tests/loops/test_training_loop.py b/tests/loops/test_training_loop.py index 5e0d00d3df..a6bc414e0e 100644 --- a/tests/loops/test_training_loop.py +++ b/tests/loops/test_training_loop.py @@ -35,9 +35,9 @@ def test_outputs_format(tmpdir): assert "foo" in output assert output["foo"] == 123 - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): HookedModel._check_output(outputs) - super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx) + super().on_train_batch_end(outputs, batch, batch_idx) def training_epoch_end(self, outputs): assert len(outputs) == 2 diff --git a/tests/models/test_hooks.py b/tests/models/test_hooks.py index de849f90a0..9af883f73c 100644 --- a/tests/models/test_hooks.py +++ b/tests/models/test_hooks.py @@ -91,14 +91,14 @@ def test_training_epoch_end_metrics_collection_on_override(tmpdir): def training_epoch_end(self, outputs): self.len_outputs = len(outputs) - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): self.num_train_batches += 1 class NotOverriddenModel(BoringModel): def on_train_epoch_start(self): self.num_train_batches = 0 - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): self.num_train_batches += 1 overridden_model = OverriddenModel() @@ -289,8 +289,8 @@ class HookedModel(BoringModel): dict(name="on_after_batch_transfer", args=(ANY, 0)), # TODO: `on_batch_{start,end}` dict(name="Callback.on_batch_start", args=(trainer, model)), - dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)), - dict(name="on_train_batch_start", args=(ANY, i, 0)), + dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)), + dict(name="on_train_batch_start", args=(ANY, i)), # without a precision plugin, we execute the closure inside the `optimizer.step` *([] if using_plugin else on_before_optimizer_step), dict(name="forward", args=(ANY,)), @@ -311,8 +311,8 @@ class HookedModel(BoringModel): args=(current_epoch, i, ANY, 0, ANY), kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp), ), - dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i, 0)), - dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i, 0)), + dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)), + dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)), dict(name="Callback.on_batch_end", args=(trainer, model)), ] ) @@ -331,8 +331,8 @@ class HookedModel(BoringModel): dict(name="on_after_batch_transfer", args=(ANY, 0)), # TODO: `on_batch_{start,end}` dict(name="Callback.on_batch_start", args=(trainer, model)), - dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)), - dict(name="on_train_batch_start", args=(ANY, i, 0)), + dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)), + dict(name="on_train_batch_start", args=(ANY, i)), dict(name="forward", args=(ANY,)), dict(name="Callback.on_before_backward", args=(trainer, model, ANY)), dict(name="on_before_backward", args=(ANY,)), @@ -349,8 +349,8 @@ class HookedModel(BoringModel): *([] if using_plugin else [dict(name="closure")]), dict(name="training_step", args=(ANY, i)), dict(name="training_step_end", args=(dict(loss=ANY),)), - dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i, 0)), - dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i, 0)), + dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)), + dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)), dict(name="Callback.on_batch_end", args=(trainer, model)), ] ) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index 9afcbcd517..ca02e9b9b0 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -649,7 +649,7 @@ def test_deepspeed_multigpu_stage_3_resume_training(tmpdir): class TestCallback(Callback): def on_train_batch_start( - self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int + self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int ) -> None: original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin current_deepspeed_plugin = trainer.accelerator.training_type_plugin @@ -707,9 +707,7 @@ def _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimiz def __init__(self): self.on_train_batch_start_called = False - def on_train_batch_start( - self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int - ) -> None: + def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None: deepspeed_engine = trainer.training_type_plugin.model assert trainer.global_step == deepspeed_engine.global_steps self.on_train_batch_start_called = True diff --git a/tests/trainer/logging_/test_train_loop_logging.py b/tests/trainer/logging_/test_train_loop_logging.py index 3c72a78331..83a07ed7ad 100644 --- a/tests/trainer/logging_/test_train_loop_logging.py +++ b/tests/trainer/logging_/test_train_loop_logging.py @@ -501,7 +501,7 @@ def test_logging_in_callbacks_with_log_function(tmpdir): def on_train_epoch_start(self, trainer, pl_module): self.log("on_train_epoch_start", 2) - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): self.log("on_train_batch_end", 3) def on_batch_end(self, trainer, pl_module): diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index c288dee5a5..7fc1fba9c1 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -232,7 +232,7 @@ class ManualOptimizationExtendedModel(BoringModel): def should_update(self): return self.count % 2 == 0 - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): self.called["on_train_batch_start"] += 1 self.weight_before = self.layer.weight.clone() diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index ba17929693..86499e2d8c 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -334,7 +334,7 @@ def test_multiple_optimizers_callbacks(tmpdir): """Tests that multiple optimizers can be used with callbacks.""" class CB(Callback): - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): pass def on_train_epoch_start(self, trainer, pl_module): diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 949db371e4..5c9aacc92a 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -220,7 +220,7 @@ class Counter(Callback): self.val_batches_seen = 0 self.test_batches_seen = 0 - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): self.train_batches_seen += 1 def on_train_epoch_start(self, trainer, pl_module): @@ -1482,7 +1482,7 @@ def test_request_dataloader(tmpdir): self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader)) self.on_train_dataloader_called = True - def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start(self, batch, batch_idx: int) -> None: assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper) self.on_train_batch_start_called = True diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a1e6ce0110..021812464e 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -299,7 +299,7 @@ def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_bat self.start_state_dict = self.state_dict() self.opt_step_called = False - def on_train_batch_end(self, outputs, batch, batch_idx, *_): + def on_train_batch_end(self, outputs, batch, batch_idx): end_state_dict = self.state_dict() is_last_batch = (batch_idx + 1) == self.trainer.num_training_batches @@ -966,7 +966,7 @@ def test_on_exception_hook(tmpdir): def __init__(self): super().__init__() - def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): raise KeyboardInterrupt def on_test_start(self, trainer, pl_module): diff --git a/tests/utilities/test_dtype_device_mixin.py b/tests/utilities/test_dtype_device_mixin.py index b30d013d88..9140b965aa 100644 --- a/tests/utilities/test_dtype_device_mixin.py +++ b/tests/utilities/test_dtype_device_mixin.py @@ -38,7 +38,7 @@ class TopModule(BoringModel): class DeviceAssertCallback(Callback): - def on_train_batch_start(self, trainer, model, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, trainer, model, batch, batch_idx): rank = trainer.local_rank assert isinstance(model, TopModule) # index = None also means first device diff --git a/tests/utilities/test_fetching.py b/tests/utilities/test_fetching.py index fd3df1df0c..88c232b76a 100644 --- a/tests/utilities/test_fetching.py +++ b/tests/utilities/test_fetching.py @@ -356,7 +356,7 @@ def test_on_train_batch_start_overridden(tmpdir) -> None: `LightningModule`.""" class InvalidModel(AsyncBoringModel): - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): pass trainer = Trainer(max_epochs=1, default_root_dir=tmpdir) @@ -370,7 +370,7 @@ def test_on_train_batch_end_overridden(tmpdir) -> None: `LightningModule`.""" class InvalidModel(AsyncBoringModel): - def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, outputs, batch, batch_idx): pass trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)