Deprecate `dataloader_idx` from `on_train_batch_start/end` (#9816)

* deprecate hooks

* dep todo

* explicit

* Apply suggestions from code review

* Apply suggestions from code review

* code review

* base
This commit is contained in:
Rohit Gupta 2021-10-07 15:48:11 +05:30 committed by GitHub
parent 0561fd6925
commit 4decbc0d95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 150 additions and 67 deletions

View File

@ -487,6 +487,7 @@ class Accelerator:
"""Called when train ends."""
return self.training_type_plugin.on_train_end()
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
# TODO: Update this in v1.7 (deprecation: #9816)
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Called in the training loop before anything happens for that batch."""
return self.training_type_plugin.on_train_batch_start(batch, batch_idx, dataloader_idx)
return self.training_type_plugin.on_train_batch_start(batch, batch_idx)

View File

@ -97,7 +97,12 @@ class Callback(abc.ABC):
pass
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
self,
trainer: "pl.Trainer",
pl_module: "pl.LightningModule",
batch: Any,
batch_idx: int,
unused: Optional[int] = 0,
) -> None:
"""Called when the train batch begins."""
pass
@ -109,7 +114,7 @@ class Callback(abc.ABC):
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
unused: Optional[int] = 0,
) -> None:
"""Called when the train batch ends."""
pass

View File

@ -135,7 +135,7 @@ class GPUStatsMonitor(Callback):
@rank_zero_only
def on_train_batch_start(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
if self._log_stats.intra_step_time:
self._snap_intra_step_time = time.time()
@ -161,7 +161,6 @@ class GPUStatsMonitor(Callback):
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
if self._log_stats.inter_step_time:
self._snap_inter_step_time = time.time()

View File

@ -279,7 +279,6 @@ class ModelCheckpoint(Callback):
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int,
) -> None:
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
if self._should_skip_saving_checkpoint(trainer):
@ -304,9 +303,7 @@ class ModelCheckpoint(Callback):
self.save_checkpoint(trainer)
def on_train_epoch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", unused: Optional = None
) -> None:
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
trainer.fit_loop.global_step -= 1

View File

@ -35,8 +35,8 @@ class ProgressBarBase(Callback):
def disable(self):
self.enable = False
def on_train_batch_end(self, trainer, pl_module, outputs):
super().on_train_batch_end(trainer, pl_module, outputs) # don't forget this :)
def on_train_batch_end(self, trainer, pl_module, outputs, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch_idx) # don't forget this :)
percent = (self.train_batch_idx / self.total_train_batches) * 100
sys.stdout.flush()
sys.stdout.write(f'{percent:.01f} percent complete \r')
@ -161,7 +161,7 @@ class ProgressBarBase(Callback):
def on_train_epoch_start(self, trainer, pl_module):
self._train_batch_idx = trainer.fit_loop.epoch_loop.batch_progress.current.completed
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self._train_batch_idx += 1
def on_validation_start(self, trainer, pl_module):

View File

@ -369,8 +369,8 @@ class RichProgressBar(ProgressBarBase):
super().on_predict_epoch_start(trainer, pl_module)
self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
self._update(self.main_progress_bar_id)
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):

View File

@ -231,8 +231,8 @@ class ProgressBar(ProgressBarBase):
reset(self.main_progress_bar, total=total_batches, current=self.train_batch_idx)
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
total_batches = self.total_train_batches + self.total_val_batches
total_batches = convert_inf(total_batches)
if self._should_update(self.train_batch_idx, total_batches):

View File

@ -79,7 +79,7 @@ class ModelHooks:
- training_start
"""
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
"""Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
@ -87,17 +87,17 @@ class ModelHooks:
Args:
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
unused: Deprecated argument. Will be removed in v1.7.
"""
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
"""Called in the training loop after the batch.
Args:
outputs: The outputs of training_step_end(training_step(x))
batch: The batched data as it is returned by the training DataLoader.
batch_idx: the index of the batch
dataloader_idx: the index of the dataloader
unused: Deprecated argument. Will be removed in v1.7.
"""
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:

View File

@ -24,6 +24,7 @@ from pytorch_lightning.loops.optimization.optimizer_loop import OptimizerLoop
from pytorch_lightning.loops.utilities import _get_active_optimizers
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.warnings import WarningCache
_OUTPUTS_TYPE = List[Union[_OPTIMIZER_LOOP_OUTPUTS_TYPE, _MANUAL_LOOP_OUTPUTS_TYPE]]
@ -76,7 +77,14 @@ class TrainingBatchLoop(Loop[_OUTPUTS_TYPE]):
return AttributeDict(signal=-1)
# hook
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, 0)
# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_start
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
response = self.trainer.call_hook("on_train_batch_start", batch, batch_idx, **extra_kwargs)
if response == -1:
return AttributeDict(signal=-1)

