Gradient accumulation callback (#150)

* Gradient accumulation callback

* little test case

* typo

* import fix

* method name fix

* fix epochs indexing from 1

* better code style

* code style fix v2 :/

* change interface

* fix Trainre new api in tests

* trainer api bug fix

* new raising error, new update method

* extentions tests

* a little better tests

* typo fix

* flack8 better

* using scheduler for int and dict

* typo

* firs epoch bug fix

* test update

* empty dict exception

* floats check

* codestyle fix

* grad counting test

* someday, i will install normal linter

* add more checks

* Update test_models.py

* Update test_models.py

* Update test_models.py

* Update test_models.py

* Update test_models.py

* Update test_models.py

* Update test_models.py
This commit is contained in:
Stanislav 2019-08-30 17:56:14 +03:00 committed by William Falcon
parent c2247350bb
commit 73cf47112e
4 changed files with 122 additions and 3 deletions

View File

@ -1,6 +1,7 @@
from .pt_callbacks import EarlyStopping, ModelCheckpoint
from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler
__all__ = [
'EarlyStopping',
'ModelCheckpoint',
'GradientAccumulationScheduler',
]

View File

@ -1,5 +1,6 @@
import os
import shutil
import warnings
import numpy as np
@ -254,6 +255,37 @@ class ModelCheckpoint(Callback):
self.save_model(filepath, overwrite=False)
class GradientAccumulationScheduler(Callback):
"""Change gradient accumulation factor according to scheduling.
# Arguments
scheduling: dict, scheduling in format {epoch: accumulation_factor}
"""
def __init__(self, scheduling: dict):
if scheduling == {}: # empty dict error
raise TypeError("Empty dict cannot be interpreted correct")
for key in scheduling.keys():
if not isinstance(key, int) or not isinstance(scheduling[key], int):
raise TypeError("All epoches and accumulation factor must be integers")
minimal_epoch = min(scheduling.keys())
if minimal_epoch < 1:
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct"
raise IndexError(msg)
elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor
scheduling.update({1: 1})
self.scheduling = scheduling
self.epochs = sorted(scheduling.keys())
def on_epoch_begin(self, epoch, trainer):
epoch += 1 # indexing epochs from 1
for i in reversed(range(len(self.epochs))):
if epoch >= self.epochs[i]:
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i])
break
if __name__ == '__main__':
c = EarlyStopping(min_delta=0.9, patience=2, verbose=True)
losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5]

View File

@ -18,6 +18,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map
from pytorch_lightning.root_module.model_saving import TrainerIO
from pytorch_lightning.pt_overrides.override_data_parallel import (
LightningDistributedDataParallel, LightningDataParallel)
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities.debugging import MisconfigurationException
try:
@ -137,7 +138,13 @@ class Trainer(TrainerIO):
self.early_stop = early_stop_callback
self.model = None
self.max_nb_epochs = max_nb_epochs
self.accumulate_grad_batches = accumulate_grad_batches
if isinstance(accumulate_grad_batches, dict):
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):
schedule = {1: accumulate_grad_batches}
self.accumulation_scheduler = GradientAccumulationScheduler(schedule)
else:
raise TypeError("Gradient accumulation supports only int and dict types")
self.early_stop_callback = early_stop_callback
self.min_nb_epochs = min_nb_epochs
self.nb_sanity_val_steps = nb_sanity_val_steps
@ -810,6 +817,9 @@ class Trainer(TrainerIO):
if self.show_progress_bar:
self.progress_bar.reset(self.total_batches)
# changing gradient according accumulation_scheduler
self.accumulation_scheduler.on_epoch_begin(epoch_nb, self)
# -----------------
# RUN TNG EPOCH
# -----------------

View File

@ -11,7 +11,11 @@ from test_tube import Experiment, SlurmCluster
# sys.path += [os.path.abspath('..'), os.path.abspath('../..')]
from pytorch_lightning import Trainer
from pytorch_lightning.testing import LightningTestModel, NoValEndTestModel, NoValModel
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.callbacks import (
ModelCheckpoint,
EarlyStopping,
GradientAccumulationScheduler,
)
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.root_module import memory
from pytorch_lightning.models.trainer import reduce_distributed_output
@ -26,6 +30,78 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------
# TESTS
# ------------------------------------------------------------------------
def test_gradient_accumulation_scheduling():
"""
Test grad accumulation by the freq of optimizer updates
"""
# test incorrect configs
with pytest.raises(IndexError):
assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6})
assert Trainer(accumulate_grad_batches={-2: 3})
with pytest.raises(TypeError):
assert Trainer(accumulate_grad_batches={})
assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]])
assert Trainer(accumulate_grad_batches={1: 2, 3.: 4})
assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5})
# test optimizer call freq matches scheduler
def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i):
# only test the first 12 batches in epoch
if batch_nb < 12:
if epoch_nb == 0:
# reset counter when starting epoch
if batch_nb == 0:
self.prev_called_batch_nb = 0
# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 1
assert batch_nb == self.prev_called_batch_nb
self.prev_called_batch_nb += 1
elif 1 <= epoch_nb <= 2:
# reset counter when starting epoch
if batch_nb == 1:
self.prev_called_batch_nb = 1
# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 2
assert batch_nb == self.prev_called_batch_nb
self.prev_called_batch_nb += 2
else:
if batch_nb == 3:
self.prev_called_batch_nb = 3
# use this opportunity to test once
assert self.trainer.accumulate_grad_batches == 4
assert batch_nb == self.prev_called_batch_nb
self.prev_called_batch_nb += 3
optimizer.step()
# clear gradients
optimizer.zero_grad()
hparams = get_hparams()
model = LightningTestModel(hparams)
schedule = {1: 2, 3: 4}
trainer = Trainer(accumulate_grad_batches=schedule,
train_percent_check=0.1,
val_percent_check=0.1,
max_nb_epochs=4)
# for the test
trainer.optimizer_step = optimizer_step
model.prev_called_batch_nb = 0
trainer.fit(model)
def test_multi_gpu_model_ddp():
"""
Make sure DDP works