Add stronger typing to gradient accumulation scheduler callback (#3558)
* Update gradient_accumulation_scheduler.py add types for gradient accumulation scheduler callback * Update gradient_accumulation_scheduler.py
This commit is contained in:
parent
3affa0e49a
commit
c61e1e697d
|
@ -21,6 +21,8 @@ Trainer also calls ``optimizer.step()`` for the last indivisible step number.
|
|||
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
|
||||
|
||||
|
@ -44,7 +46,7 @@ class GradientAccumulationScheduler(Callback):
|
|||
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
|
||||
"""
|
||||
|
||||
def __init__(self, scheduling: dict):
|
||||
def __init__(self, scheduling: Dict[int, int]):
|
||||
super().__init__()
|
||||
|
||||
if not scheduling: # empty dict error
|
||||
|
@ -56,7 +58,9 @@ class GradientAccumulationScheduler(Callback):
|
|||
|
||||
minimal_epoch = min(scheduling.keys())
|
||||
if minimal_epoch < 0:
|
||||
raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")
|
||||
raise IndexError(
|
||||
f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
|
||||
)
|
||||
if minimal_epoch != 0: # if user didnt define first epoch accumulation factor
|
||||
scheduling.update({0: 1})
|
||||
|
||||
|
|
Loading…
Reference in New Issue