View File

@ -27,6 +27,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
_OUTPUTS_TYPE = List[_BATCH_OUTPUTS_TYPE]
@ -170,7 +171,15 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
automatic=self.trainer.lightning_module.trainer.lightning_module.automatic_optimization,
num_optimizers=len(self.trainer.optimizers),
)
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, self.batch_idx, 0)
# TODO: Update this in v1.7 (deprecation: #9816)
model_fx = self.trainer.lightning_module.on_train_batch_end
extra_kwargs = (
{"dataloader_idx": 0}
if callable(model_fx) and is_param_in_hook_signature(model_fx, "dataloader_idx", explicit=True)
else {}
)
self.trainer.call_hook("on_train_batch_end", batch_end_outputs, batch, batch_idx, **extra_kwargs)
self.trainer.call_hook("on_batch_end")
self.trainer.logger_connector.on_batch_end()

View File

@ -285,7 +285,7 @@ class IPUPlugin(ParallelPlugin):
def on_predict_end(self):
self._detach_models()
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
# Updates optimizer stats if LR scheduler modified the optimizer state
optimizer = self.lightning_module.trainer.optimizers[0]
self.poptorch_models[RunningStage.TRAINING].setOptimizer(optimizer)

View File

@ -345,7 +345,7 @@ class TrainingTypePlugin(ABC):
"""Called when predict ends."""
pass
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
"""Called in the training loop before anything happens for that batch."""
pass

View File

@ -21,6 +21,7 @@ from packaging.version import Version
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
from pytorch_lightning.utilities.types import STEP_OUTPUT
@ -161,15 +162,23 @@ class TrainerCallbackHookMixin(ABC):
for callback in self.callbacks:
callback.on_batch_end(self, self.lightning_module)
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
# TODO: Update this in v1.7 (deprecation: #9816)
def on_train_batch_start(self, batch, batch_idx, dataloader_idx=0):
"""Called when the training batch begins."""
for callback in self.callbacks:
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)
if is_param_in_hook_signature(callback.on_train_batch_start, "dataloader_idx", explicit=True):
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, 0)
else:
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx)
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx):
# TODO: Update this in v1.7 (deprecation: #9816)
def on_train_batch_end(self, outputs: STEP_OUTPUT, batch, batch_idx, dataloader_idx=0):
"""Called when the training batch ends."""
for callback in self.callbacks:
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
if is_param_in_hook_signature(callback.on_train_batch_end, "dataloader_idx", explicit=True):
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, 0)
else:
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx)
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
"""Called when the validation batch begins."""

View File

@ -50,6 +50,8 @@ class ConfigValidator:
self._check_on_post_move_to_device(model)
# TODO: Delete _check_on_keyboard_interrupt in v1.7
self._check_on_keyboard_interrupt()
# TODO: Remove this in v1.7 (deprecation: #9816)
self._check_dl_idx_in_on_train_batch_hooks(model)
def __verify_train_loop_configuration(self, model: "pl.LightningModule") -> None:
# -----------------------------------
@ -261,3 +263,18 @@ class ConfigValidator:
"The `on_keyboard_interrupt` callback hook was deprecated in v1.5 and will be removed in v1.7."
" Please use the `on_exception` callback hook instead."
)
def _check_dl_idx_in_on_train_batch_hooks(self, model: "pl.LightningModule") -> None:
for hook in ("on_train_batch_start", "on_train_batch_end"):
if is_param_in_hook_signature(getattr(model, hook), "dataloader_idx", explicit=True):
rank_zero_deprecation(
f"Base `LightningModule.{hook}` hook signature has changed in v1.5."
" The `dataloader_idx` argument will be removed in v1.7."
)
for cb in self.trainer.callbacks:
if is_param_in_hook_signature(getattr(cb, hook), "dataloader_idx", explicit=True):
rank_zero_deprecation(
f"Base `Callback.{hook}` hook signature has changed in v1.5."
" The `dataloader_idx` argument will be removed in v1.7."
)

View File

