Add stronger typing to gradient accumulation scheduler callback (#3558)

* Update gradient_accumulation_scheduler.py

add types for gradient accumulation scheduler callback

* Update gradient_accumulation_scheduler.py
This commit is contained in:
ananthsub 2020-09-23 11:22:10 -07:00 committed by GitHub
parent 3affa0e49a
commit c61e1e697d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 2 deletions

View File

@ -21,6 +21,8 @@ Trainer also calls ``optimizer.step()`` for the last indivisible step number.
""" """
from typing import Dict
from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.base import Callback
@ -44,7 +46,7 @@ class GradientAccumulationScheduler(Callback):
>>> trainer = Trainer(accumulate_grad_batches={5: 2}) >>> trainer = Trainer(accumulate_grad_batches={5: 2})
""" """
def __init__(self, scheduling: dict): def __init__(self, scheduling: Dict[int, int]):
super().__init__() super().__init__()
if not scheduling: # empty dict error if not scheduling: # empty dict error
@ -56,7 +58,9 @@ class GradientAccumulationScheduler(Callback):
minimal_epoch = min(scheduling.keys()) minimal_epoch = min(scheduling.keys())
if minimal_epoch < 0: if minimal_epoch < 0:
raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct") raise IndexError(
f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
)
if minimal_epoch != 0: # if user didnt define first epoch accumulation factor if minimal_epoch != 0: # if user didnt define first epoch accumulation factor
scheduling.update({0: 1}) scheduling.update({0: 1})