Gradient accumulation callback (#150)
* Gradient accumulation callback * little test case * typo * import fix * method name fix * fix epochs indexing from 1 * better code style * code style fix v2 :/ * change interface * fix Trainre new api in tests * trainer api bug fix * new raising error, new update method * extentions tests * a little better tests * typo fix * flack8 better * using scheduler for int and dict * typo * firs epoch bug fix * test update * empty dict exception * floats check * codestyle fix * grad counting test * someday, i will install normal linter * add more checks * Update test_models.py * Update test_models.py * Update test_models.py * Update test_models.py * Update test_models.py * Update test_models.py * Update test_models.py
This commit is contained in:
parent
c2247350bb
commit
73cf47112e
|
@ -1,6 +1,7 @@
|
|||
from .pt_callbacks import EarlyStopping, ModelCheckpoint
|
||||
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler
|
||||
|
||||
__all__ = [
|
||||
'EarlyStopping',
|
||||
'ModelCheckpoint',
|
||||
'GradientAccumulationScheduler',
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import os
|
||||
import shutil
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -254,6 +255,37 @@ class ModelCheckpoint(Callback):
|
|||
self.save_model(filepath, overwrite=False)
|
||||
|
||||
|
||||
class GradientAccumulationScheduler(Callback):
|
||||
"""Change gradient accumulation factor according to scheduling.
|
||||
# Arguments
|
||||
scheduling: dict, scheduling in format {epoch: accumulation_factor}
|
||||
"""
|
||||
def __init__(self, scheduling: dict):
|
||||
if scheduling == {}: # empty dict error
|
||||
raise TypeError("Empty dict cannot be interpreted correct")
|
||||
|
||||
for key in scheduling.keys():
|
||||
if not isinstance(key, int) or not isinstance(scheduling[key], int):
|
||||
raise TypeError("All epoches and accumulation factor must be integers")
|
||||
|
||||
minimal_epoch = min(scheduling.keys())
|
||||
if minimal_epoch < 1:
|
||||
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
|
||||
raise IndexError(msg)
|
||||
elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor
|
||||
scheduling.update({1: 1})
|
||||
|
||||
self.scheduling = scheduling
|
||||
self.epochs = sorted(scheduling.keys())
|
||||
|
||||
def on_epoch_begin(self, epoch, trainer):
|
||||
epoch += 1 # indexing epochs from 1
|
||||
for i in reversed(range(len(self.epochs))):
|
||||
if epoch >= self.epochs[i]:
|
||||
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
|
||||
break
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
c = EarlyStopping(min_delta=0.9, patience=2, verbose=True)
|
||||
losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5]
|
||||
|
|
|
@ -18,6 +18,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map
|
|||
from pytorch_lightning.root_module.model_saving import TrainerIO
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import (
|
||||
LightningDistributedDataParallel, LightningDataParallel)
|
||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
|
||||
try:
|
||||
|
@ -137,7 +138,13 @@ class Trainer(TrainerIO):
|
|||
self.early_stop = early_stop_callback
|
||||
self.model = None
|
||||
self.max_nb_epochs = max_nb_epochs
|
||||
self.accumulate_grad_batches = accumulate_grad_batches
|
||||
if isinstance(accumulate_grad_batches, dict):
|
||||
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
||||
elif isinstance(accumulate_grad_batches, int):
|
||||
schedule = {1: accumulate_grad_batches}
|
||||
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
|
||||
else:
|
||||
raise TypeError("Gradient accumulation supports only int and dict types")
|
||||
self.early_stop_callback = early_stop_callback
|
||||
self.min_nb_epochs = min_nb_epochs
|
||||
self.nb_sanity_val_steps = nb_sanity_val_steps
|
||||
|
@ -810,6 +817,9 @@ class Trainer(TrainerIO):
|
|||
if self.show_progress_bar:
|
||||
self.progress_bar.reset(self.total_batches)
|
||||
|
||||
# changing gradient according accumulation_scheduler
|
||||
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
|
||||
|
||||
# -----------------
|
||||
# RUN TNG EPOCH
|
||||
# -----------------
|
||||
|
|
|
@ -11,7 +11,11 @@ from test_tube import Experiment, SlurmCluster
|
|||
# sys.path += [os.path.abspath('..'), os.path.abspath('../..')]
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.testing import LightningTestModel, NoValEndTestModel, NoValModel
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
||||
from pytorch_lightning.callbacks import (
|
||||
ModelCheckpoint,
|
||||
EarlyStopping,
|
||||
GradientAccumulationScheduler,
|
||||
)
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.root_module import memory
|
||||
from pytorch_lightning.models.trainer import reduce_distributed_output
|
||||
|
@ -26,6 +30,78 @@ np.random.seed(SEED)
|
|||
# ------------------------------------------------------------------------
|
||||
# TESTS
|
||||
# ------------------------------------------------------------------------
|
||||
def test_gradient_accumulation_scheduling():
|
||||
"""
|
||||
Test grad accumulation by the freq of optimizer updates
|
||||
"""
|
||||
# test incorrect configs
|
||||
with pytest.raises(IndexError):
|
||||
assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6})
|
||||
assert Trainer(accumulate_grad_batches={-2: 3})
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
assert Trainer(accumulate_grad_batches={})
|
||||
assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]])
|
||||
assert Trainer(accumulate_grad_batches={1: 2, 3.: 4})
|
||||
assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})
|
||||
|
||||
# test optimizer call freq matches scheduler
|
||||
def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i):
|
||||
# only test the first 12 batches in epoch
|
||||
if batch_nb < 12:
|
||||
if epoch_nb == 0:
|
||||
# reset counter when starting epoch
|
||||
if batch_nb == 0:
|
||||
self.prev_called_batch_nb = 0
|
||||
|
||||
# use this opportunity to test once
|
||||
assert self.trainer.accumulate_grad_batches == 1
|
||||
|
||||
assert batch_nb == self.prev_called_batch_nb
|
||||
self.prev_called_batch_nb += 1
|
||||
|
||||
elif 1 <= epoch_nb <= 2:
|
||||
# reset counter when starting epoch
|
||||
if batch_nb == 1:
|
||||
self.prev_called_batch_nb = 1
|
||||
|
||||
# use this opportunity to test once
|
||||
assert self.trainer.accumulate_grad_batches == 2
|
||||
|
||||
assert batch_nb == self.prev_called_batch_nb
|
||||
self.prev_called_batch_nb += 2
|
||||
|
||||
else:
|
||||
if batch_nb == 3:
|
||||
self.prev_called_batch_nb = 3
|
||||
|
||||
# use this opportunity to test once
|
||||
assert self.trainer.accumulate_grad_batches == 4
|
||||
|
||||
assert batch_nb == self.prev_called_batch_nb
|
||||
self.prev_called_batch_nb += 3
|
||||
|
||||
optimizer.step()
|
||||
|
||||
# clear gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
hparams = get_hparams()
|
||||
model = LightningTestModel(hparams)
|
||||
schedule = {1: 2, 3: 4}
|
||||
|
||||
trainer = Trainer(accumulate_grad_batches=schedule,
|
||||
train_percent_check=0.1,
|
||||
val_percent_check=0.1,
|
||||
max_nb_epochs=4)
|
||||
|
||||
# for the test
|
||||
trainer.optimizer_step = optimizer_step
|
||||
model.prev_called_batch_nb = 0
|
||||
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_multi_gpu_model_ddp():
|
||||
"""
|
||||
Make sure DDP works
|
||||
|
|
Loading…
Reference in New Issue