@ -344,7 +344,7 @@ class _LRCallback(Callback):
self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0])
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
"""Called when the training batch ends, logs the calculated loss."""
if (trainer.fit_loop.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return

View File

@ -165,7 +165,7 @@ def test_manual_optimization_tpus(tmpdir):
def should_update(self):
return self.count % 2 == 0
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
self.called["on_train_batch_start"] += 1
self.weight_before = self.layer.weight.clone()
@ -181,7 +181,7 @@ def test_manual_optimization_tpus(tmpdir):
opt.zero_grad()
return loss
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
self.called["on_train_batch_end"] += 1
after_before = self.layer.weight.clone()
if self.should_update:

View File

@ -22,7 +22,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
"""Tests that only training_step can be used."""
class CB(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
assert "loss" in outputs
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
@ -32,7 +32,7 @@ def test_train_step_no_return(tmpdir, single_cb: bool):
assert "x" in outputs
class TestModel(BoringModel):
def on_train_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(self, outputs, batch, batch_idx: int) -> None:
assert "loss" in outputs
def on_validation_batch_end(self, outputs, batch, batch_idx: int, dataloader_idx: int) -> None:

View File

@ -185,12 +185,12 @@ def test_progress_bar_progress_refresh(tmpdir, refresh_rate: int):
val_batches_seen = 0
test_batches_seen = 0
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
super().on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
super().on_train_batch_start(trainer, pl_module, batch, batch_idx)
assert self.train_batch_idx == trainer.fit_loop.batch_idx
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
assert self.train_batch_idx == trainer.fit_loop.batch_idx + 1
if not self.is_disabled and self.train_batch_idx % self.refresh_rate == 0:
assert self.main_progress_bar.n == self.train_batch_idx

View File

@ -331,12 +331,12 @@ def test_lightning_optimizer_keeps_hooks(tmpdir):
def configure_optimizers(self):
return OptimizerWithHooks(self)
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch: Any, batch_idx: int) -> None:
self.count_on_train_batch_start += 1
optimizer = self.optimizers(use_pl_optimizer=False)
assert len(optimizer._fwd_handles) == 1
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(self, outputs: Any, batch: Any, batch_idx: int) -> None:
self.count_on_train_batch_end += 1
del self.trainer._lightning_optimizers
gc.collect() # not necessary, just in case

View File

@ -257,6 +257,46 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
_ = LightningDistributed()
def test_v1_7_0_old_on_train_batch_start(tmpdir):
class OldSignature(Callback):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
...
class OldSignatureModel(BoringModel):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
...
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature())
with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
trainer.fit(model)
model = OldSignatureModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1)
with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
trainer.fit(model)
def test_v1_7_0_old_on_train_batch_end(tmpdir):
class OldSignature(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
...
class OldSignatureModel(BoringModel):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
...
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True)
with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
trainer.fit(model)
model = OldSignatureModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True)
with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
trainer.fit(model)
def test_v1_7_0_deprecate_on_post_move_to_device(tmpdir):
class TestModel(BoringModel):
def on_post_move_to_device(self):

View File

@ -308,7 +308,7 @@ class RankZeroLoggerCheck(Callback):
# this class has to be defined outside the test function, otherwise we get pickle error
# due to the way ddp process is launched
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
is_dummy = isinstance(trainer.logger.experiment, DummyExperiment)
if trainer.is_global_zero:
assert not is_dummy

View File

@ -35,9 +35,9 @@ def test_outputs_format(tmpdir):
assert "foo" in output
assert output["foo"] == 123
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
HookedModel._check_output(outputs)
super().on_train_batch_end(outputs, batch, batch_idx, dataloader_idx)
super().on_train_batch_end(outputs, batch, batch_idx)
def training_epoch_end(self, outputs):
assert len(outputs) == 2

View File

@ -91,14 +91,14 @@ def test_training_epoch_end_metrics_collection_on_override(tmpdir):
def training_epoch_end(self, outputs):
self.len_outputs = len(outputs)
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
self.num_train_batches += 1
class NotOverriddenModel(BoringModel):
def on_train_epoch_start(self):
self.num_train_batches = 0
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
self.num_train_batches += 1
overridden_model = OverriddenModel()
@ -289,8 +289,8 @@ class HookedModel(BoringModel):
dict(name="on_after_batch_transfer", args=(ANY, 0)),
# TODO: `on_batch_{start,end}`
dict(name="Callback.on_batch_start", args=(trainer, model)),
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)),
dict(name="on_train_batch_start", args=(ANY, i, 0)),
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)),
dict(name="on_train_batch_start", args=(ANY, i)),
# without a precision plugin, we execute the closure inside the `optimizer.step`
*([] if using_plugin else on_before_optimizer_step),
dict(name="forward", args=(ANY,)),
@ -311,8 +311,8 @@ class HookedModel(BoringModel):
args=(current_epoch, i, ANY, 0, ANY),
kwargs=dict(on_tpu=False, using_lbfgs=False, using_native_amp=using_native_amp),
),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i, 0)),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)),
dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)),
dict(name="Callback.on_batch_end", args=(trainer, model)),
]
)
@ -331,8 +331,8 @@ class HookedModel(BoringModel):
dict(name="on_after_batch_transfer", args=(ANY, 0)),
# TODO: `on_batch_{start,end}`
dict(name="Callback.on_batch_start", args=(trainer, model)),
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i, 0)),
dict(name="on_train_batch_start", args=(ANY, i, 0)),
dict(name="Callback.on_train_batch_start", args=(trainer, model, ANY, i)),
dict(name="on_train_batch_start", args=(ANY, i)),
dict(name="forward", args=(ANY,)),
dict(name="Callback.on_before_backward", args=(trainer, model, ANY)),
dict(name="on_before_backward", args=(ANY,)),
@ -349,8 +349,8 @@ class HookedModel(BoringModel):
*([] if using_plugin else [dict(name="closure")]),
dict(name="training_step", args=(ANY, i)),
dict(name="training_step_end", args=(dict(loss=ANY),)),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i, 0)),
dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i, 0)),
dict(name="Callback.on_train_batch_end", args=(trainer, model, dict(loss=ANY), ANY, i)),
dict(name="on_train_batch_end", args=(dict(loss=ANY), ANY, i)),
dict(name="Callback.on_batch_end", args=(trainer, model)),
]
)

