322 lines
13 KiB
Python
322 lines
13 KiB
Python
import math
|
|
import sys
|
|
from abc import ABC, abstractmethod
|
|
import gc
|
|
import os
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.utils.data import DataLoader
|
|
|
|
from pytorch_lightning import _logger as log
|
|
from pytorch_lightning.core.lightning import LightningModule
|
|
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
|
from pytorch_lightning.loggers.base import DummyLogger
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from pytorch_lightning.utilities.memory import is_oom_error, garbage_collection_cuda
|
|
|
|
EPSILON = 1e-6
|
|
EPSILON_FP16 = 1e-5
|
|
|
|
|
|
class TrainerTrainingTricksMixin(ABC):
|
|
|
|
# this is just a summary on variables used in this abstract class,
|
|
# the proper values/initialisation should be done in child class
|
|
gradient_clip_val: ...
|
|
precision: int
|
|
default_root_dir: str
|
|
progress_bar_callback: ...
|
|
on_gpu: bool
|
|
|
|
@abstractmethod
|
|
def get_model(self):
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
|
|
|
@abstractmethod
|
|
def save_checkpoint(self, *args):
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
|
|
|
@abstractmethod
|
|
def restore(self, *args):
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
|
|
|
@abstractmethod
|
|
def fit(self, *args):
|
|
"""Warning: this is just empty shell for code implemented in other class."""
|
|
|
|
def clip_gradients(self):
|
|
|
|
# this code is a modification of torch.nn.utils.clip_grad_norm_
|
|
# with TPU support based on https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md
|
|
if self.gradient_clip_val > 0:
|
|
model = self.get_model()
|
|
parameters = model.parameters()
|
|
max_norm = float(self.gradient_clip_val)
|
|
norm_type = float(2.0)
|
|
if isinstance(parameters, torch.Tensor):
|
|
parameters = [parameters]
|
|
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
|
if norm_type == math.inf:
|
|
total_norm = max(p.grad.data.abs().max() for p in parameters)
|
|
else:
|
|
device = parameters[0].device
|
|
total_norm = torch.zeros([], device=device if parameters else None)
|
|
for p in parameters:
|
|
param_norm = p.grad.data.pow(norm_type).sum()
|
|
total_norm.add_(param_norm)
|
|
total_norm = (total_norm ** (1. / norm_type))
|
|
eps = EPSILON_FP16 if self.precision == 16 else EPSILON
|
|
clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps)
|
|
for p in parameters:
|
|
p.grad.data.mul_(torch.where(clip_coef < 1, clip_coef, torch.tensor(1., device=device)))
|
|
|
|
def print_nan_gradients(self) -> None:
|
|
model = self.get_model()
|
|
for param in model.parameters():
|
|
if (param.grad is not None) and torch.isnan(param.grad.float()).any():
|
|
log.info(param, param.grad)
|
|
|
|
def detect_nan_tensors(self, loss: Tensor) -> None:
|
|
model = self.get_model()
|
|
|
|
# check if loss is nan
|
|
if not torch.isfinite(loss).all():
|
|
raise ValueError(
|
|
'The loss returned in `training_step` is nan or inf.'
|
|
)
|
|
# check if a network weight is nan
|
|
for name, param in model.named_parameters():
|
|
if not torch.isfinite(param).all():
|
|
self.print_nan_gradients()
|
|
raise ValueError(
|
|
f'Detected nan and/or inf values in `{name}`.'
|
|
' Check your forward pass for numerically unstable operations.'
|
|
)
|
|
|
|
def configure_accumulated_gradients(self, accumulate_grad_batches):
|
|
if isinstance(accumulate_grad_batches, dict):
|
|
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
|
elif isinstance(accumulate_grad_batches, int):
|
|
schedule = {1: accumulate_grad_batches}
|
|
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
|
|
else:
|
|
raise TypeError("Gradient accumulation supports only int and dict types")
|
|
|
|
def scale_batch_size(self,
|
|
model: LightningModule,
|
|
mode: str = 'power',
|
|
steps_per_trial: int = 3,
|
|
init_val: int = 2,
|
|
max_trials: int = 25,
|
|
batch_arg_name: str = 'batch_size'):
|
|
r"""
|
|
Will iteratively try to find the largest batch size for a given model
|
|
that does not give an out of memory (OOM) error.
|
|
|
|
Args:
|
|
model: Model to fit.
|
|
|
|
mode: string setting the search mode. Either `power` or `binsearch`.
|
|
If mode is `power` we keep multiplying the batch size by 2, until
|
|
we get an OOM error. If mode is 'binsearch', we will initially
|
|
also keep multiplying by 2 and after encountering an OOM error
|
|
do a binary search between the last successful batch size and the
|
|
batch size that failed.
|
|
|
|
steps_per_trial: number of steps to run with a given batch size.
|
|
Idealy 1 should be enough to test if a OOM error occurs,
|
|
however in practise a few are needed
|
|
|
|
init_val: initial batch size to start the search with
|
|
|
|
max_trials: max number of increase in batch size done before
|
|
algorithm is terminated
|
|
|
|
"""
|
|
if not hasattr(model, batch_arg_name):
|
|
raise MisconfigurationException(f'Field {batch_arg_name} not found in `model.hparams`')
|
|
|
|
if hasattr(model.train_dataloader, 'patch_loader_code'):
|
|
raise MisconfigurationException('The batch scaling feature cannot be used with dataloaders'
|
|
' passed directly to `.fit()`. Please disable the feature or'
|
|
' incorporate the dataloader into the model.')
|
|
|
|
# Arguments we adjust during the batch size finder, save for restoring
|
|
self.__scale_batch_dump_params()
|
|
|
|
# Set to values that are required by the algorithm
|
|
self.__scale_batch_reset_params(model, steps_per_trial)
|
|
|
|
# Save initial model, that is loaded after batch size is found
|
|
save_path = os.path.join(self.default_root_dir, 'temp_model.ckpt')
|
|
self.save_checkpoint(str(save_path))
|
|
|
|
if self.progress_bar_callback:
|
|
self.progress_bar_callback.disable()
|
|
|
|
# Initially we just double in size until an OOM is encountered
|
|
new_size = _adjust_batch_size(self, value=init_val) # initially set to init_val
|
|
if mode == 'power':
|
|
new_size = _run_power_scaling(self, model, new_size, batch_arg_name, max_trials)
|
|
elif mode == 'binsearch':
|
|
new_size = _run_binsearch_scaling(self, model, new_size, batch_arg_name, max_trials)
|
|
else:
|
|
raise ValueError('mode in method `scale_batch_size` can only be `power` or `binsearch')
|
|
|
|
garbage_collection_cuda()
|
|
log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')
|
|
|
|
# Restore initial state of model
|
|
self.restore(str(save_path), on_gpu=self.on_gpu)
|
|
os.remove(save_path)
|
|
|
|
# Finish by resetting variables so trainer is ready to fit model
|
|
self.__scale_batch_restore_params()
|
|
if self.progress_bar_callback:
|
|
self.progress_bar_callback.enable()
|
|
|
|
return new_size
|
|
|
|
def __scale_batch_dump_params(self):
|
|
# Prevent going into infinite loop
|
|
self.__dumped_params = {
|
|
'max_steps': self.max_steps,
|
|
'weights_summary': self.weights_summary,
|
|
'logger': self.logger,
|
|
'callbacks': self.callbacks,
|
|
'checkpoint_callback': self.checkpoint_callback,
|
|
'early_stop_callback': self.early_stop_callback,
|
|
'enable_early_stop': self.enable_early_stop,
|
|
'auto_scale_batch_size': self.auto_scale_batch_size,
|
|
'train_percent_check': self.train_percent_check,
|
|
'model': self.model,
|
|
}
|
|
|
|
def __scale_batch_reset_params(self, model, steps_per_trial):
|
|
self.auto_scale_batch_size = None # prevent recursion
|
|
self.max_steps = steps_per_trial # take few steps
|
|
self.weights_summary = None # not needed before full run
|
|
self.logger = DummyLogger()
|
|
self.callbacks = [] # not needed before full run
|
|
self.checkpoint_callback = False # required for saving
|
|
self.early_stop_callback = None
|
|
self.enable_early_stop = False
|
|
self.train_percent_check = 1.0
|
|
self.optimizers, self.schedulers = [], [] # required for saving
|
|
self.model = model # required for saving
|
|
|
|
def __scale_batch_restore_params(self):
|
|
self.max_steps = self.__dumped_params['max_steps']
|
|
self.weights_summary = self.__dumped_params['weights_summary']
|
|
self.logger = self.__dumped_params['logger']
|
|
self.callbacks = self.__dumped_params['callbacks']
|
|
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
|
|
self.auto_scale_batch_size = self.__dumped_params['auto_scale_batch_size']
|
|
self.early_stop_callback = self.__dumped_params['early_stop_callback']
|
|
self.enable_early_stop = self.__dumped_params['enable_early_stop']
|
|
self.train_percent_check = self.__dumped_params['train_percent_check']
|
|
self.model = self.__dumped_params['model']
|
|
del self.__dumped_params
|
|
|
|
|
|
def _adjust_batch_size(trainer,
|
|
batch_arg_name: str = 'batch_size',
|
|
factor: float = 1.0,
|
|
value: Optional[int] = None,
|
|
desc: str = None):
|
|
""" Function for adjusting the batch size. It is expected that the user
|
|
has provided a model that has a hparam field called `batch_size` i.e.
|
|
`model.hparams.batch_size` should exist.
|
|
|
|
Args:
|
|
trainer: instance of pytorch_lightning.Trainer
|
|
|
|
batch_arg_name: field where batch_size is stored in `model.hparams`
|
|
|
|
factor: value which the old batch size is multiplied by to get the
|
|
new batch size
|
|
|
|
value: if a value is given, will override the batch size with this value.
|
|
Note that the value of `factor` will not have an effect in this case
|
|
|
|
desc: either `succeeded` or `failed`. Used purely for logging
|
|
|
|
"""
|
|
model = trainer.get_model()
|
|
batch_size = getattr(model, batch_arg_name)
|
|
if value:
|
|
setattr(model, batch_arg_name, value)
|
|
new_size = value
|
|
if desc:
|
|
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
|
|
else:
|
|
new_size = int(batch_size * factor)
|
|
if desc:
|
|
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')
|
|
setattr(model, batch_arg_name, new_size)
|
|
return new_size
|
|
|
|
|
|
def _run_power_scaling(trainer, model, new_size, batch_arg_name, max_trials):
|
|
""" Batch scaling mode where the size is doubled at each iteration until an
|
|
OOM error is encountered. """
|
|
for _ in range(max_trials):
|
|
garbage_collection_cuda()
|
|
trainer.global_step = 0 # reset after each try
|
|
try:
|
|
# Try fit
|
|
trainer.fit(model)
|
|
# Double in size
|
|
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
|
|
except RuntimeError as exception:
|
|
# Only these errors should trigger an adjustment
|
|
if is_oom_error(exception):
|
|
# If we fail in power mode, half the size and return
|
|
garbage_collection_cuda()
|
|
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=0.5, desc='failed')
|
|
break
|
|
else:
|
|
raise # some other error not memory related
|
|
return new_size
|
|
|
|
|
|
def _run_binsearch_scaling(trainer, model, new_size, batch_arg_name, max_trials):
|
|
""" Batch scaling mode where the size is initially is doubled at each iteration
|
|
until an OOM error is encountered. Hereafter, the batch size is further
|
|
refined using a binary search """
|
|
high = None
|
|
count = 0
|
|
while True:
|
|
garbage_collection_cuda()
|
|
trainer.global_step = 0 # reset after each try
|
|
try:
|
|
# Try fit
|
|
trainer.fit(model)
|
|
count += 1
|
|
if count > max_trials:
|
|
break
|
|
# Double in size
|
|
low = new_size
|
|
if high:
|
|
if high - low <= 1:
|
|
break
|
|
midval = (high + low) // 2
|
|
new_size = _adjust_batch_size(trainer, batch_arg_name, value=midval, desc='succeeded')
|
|
else:
|
|
new_size = _adjust_batch_size(trainer, batch_arg_name, factor=2.0, desc='succeeded')
|
|
except RuntimeError as exception:
|
|
# Only these errors should trigger an adjustment
|
|
if is_oom_error(exception):
|
|
# If we fail in power mode, half the size and return
|
|
garbage_collection_cuda()
|
|
high = new_size
|
|
midval = (high + low) // 2
|
|
new_size = _adjust_batch_size(trainer, value=midval, desc='failed')
|
|
if high - low <= 1:
|
|
break
|
|
else:
|
|
raise # some other error not memory related
|
|
return new_size
|