Black format pytorch_lightning/core/hooks.py (#3575)
Split out changes from #3563 to make that PR easier to review. This formats the file according to the Black formatter
This commit is contained in:
parent
cf1b946d4a
commit
3442b97d1f
|
@ -12,14 +12,14 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Any, Union, List
|
from typing import Any, List, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pytorch_lightning.utilities import AMPType, move_data_to_device, rank_zero_warn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.optim.optimizer import Optimizer
|
from torch.optim.optimizer import Optimizer
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from pytorch_lightning.utilities import move_data_to_device, AMPType, rank_zero_warn
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
@ -28,7 +28,6 @@ except ImportError:
|
||||||
|
|
||||||
|
|
||||||
class ModelHooks:
|
class ModelHooks:
|
||||||
|
|
||||||
def setup(self, stage: str):
|
def setup(self, stage: str):
|
||||||
"""
|
"""
|
||||||
Called at the beginning of fit and test.
|
Called at the beginning of fit and test.
|
||||||
|
@ -113,7 +112,9 @@ class ModelHooks:
|
||||||
"""
|
"""
|
||||||
# do something at the end of the pretrain routine
|
# do something at the end of the pretrain routine
|
||||||
|
|
||||||
def on_train_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
def on_train_batch_start(
|
||||||
|
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called in the training loop before anything happens for that batch.
|
Called in the training loop before anything happens for that batch.
|
||||||
|
|
||||||
|
@ -126,7 +127,9 @@ class ModelHooks:
|
||||||
"""
|
"""
|
||||||
# do something when the batch starts
|
# do something when the batch starts
|
||||||
|
|
||||||
def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
def on_train_batch_end(
|
||||||
|
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called in the training loop after the batch.
|
Called in the training loop after the batch.
|
||||||
|
|
||||||
|
@ -137,7 +140,9 @@ class ModelHooks:
|
||||||
"""
|
"""
|
||||||
# do something when the batch ends
|
# do something when the batch ends
|
||||||
|
|
||||||
def on_validation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
def on_validation_batch_start(
|
||||||
|
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called in the validation loop before anything happens for that batch.
|
Called in the validation loop before anything happens for that batch.
|
||||||
|
|
||||||
|
@ -148,7 +153,9 @@ class ModelHooks:
|
||||||
"""
|
"""
|
||||||
# do something when the batch starts
|
# do something when the batch starts
|
||||||
|
|
||||||
def on_validation_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
def on_validation_batch_end(
|
||||||
|
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called in the validation loop after the batch.
|
Called in the validation loop after the batch.
|
||||||
|
|
||||||
|
@ -159,7 +166,9 @@ class ModelHooks:
|
||||||
"""
|
"""
|
||||||
# do something when the batch ends
|
# do something when the batch ends
|
||||||
|
|
||||||
def on_test_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
def on_test_batch_start(
|
||||||
|
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called in the test loop before anything happens for that batch.
|
Called in the test loop before anything happens for that batch.
|
||||||
|
|
||||||
|
@ -170,7 +179,9 @@ class ModelHooks:
|
||||||
"""
|
"""
|
||||||
# do something when the batch starts
|
# do something when the batch starts
|
||||||
|
|
||||||
def on_test_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None:
|
def on_test_batch_end(
|
||||||
|
self, batch: Any, batch_idx: int, dataloader_idx: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Called in the test loop after the batch.
|
Called in the test loop after the batch.
|
||||||
|
|
||||||
|
@ -288,7 +299,9 @@ class ModelHooks:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def backward(self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int) -> None:
|
def backward(
|
||||||
|
self, trainer, loss: Tensor, optimizer: Optimizer, optimizer_idx: int
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Override backward with your own implementation if you need to.
|
Override backward with your own implementation if you need to.
|
||||||
|
|
||||||
|
@ -311,7 +324,13 @@ class ModelHooks:
|
||||||
"""
|
"""
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType):
|
def amp_scale_loss(
|
||||||
|
self,
|
||||||
|
unscaled_loss: Tensor,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
optimizer_idx: int,
|
||||||
|
amp_backend: AMPType,
|
||||||
|
):
|
||||||
if amp_backend == AMPType.NATIVE:
|
if amp_backend == AMPType.NATIVE:
|
||||||
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
|
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
|
||||||
else:
|
else:
|
||||||
|
@ -321,7 +340,6 @@ class ModelHooks:
|
||||||
|
|
||||||
|
|
||||||
class DataHooks:
|
class DataHooks:
|
||||||
|
|
||||||
def prepare_data(self) -> None:
|
def prepare_data(self) -> None:
|
||||||
"""
|
"""
|
||||||
Use this to download and prepare data.
|
Use this to download and prepare data.
|
||||||
|
@ -412,7 +430,9 @@ class DataHooks:
|
||||||
return loader
|
return loader
|
||||||
|
|
||||||
"""
|
"""
|
||||||
rank_zero_warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
|
rank_zero_warn(
|
||||||
|
"`train_dataloader` must be implemented to be used with the Lightning Trainer"
|
||||||
|
)
|
||||||
|
|
||||||
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
|
def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
|
||||||
r"""
|
r"""
|
||||||
|
|
Loading…
Reference in New Issue