View File

@ -649,7 +649,7 @@ def test_deepspeed_multigpu_stage_3_resume_training(tmpdir):
class TestCallback(Callback):
def on_train_batch_start(
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
) -> None:
original_deepspeed_plugin = initial_trainer.accelerator.training_type_plugin
current_deepspeed_plugin = trainer.accelerator.training_type_plugin
@ -707,9 +707,7 @@ def _deepspeed_multigpu_stage_2_accumulated_grad_batches(tmpdir, offload_optimiz
def __init__(self):
self.on_train_batch_start_called = False
def on_train_batch_start(
self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
def on_train_batch_start(self, trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
deepspeed_engine = trainer.training_type_plugin.model
assert trainer.global_step == deepspeed_engine.global_steps
self.on_train_batch_start_called = True

View File

@ -501,7 +501,7 @@ def test_logging_in_callbacks_with_log_function(tmpdir):
def on_train_epoch_start(self, trainer, pl_module):
self.log("on_train_epoch_start", 2)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.log("on_train_batch_end", 3)
def on_batch_end(self, trainer, pl_module):

View File

@ -232,7 +232,7 @@ class ManualOptimizationExtendedModel(BoringModel):
def should_update(self):
return self.count % 2 == 0
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
self.called["on_train_batch_start"] += 1
self.weight_before = self.layer.weight.clone()

View File

@ -334,7 +334,7 @@ def test_multiple_optimizers_callbacks(tmpdir):
"""Tests that multiple optimizers can be used with callbacks."""
class CB(Callback):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
pass
def on_train_epoch_start(self, trainer, pl_module):

View File

@ -220,7 +220,7 @@ class Counter(Callback):
self.val_batches_seen = 0
self.test_batches_seen = 0
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
self.train_batches_seen += 1
def on_train_epoch_start(self, trainer, pl_module):
@ -1482,7 +1482,7 @@ def test_request_dataloader(tmpdir):
self.train_dataloader = DataLoaderFunc(DataLoaderWrapper(loader))
self.on_train_dataloader_called = True
def on_train_batch_start(self, batch, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(self, batch, batch_idx: int) -> None:
assert isinstance(self.trainer.train_dataloader.loaders, DataLoaderWrapper)
self.on_train_batch_start_called = True

View File

@ -299,7 +299,7 @@ def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_bat
self.start_state_dict = self.state_dict()
self.opt_step_called = False
def on_train_batch_end(self, outputs, batch, batch_idx, *_):
def on_train_batch_end(self, outputs, batch, batch_idx):
end_state_dict = self.state_dict()
is_last_batch = (batch_idx + 1) == self.trainer.num_training_batches
@ -966,7 +966,7 @@ def test_on_exception_hook(tmpdir):
def __init__(self):
super().__init__()
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
raise KeyboardInterrupt
def on_test_start(self, trainer, pl_module):

View File

@ -38,7 +38,7 @@ class TopModule(BoringModel):
class DeviceAssertCallback(Callback):
def on_train_batch_start(self, trainer, model, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, trainer, model, batch, batch_idx):
rank = trainer.local_rank
assert isinstance(model, TopModule)
# index = None also means first device

View File

@ -356,7 +356,7 @@ def test_on_train_batch_start_overridden(tmpdir) -> None:
`LightningModule`."""
class InvalidModel(AsyncBoringModel):
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
pass
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
@ -370,7 +370,7 @@ def test_on_train_batch_end_overridden(tmpdir) -> None:
`LightningModule`."""
class InvalidModel(AsyncBoringModel):
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, outputs, batch, batch_idx):
pass
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)