Black format pytorch_lightning/core/hooks.py (#3575)
Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter
This commit is contained in:
parent
cf1b946d4a
commit
3442b97d1f
|
@ -12,14 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Union, List
|
||||
from typing import Any, List, Union
|
||||
|
||||
import torch
|
||||
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
|
||||
from torch import Tensor
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.utilities import move_data_to_device, AMPType, rank_zero_warn
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
@ -28,7 +28,6 @@ except ImportError:
|
|||
|
||||
|
||||
class ModelHooks:
|
||||
|
||||
def setup(self, stage: str):
|
||||
"""
|
||||
Called at the beginning of fit and test.
|
||||
|
@ -113,7 +112,9 @@ class ModelHooks:
|
|||
"""
|
||||
# do something at the end of the pretrain routine
|
||||
|
||||
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
def on_train_batch_start(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
"""
|
||||
Called in the training loop before anything happens for that batch.
|
||||
|
||||
|
@ -126,7 +127,9 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the batch starts
|
||||
|
||||
def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
def on_train_batch_end(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
"""
|
||||
Called in the training loop after the batch.
|
||||
|
||||
|
@ -137,7 +140,9 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the batch ends
|
||||
|
||||
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
def on_validation_batch_start(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
"""
|
||||
Called in the validation loop before anything happens for that batch.
|
||||
|
||||
|
@ -148,7 +153,9 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the batch starts
|
||||
|
||||
def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
def on_validation_batch_end(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
"""
|
||||
Called in the validation loop after the batch.
|
||||
|
||||
|
@ -159,7 +166,9 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the batch ends
|
||||
|
||||
def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
def on_test_batch_start(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
"""
|
||||
Called in the test loop before anything happens for that batch.
|
||||
|
||||
|
@ -170,7 +179,9 @@ class ModelHooks:
|
|||
"""
|
||||
# do something when the batch starts
|
||||
|
||||
def on_test_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
||||
def on_test_batch_end(
|
||||
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||
) -> None:
|
||||
"""
|
||||
Called in the test loop after the batch.
|
||||
|
||||
|
@ -288,7 +299,9 @@ class ModelHooks:
|
|||
|
||||
"""
|
||||
|
||||
def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
|
||||
def backward(
|
||||
self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int
|
||||
) -> None:
|
||||
"""
|
||||
Override backward with your own implementation if you need to.
|
||||
|
||||
|
@ -311,7 +324,13 @@ class ModelHooks:
|
|||
"""
|
||||
loss.backward()
|
||||
|
||||
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType):
|
||||
def amp_scale_loss(
|
||||
self,
|
||||
unscaled_loss: Tensor,
|
||||
optimizer: Optimizer,
|
||||
optimizer_idx: int,
|
||||
amp_backend: AMPType,
|
||||
):
|
||||
if amp_backend == AMPType.NATIVE:
|
||||
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
|
||||
else:
|
||||
|
@ -321,7 +340,6 @@ class ModelHooks:
|
|||
|
||||
|
||||
class DataHooks:
|
||||
|
||||
def prepare_data(self) -> None:
|
||||
"""
|
||||
Use this to download and prepare data.
|
||||
|
@ -412,7 +430,9 @@ class DataHooks:
|
|||
return loader
|
||||
|
||||
"""
|
||||
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
|
||||
rank_zero_warn(
|
||||
"`train_dataloader` must be implemented to be used with the Lightning Trainer"
|
||||
)
|
||||
|
||||
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
r"""
|
||||
|
|
Loading…
Reference in New Issue