Add stronger typing to gradient accumulation scheduler callback (#3558)
* Update gradient_accumulation_scheduler.py add types for gradient accumulation scheduler callback * Update gradient_accumulation_scheduler.py
This commit is contained in:
parent
3affa0e49a
commit
c61e1e697d
|
@ -21,6 +21,8 @@ Trainer also calls ``optimizer.step()`` for the last indivisible step number.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from pytorch_lightning.callbacks.base import Callback
|
from pytorch_lightning.callbacks.base import Callback
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,7 +46,7 @@ class GradientAccumulationScheduler(Callback):
|
||||||
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
|
>>> trainer = Trainer(accumulate_grad_batches={5: 2})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, scheduling: dict):
|
def __init__(self, scheduling: Dict[int, int]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not scheduling: # empty dict error
|
if not scheduling: # empty dict error
|
||||||
|
@ -56,7 +58,9 @@ class GradientAccumulationScheduler(Callback):
|
||||||
|
|
||||||
minimal_epoch = min(scheduling.keys())
|
minimal_epoch = min(scheduling.keys())
|
||||||
if minimal_epoch < 0:
|
if minimal_epoch < 0:
|
||||||
raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct")
|
raise IndexError(
|
||||||
|
f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
|
||||||
|
)
|
||||||
if minimal_epoch != 0: # if user didnt define first epoch accumulation factor
|
if minimal_epoch != 0: # if user didnt define first epoch accumulation factor
|
||||||
scheduling.update({0: 1})
|
scheduling.update({0: 1})
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue