2020-09-07 20:45:31 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License
|
2021-03-02 09:47:55 +00:00
|
|
|
import logging
|
2020-09-07 20:45:31 +00:00
|
|
|
import os
|
2020-10-19 20:20:17 +00:00
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
2021-04-29 12:40:51 +00:00
|
|
|
import pytorch_lightning as pl
|
2020-09-07 20:45:31 +00:00
|
|
|
from pytorch_lightning.loggers.base import DummyLogger
|
2021-01-15 22:44:27 +00:00
|
|
|
from pytorch_lightning.utilities import DeviceType, rank_zero_warn
|
2020-11-03 09:39:40 +00:00
|
|
|
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
2021-01-15 22:44:27 +00:00
|
|
|
from pytorch_lightning.utilities.data import has_len
|
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
|
|
from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error
|
|
|
|
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
|
2020-09-07 20:45:31 +00:00
|
|
|
|
2021-03-02 09:47:55 +00:00
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
2020-09-07 20:45:31 +00:00
|
|
|
|
2021-02-08 19:28:38 +00:00
|
|
|
def scale_batch_size(
|
2021-04-29 12:40:51 +00:00
|
|
|
trainer: 'pl.Trainer',
|
|
|
|
model: 'pl.LightningModule',
|
2021-02-08 19:28:38 +00:00
|
|
|
mode: str = 'power',
|
|
|
|
steps_per_trial: int = 3,
|
|
|
|
init_val: int = 2,
|
|
|
|
max_trials: int = 25,
|
|
|
|
batch_arg_name: str = 'batch_size',
|
2021-04-29 12:40:51 +00:00
|
|
|
) -> Optional[int]:
|
2021-04-30 13:54:58 +00:00
|
|
|
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`"""
|
2020-11-09 23:34:42 +00:00
|
|
|
if trainer.fast_dev_run:
|
2020-12-08 20:07:53 +00:00
|
|
|
rank_zero_warn('Skipping batch size scaler since fast_dev_run is enabled.', UserWarning)
|
2020-11-09 23:34:42 +00:00
|
|
|
return
|
|
|
|
|
2020-09-07 20:45:31 +00:00
|
|
|
if not lightning_hasattr(model, batch_arg_name):
|
2021-02-08 19:28:38 +00:00
|
|
|
raise MisconfigurationException(f'Field {batch_arg_name} not found in both `model` and `model.hparams`')
|
2020-09-07 20:45:31 +00:00
|
|
|
if hasattr(model, batch_arg_name) and hasattr(model, "hparams") and batch_arg_name in model.hparams:
|
|
|
|
rank_zero_warn(
|
|
|
|
f'Field `model.{batch_arg_name}` and `model.hparams.{batch_arg_name}` are mutually exclusive!'
|
|
|
|
f' `model.{batch_arg_name}` will be used as the initial batch size for scaling.'
|
|
|
|
f' If this is not the intended behavior, please remove either one.'
|
|
|
|
)
|
|
|
|
|
2020-10-10 01:03:23 +00:00
|
|
|
if hasattr(model.train_dataloader, 'patch_loader_code'):
|
2021-02-08 19:28:38 +00:00
|
|
|
raise MisconfigurationException(
|
|
|
|
'The batch scaling feature cannot be used with dataloaders passed directly to `.fit()`.'
|
|
|
|
' Please disable the feature or incorporate the dataloader into the model.'
|
|
|
|
)
|
2020-10-10 01:03:23 +00:00
|
|
|
|
2020-09-07 20:45:31 +00:00
|
|
|
# Arguments we adjust during the batch size finder, save for restoring
|
|
|
|
__scale_batch_dump_params(trainer)
|
|
|
|
|
|
|
|
# Set to values that are required by the algorithm
|
|
|
|
__scale_batch_reset_params(trainer, model, steps_per_trial)
|
|
|
|
|
|
|
|
# Save initial model, that is loaded after batch size is found
|
2020-11-03 09:39:40 +00:00
|
|
|
save_path = os.path.join(trainer.default_root_dir, 'scale_batch_size_temp_model.ckpt')
|
2020-09-07 20:45:31 +00:00
|
|
|
trainer.save_checkpoint(str(save_path))
|
|
|
|
|
|
|
|
if trainer.progress_bar_callback:
|
|
|
|
trainer.progress_bar_callback.disable()
|
|
|
|
|
2020-10-10 01:03:23 +00:00
|
|
|
# Initially we just double in size until an OOM is encountered
|
2021-04-29 12:40:51 +00:00
|
|
|
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) # initially set to init_val
|
2020-09-07 20:45:31 +00:00
|
|
|
if mode == 'power':
|
2021-04-30 13:54:58 +00:00
|
|
|
new_size = _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials)
|
2020-09-07 20:45:31 +00:00
|
|
|
elif mode == 'binsearch':
|
2021-04-30 13:54:58 +00:00
|
|
|
new_size = _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials)
|
2020-09-07 20:45:31 +00:00
|
|
|
else:
|
2021-05-04 12:03:51 +00:00
|
|
|
raise ValueError('mode in method `scale_batch_size` could either be `power` or `binsearch`')
|
2020-10-10 01:03:23 +00:00
|
|
|
|
2020-09-07 20:45:31 +00:00
|
|
|
garbage_collection_cuda()
|
|
|
|
log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')
|
|
|
|
|
|
|
|
# Restore initial state of model
|
2020-11-03 09:39:40 +00:00
|
|
|
if trainer.is_global_zero:
|
2021-01-12 10:22:37 +00:00
|
|
|
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer._device_type == DeviceType.GPU)
|
2020-11-03 09:39:40 +00:00
|
|
|
fs = get_filesystem(str(save_path))
|
|
|
|
if fs.exists(save_path):
|
|
|
|
fs.rm(save_path)
|
2020-09-07 20:45:31 +00:00
|
|
|
|
|
|
|
# Finish by resetting variables so trainer is ready to fit model
|
|
|
|
__scale_batch_restore_params(trainer)
|
|
|
|
if trainer.progress_bar_callback:
|
|
|
|
trainer.progress_bar_callback.enable()
|
|
|
|
|
|
|
|
return new_size
|
|
|
|
|
|
|
|
|
2021-04-29 12:40:51 +00:00
|
|
|
def __scale_batch_dump_params(trainer: 'pl.Trainer') -> None:
|
2020-09-07 20:45:31 +00:00
|
|
|
# Prevent going into infinite loop
|
|
|
|
trainer.__dumped_params = {
|
|
|
|
'auto_lr_find': trainer.auto_lr_find,
|
2020-10-06 17:54:48 +00:00
|
|
|
'current_epoch': trainer.current_epoch,
|
2020-09-07 20:45:31 +00:00
|
|
|
'max_steps': trainer.max_steps,
|
|
|
|
'weights_summary': trainer.weights_summary,
|
|
|
|
'logger': trainer.logger,
|
|
|
|
'callbacks': trainer.callbacks,
|
|
|
|
'checkpoint_callback': trainer.checkpoint_callback,
|
|
|
|
'auto_scale_batch_size': trainer.auto_scale_batch_size,
|
|
|
|
'limit_train_batches': trainer.limit_train_batches,
|
|
|
|
'model': trainer.model,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
2021-04-29 12:40:51 +00:00
|
|
|
def __scale_batch_reset_params(trainer: 'pl.Trainer', model: 'pl.LightningModule', steps_per_trial: int) -> None:
|
2020-09-07 20:45:31 +00:00
|
|
|
trainer.auto_scale_batch_size = None # prevent recursion
|
|
|
|
trainer.auto_lr_find = False # avoid lr find being called multiple times
|
2021-05-11 09:09:08 +00:00
|
|
|
trainer.train_loop.current_epoch = 0
|
|
|
|
trainer.train_loop.max_steps = steps_per_trial # take few steps
|
2020-09-07 20:45:31 +00:00
|
|
|
trainer.weights_summary = None # not needed before full run
|
|
|
|
trainer.logger = DummyLogger()
|
|
|
|
trainer.callbacks = [] # not needed before full run
|
|
|
|
trainer.limit_train_batches = 1.0
|
|
|
|
trainer.optimizers, trainer.schedulers = [], [] # required for saving
|
|
|
|
trainer.model = model # required for saving
|
|
|
|
|
|
|
|
|
2021-04-29 12:40:51 +00:00
|
|
|
def __scale_batch_restore_params(trainer: 'pl.Trainer') -> None:
|
2020-09-07 20:45:31 +00:00
|
|
|
trainer.auto_lr_find = trainer.__dumped_params['auto_lr_find']
|
2021-05-11 09:09:08 +00:00
|
|
|
trainer.train_loop.current_epoch = trainer.__dumped_params['current_epoch']
|
|
|
|
trainer.train_loop.max_steps = trainer.__dumped_params['max_steps']
|
2020-09-07 20:45:31 +00:00
|
|
|
trainer.weights_summary = trainer.__dumped_params['weights_summary']
|
|
|
|
trainer.logger = trainer.__dumped_params['logger']
|
|
|
|
trainer.callbacks = trainer.__dumped_params['callbacks']
|
|
|
|
trainer.auto_scale_batch_size = trainer.__dumped_params['auto_scale_batch_size']
|
|
|
|
trainer.limit_train_batches = trainer.__dumped_params['limit_train_batches']
|
|
|
|
trainer.model = trainer.__dumped_params['model']
|
|
|
|
del trainer.__dumped_params
|
|
|
|
|
|
|
|
|
2021-04-29 12:40:51 +00:00
|
|
|
def _run_power_scaling(
|
2021-04-30 13:54:58 +00:00
|
|
|
trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int
|
2021-04-29 12:40:51 +00:00
|
|
|
) -> int:
|
|
|
|
""" Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered. """
|
2020-09-07 20:45:31 +00:00
|
|
|
for _ in range(max_trials):
|
|
|
|
garbage_collection_cuda()
|
2021-05-11 09:09:08 +00:00
|
|
|
trainer.train_loop.global_step = 0 # reset after each try
|
2020-09-07 20:45:31 +00:00
|
|
|
try:
|
|
|
|
# Try fit
|
2021-04-30 13:54:58 +00:00
|
|
|
trainer.tuner._run(model)
|
2020-09-07 20:45:31 +00:00
|
|
|
# Double in size
|
2020-10-10 01:03:23 +00:00
|
|
|
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
|
2020-09-07 20:45:31 +00:00
|
|
|
except RuntimeError as exception:
|
|
|
|
# Only these errors should trigger an adjustment
|
|
|
|
if is_oom_error(exception):
|
|
|
|
# If we fail in power mode, half the size and return
|
|
|
|
garbage_collection_cuda()
|
2020-10-10 01:03:23 +00:00
|
|
|
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed')
|
2020-09-07 20:45:31 +00:00
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise # some other error not memory related
|
2020-09-09 08:51:43 +00:00
|
|
|
|
2021-05-24 08:21:45 +00:00
|
|
|
if changed:
|
|
|
|
# Force the train dataloader to reset as the batch size has changed
|
|
|
|
trainer.reset_train_dataloader(model)
|
|
|
|
else:
|
2020-09-09 08:51:43 +00:00
|
|
|
break
|
2020-09-07 20:45:31 +00:00
|
|
|
return new_size
|
|
|
|
|
|
|
|
|
2021-04-29 12:40:51 +00:00
|
|
|
def _run_binsearch_scaling(
|
2021-04-30 13:54:58 +00:00
|
|
|
trainer: 'pl.Trainer', model: 'pl.LightningModule', new_size: int, batch_arg_name: str, max_trials: int
|
2021-04-29 12:40:51 +00:00
|
|
|
) -> int:
|
2020-09-07 20:45:31 +00:00
|
|
|
""" Batch scaling mode where the size is initially is doubled at each iteration
|
|
|
|
until an OOM error is encountered. Hereafter, the batch size is further
|
|
|
|
refined using a binary search """
|
|
|
|
high = None
|
|
|
|
count = 0
|
|
|
|
while True:
|
|
|
|
garbage_collection_cuda()
|
2021-05-11 09:09:08 +00:00
|
|
|
trainer.train_loop.global_step = 0 # reset after each try
|
2020-09-07 20:45:31 +00:00
|
|
|
try:
|
|
|
|
# Try fit
|
2021-04-30 13:54:58 +00:00
|
|
|
trainer.tuner._run(model)
|
2020-09-07 20:45:31 +00:00
|
|
|
count += 1
|
|
|
|
if count > max_trials:
|
|
|
|
break
|
|
|
|
# Double in size
|
|
|
|
low = new_size
|
|
|
|
if high:
|
|
|
|
if high - low <= 1:
|
|
|
|
break
|
|
|
|
midval = (high + low) // 2
|
2020-10-10 01:03:23 +00:00
|
|
|
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded')
|
2020-09-07 20:45:31 +00:00
|
|
|
else:
|
2020-10-10 01:03:23 +00:00
|
|
|
new_size, changed = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
|
|
|
|
|
2021-05-24 08:21:45 +00:00
|
|
|
if changed:
|
|
|
|
# Force the train dataloader to reset as the batch size has changed
|
|
|
|
trainer.reset_train_dataloader(model)
|
|
|
|
else:
|
2020-09-09 08:51:43 +00:00
|
|
|
break
|
|
|
|
|
2020-09-07 20:45:31 +00:00
|
|
|
except RuntimeError as exception:
|
|
|
|
# Only these errors should trigger an adjustment
|
|
|
|
if is_oom_error(exception):
|
|
|
|
# If we fail in power mode, half the size and return
|
|
|
|
garbage_collection_cuda()
|
|
|
|
high = new_size
|
|
|
|
midval = (high + low) // 2
|
2020-11-23 06:04:11 +00:00
|
|
|
new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='failed')
|
2020-09-07 20:45:31 +00:00
|
|
|
if high - low <= 1:
|
|
|
|
break
|
|
|
|
else:
|
|
|
|
raise # some other error not memory related
|
2020-09-09 08:51:43 +00:00
|
|
|
|
2020-09-07 20:45:31 +00:00
|
|
|
return new_size
|
|
|
|
|
|
|
|
|
2021-02-08 19:28:38 +00:00
|
|
|
def _adjust_batch_size(
|
2021-04-29 12:40:51 +00:00
|
|
|
trainer: 'pl.Trainer',
|
2021-02-08 19:28:38 +00:00
|
|
|
batch_arg_name: str = 'batch_size',
|
|
|
|
factor: float = 1.0,
|
|
|
|
value: Optional[int] = None,
|
|
|
|
desc: Optional[str] = None
|
|
|
|
) -> Tuple[int, bool]:
|
2020-09-09 08:51:43 +00:00
|
|
|
""" Helper function for adjusting the batch size.
|
2020-09-07 20:45:31 +00:00
|
|
|
|
|
|
|
Args:
|
|
|
|
trainer: instance of pytorch_lightning.Trainer
|
|
|
|
|
2020-09-09 08:51:43 +00:00
|
|
|
batch_arg_name: name of the field where batch_size is stored.
|
2020-09-07 20:45:31 +00:00
|
|
|
|
|
|
|
factor: value which the old batch size is multiplied by to get the
|
|
|
|
new batch size
|
|
|
|
|
|
|
|
value: if a value is given, will override the batch size with this value.
|
|
|
|
Note that the value of `factor` will not have an effect in this case
|
|
|
|
|
|
|
|
desc: either `succeeded` or `failed`. Used purely for logging
|
|
|
|
|
2020-09-09 08:51:43 +00:00
|
|
|
Returns:
|
|
|
|
The new batch size for the next trial and a bool that signals whether the
|
|
|
|
new value is different than the previous batch size.
|
2020-09-07 20:45:31 +00:00
|
|
|
"""
|
2021-02-18 14:59:54 +00:00
|
|
|
model = trainer.lightning_module
|
2020-10-10 01:03:23 +00:00
|
|
|
batch_size = lightning_getattr(model, batch_arg_name)
|
|
|
|
new_size = value if value is not None else int(batch_size * factor)
|
2020-09-07 20:45:31 +00:00
|
|
|
if desc:
|
2020-10-10 01:03:23 +00:00
|
|
|
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
|
|
|
|
|
2020-09-09 08:51:43 +00:00
|
|
|
if not _is_valid_batch_size(new_size, trainer.train_dataloader):
|
|
|
|
new_size = min(new_size, len(trainer.train_dataloader.dataset))
|
2020-10-10 01:03:23 +00:00
|
|
|
|
|
|
|
changed = new_size != batch_size
|
|
|
|
lightning_setattr(model, batch_arg_name, new_size)
|
2020-09-09 08:51:43 +00:00
|
|
|
return new_size, changed
|
|
|
|
|
|
|
|
|
|
|
|
def _is_valid_batch_size(current_size, dataloader):
|
|
|
|
return not has_len(dataloader) or current_size <= len(dataloader)
|