diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index 733199c932..285ed5afbf 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Union +from typing import Dict, Union from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.utilities import GradClipAlgorithmType @@ -27,7 +27,7 @@ class TrainingTricksConnector: gradient_clip_val: float, gradient_clip_algorithm: str, track_grad_norm: Union[int, float, str], - accumulate_grad_batches: Union[int, Dict[int, int], List[list]], + accumulate_grad_batches: Union[int, Dict[int, int]], terminate_on_nan: bool, ): @@ -48,7 +48,7 @@ class TrainingTricksConnector: self.trainer.accumulate_grad_batches = accumulate_grad_batches self.configure_accumulated_gradients(accumulate_grad_batches) - def configure_accumulated_gradients(self, accumulate_grad_batches): + def configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None: if isinstance(accumulate_grad_batches, dict): self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) elif isinstance(accumulate_grad_batches, int): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6b39bb5159..ac66b1083f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -123,7 +123,7 @@ class Trainer( track_grad_norm: Union[int, float, str] = -1, check_val_every_n_epoch: int = 1, fast_dev_run: Union[int, bool] = False, - accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1, + accumulate_grad_batches: Union[int, Dict[int, int]] = 1, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, max_steps: Optional[int] = None,