From c61e1e697d2e4ed5cc97c187e6704a44f2f77aff Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 23 Sep 2020 11:22:10 -0700 Subject: [PATCH] Add stronger typing to gradient accumulation scheduler callback (#3558) * Update gradient_accumulation_scheduler.py add types for gradient accumulation scheduler callback * Update gradient_accumulation_scheduler.py --- .../callbacks/gradient_accumulation_scheduler.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py index 97ea029619..7b723c3fc9 100644 --- a/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py +++ b/pytorch_lightning/callbacks/gradient_accumulation_scheduler.py @@ -21,6 +21,8 @@ Trainer also calls ``optimizer.step()`` for the last indivisible step number. """ +from typing import Dict + from pytorch_lightning.callbacks.base import Callback @@ -44,7 +46,7 @@ class GradientAccumulationScheduler(Callback): >>> trainer = Trainer(accumulate_grad_batches={5: 2}) """ - def __init__(self, scheduling: dict): + def __init__(self, scheduling: Dict[int, int]): super().__init__() if not scheduling: # empty dict error @@ -56,7 +58,9 @@ class GradientAccumulationScheduler(Callback): minimal_epoch = min(scheduling.keys()) if minimal_epoch < 0: - raise IndexError(f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct") + raise IndexError( + f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" + ) if minimal_epoch != 0: # if user didnt define first epoch accumulation factor scheduling.update({0: 1})