From 376734a1e2fbb0ae41d8046a0491a4c0ba7f5657 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Tue, 24 Aug 2021 10:12:36 -0700 Subject: [PATCH] Fix `accumulated_grad_batches` typehint (#9071) * Fix `accumulated_grad_batches` typehint --- .../trainer/connectors/training_trick_connector.py | 6 +++--- pytorch_lightning/trainer/trainer.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 733199c932..285ed5afbf 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Union +from typing import Dict, Union from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.utilities import GradClipAlgorithmType @@ -27,7 +27,7 @@ class TrainingTricksConnector: gradient_clip_val: float, gradient_clip_algorithm: str, track_grad_norm: Union[int, float, str], - accumulate_grad_batches: Union[int, Dict[int, int], List[list]], + accumulate_grad_batches: Union[int, Dict[int, int]], terminate_on_nan: bool, ): @@ -48,7 +48,7 @@ class TrainingTricksConnector: self.trainer.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) - def configure_accumulated_gradients(self, accumulate_grad_batches): + def configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None: if isinstance(accumulate_grad_batches, dict): self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) elif isinstance(accumulate_grad_batches, int): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6b39bb5159..ac66b1083f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -123,7 +123,7 @@ class Trainer( track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: Union[int, bool] = False, - accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, + accumulate_grad_batches: Union[int, Dict[int, int]] = 1, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, max_steps: Optional[int] = None,