From 3442b97d1fa86bb871b85f1b5869969429243ba3 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Sun, 20 Sep 2020 19:59:42 -0700 Subject: [PATCH] Black format pytorch_lightning/core/hooks.py (#3575) Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter --- pytorch_lightning/core/hooks.py | 46 +++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index 0694b9923b..b4cfd50819 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Union, List +from typing import Any, List, Union import torch +from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn from torch import Tensor from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader -from pytorch_lightning.utilities import move_data_to_device, AMPType, rank_zero_warn try: from apex import amp @@ -28,7 +28,6 @@ except ImportError: class ModelHooks: - def setup(self, stage: str): """ Called at the beginning of fit and test. @@ -113,7 +112,9 @@ class ModelHooks: """ # do something at the end of the pretrain routine - def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_start( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the training loop before anything happens for that batch. @@ -126,7 +127,9 @@ class ModelHooks: """ # do something when the batch starts - def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the training loop after the batch. @@ -137,7 +140,9 @@ class ModelHooks: """ # do something when the batch ends - def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_validation_batch_start( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the validation loop before anything happens for that batch. @@ -148,7 +153,9 @@ class ModelHooks: """ # do something when the batch starts - def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_validation_batch_end( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the validation loop after the batch. @@ -159,7 +166,9 @@ class ModelHooks: """ # do something when the batch ends - def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_test_batch_start( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the test loop before anything happens for that batch. @@ -170,7 +179,9 @@ class ModelHooks: """ # do something when the batch starts - def on_test_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_test_batch_end( + self, batch: Any, batch_idx: int, dataloader_idx: int + ) -> None: """ Called in the test loop after the batch. @@ -288,7 +299,9 @@ class ModelHooks: """ - def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None: + def backward( + self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int + ) -> None: """ Override backward with your own implementation if you need to. @@ -311,7 +324,13 @@ class ModelHooks: """ loss.backward() - def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType): + def amp_scale_loss( + self, + unscaled_loss: Tensor, + optimizer: Optimizer, + optimizer_idx: int, + amp_backend: AMPType, + ): if amp_backend == AMPType.NATIVE: scaled_loss = self.trainer.scaler.scale(unscaled_loss) else: @@ -321,7 +340,6 @@ class ModelHooks: class DataHooks: - def prepare_data(self) -> None: """ Use this to download and prepare data. @@ -412,7 +430,9 @@ class DataHooks: return loader """ - rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer') + rank_zero_warn( + "`train_dataloader` must be implemented to be used with the Lightning Trainer" + ) def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]: r"""