Black format pytorch_lightning/core/hooks.py (#3575)

Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter
This commit is contained in:
ananthsub 2020-09-20 19:59:42 -07:00 committed by GitHub
parent cf1b946d4a
commit 3442b97d1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 33 additions and 13 deletions

View File

@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union, List
from typing import Any, List, Union
import torch
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from pytorch_lightning.utilities import move_data_to_device, AMPType, rank_zero_warn
try:
from apex import amp
@ -28,7 +28,6 @@ except ImportError:
class ModelHooks:
def setup(self, stage: str):
"""
Called at the beginning of fit and test.
@ -113,7 +112,9 @@ class ModelHooks:
"""
# do something at the end of the pretrain routine
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the training loop before anything happens for that batch.
@ -126,7 +127,9 @@ class ModelHooks:
"""
# do something when the batch starts
def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_train_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the training loop after the batch.
@ -137,7 +140,9 @@ class ModelHooks:
"""
# do something when the batch ends
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_validation_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the validation loop before anything happens for that batch.
@ -148,7 +153,9 @@ class ModelHooks:
"""
# do something when the batch starts
def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_validation_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the validation loop after the batch.
@ -159,7 +166,9 @@ class ModelHooks:
"""
# do something when the batch ends
def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_test_batch_start(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the test loop before anything happens for that batch.
@ -170,7 +179,9 @@ class ModelHooks:
"""
# do something when the batch starts
def on_test_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
def on_test_batch_end(
self, batch: Any, batch_idx: int, dataloader_idx: int
) -> None:
"""
Called in the test loop after the batch.
@ -288,7 +299,9 @@ class ModelHooks:
"""
def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
def backward(
self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int
) -> None:
"""
Override backward with your own implementation if you need to.
@ -311,7 +324,13 @@ class ModelHooks:
"""
loss.backward()
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType):
def amp_scale_loss(
self,
unscaled_loss: Tensor,
optimizer: Optimizer,
optimizer_idx: int,
amp_backend: AMPType,
):
if amp_backend == AMPType.NATIVE:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
else:
@ -321,7 +340,6 @@ class ModelHooks:
class DataHooks:
def prepare_data(self) -> None:
"""
Use this to download and prepare data.
@ -412,7 +430,9 @@ class DataHooks:
return loader
"""
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
rank_zero_warn(
"`train_dataloader` must be implemented to be used with the Lightning Trainer"
)
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
r"""