diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 97ea029619..7b723c3fc9 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -21,6 +21,8 @@ Trainer also calls ``optimizer.step()`` for the last indivisible step number. """ +from typing import Dict + from pytorch_lightning.callbacks.base import Callback @@ -44,7 +46,7 @@ class GradientAccumulationScheduler(Callback): >>> trainer = Trainer(accumulate_grad_batches={5: 2}) """ - def __init__(self, scheduling: dict): + def __init__(self, scheduling: Dict[int, int]): super().__init__() if not scheduling: # empty dict error @@ -56,7 +58,9 @@ class GradientAccumulationScheduler(Callback): minimal_epoch = min(scheduling.keys()) if minimal_epoch < 0: - raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct") + raise IndexError( + f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" + ) if minimal_epoch != 0: # if user didnt define first epoch accumulation factor scheduling.update({0: 1})