From 73cf47112eefe440ecec31da94803aa6ec677c0e Mon Sep 17 00:00:00 2001 From: Stanislav Date: Fri, 30 Aug 2019 17:56:14 +0300 Subject: [PATCH] Gradient accumulation callback (#150) * Gradient accumulation callback * little test case * typo * import fix * method name fix * fix epochs indexing from 1 * better code style * code style fix v2 :/ * change interface * fix Trainre new api in tests * trainer api bug fix * new raising error, new update method * extentions tests * a little better tests * typo fix * flack8 better * using scheduler for int and dict * typo * firs epoch bug fix * test update * empty dict exception * floats check * codestyle fix * grad counting test * someday, i will install normal linter * add more checks * Update test_models.py * Update test_models.py * Update test_models.py * Update test_models.py * Update test_models.py * Update test_models.py * Update test_models.py --- pytorch_lightning/callbacks/__init__.py | 3 +- pytorch_lightning/callbacks/pt_callbacks.py | 32 +++++++++ pytorch_lightning/models/trainer.py | 12 +++- tests/test_models.py | 78 ++++++++++++++++++++- 4 files changed, 122 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/__init__.py b/pytorch_lightning/callbacks/__init__.py index 035deb0681..9538036563 100644 --- a/pytorch_lightning/callbacks/__init__.py +++ b/pytorch_lightning/callbacks/__init__.py @@ -1,6 +1,7 @@ -from .pt_callbacks import EarlyStopping, ModelCheckpoint +from .pt_callbacks import EarlyStopping, ModelCheckpoint, GradientAccumulationScheduler __all__ = [ 'EarlyStopping', 'ModelCheckpoint', + 'GradientAccumulationScheduler', ] diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 24d46e756d..aa256b578c 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -1,5 +1,6 @@ import os import shutil +import warnings import numpy as np @@ -254,6 +255,37 @@ class ModelCheckpoint(Callback): self.save_model(filepath, overwrite=False) +class GradientAccumulationScheduler(Callback): + """Change gradient accumulation factor according to scheduling. + # Arguments + scheduling: dict, scheduling in format {epoch: accumulation_factor} + """ + def __init__(self, scheduling: dict): + if scheduling == {}: # empty dict error + raise TypeError("Empty dict cannot be interpreted correct") + + for key in scheduling.keys(): + if not isinstance(key, int) or not isinstance(scheduling[key], int): + raise TypeError("All epoches and accumulation factor must be integers") + + minimal_epoch = min(scheduling.keys()) + if minimal_epoch < 1: + msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" + raise IndexError(msg) + elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor + scheduling.update({1: 1}) + + self.scheduling = scheduling + self.epochs = sorted(scheduling.keys()) + + def on_epoch_begin(self, epoch, trainer): + epoch += 1 # indexing epochs from 1 + for i in reversed(range(len(self.epochs))): + if epoch >= self.epochs[i]: + trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) + break + + if __name__ == '__main__': c = EarlyStopping(min_delta=0.9, patience=2, verbose=True) losses = [10, 9, 8, 8, 6, 4.3, 5, 4.4, 2.8, 2.5] diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 87125b9942..bc075282d6 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -18,6 +18,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map from pytorch_lightning.root_module.model_saving import TrainerIO from pytorch_lightning.pt_overrides.override_data_parallel import ( LightningDistributedDataParallel, LightningDataParallel) +from pytorch_lightning.callbacks import GradientAccumulationScheduler from pytorch_lightning.utilities.debugging import MisconfigurationException try: @@ -137,7 +138,13 @@ class Trainer(TrainerIO): self.early_stop = early_stop_callback self.model = None self.max_nb_epochs = max_nb_epochs - self.accumulate_grad_batches = accumulate_grad_batches + if isinstance(accumulate_grad_batches, dict): + self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) + elif isinstance(accumulate_grad_batches, int): + schedule = {1: accumulate_grad_batches} + self.accumulation_scheduler = GradientAccumulationScheduler(schedule) + else: + raise TypeError("Gradient accumulation supports only int and dict types") self.early_stop_callback = early_stop_callback self.min_nb_epochs = min_nb_epochs self.nb_sanity_val_steps = nb_sanity_val_steps @@ -810,6 +817,9 @@ class Trainer(TrainerIO): if self.show_progress_bar: self.progress_bar.reset(self.total_batches) + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_begin(epoch_nb, self) + # ----------------- # RUN TNG EPOCH # ----------------- diff --git a/tests/test_models.py b/tests/test_models.py index 9d4c5dc9b1..63af928783 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -11,7 +11,11 @@ from test_tube import Experiment, SlurmCluster # sys.path += [os.path.abspath('..'), os.path.abspath('../..')] from pytorch_lightning import Trainer from pytorch_lightning.testing import LightningTestModel, NoValEndTestModel, NoValModel -from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping +from pytorch_lightning.callbacks import ( + ModelCheckpoint, + EarlyStopping, + GradientAccumulationScheduler, +) from pytorch_lightning.utilities.debugging import MisconfigurationException from pytorch_lightning.root_module import memory from pytorch_lightning.models.trainer import reduce_distributed_output @@ -26,6 +30,78 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ +def test_gradient_accumulation_scheduling(): + """ + Test grad accumulation by the freq of optimizer updates + """ + # test incorrect configs + with pytest.raises(IndexError): + assert Trainer(accumulate_grad_batches={0: 3, 1: 4, 4: 6}) + assert Trainer(accumulate_grad_batches={-2: 3}) + + with pytest.raises(TypeError): + assert Trainer(accumulate_grad_batches={}) + assert Trainer(accumulate_grad_batches=[[2, 3], [4, 6]]) + assert Trainer(accumulate_grad_batches={1: 2, 3.: 4}) + assert Trainer(accumulate_grad_batches={1: 2.5, 3: 5}) + + # test optimizer call freq matches scheduler + def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i): + # only test the first 12 batches in epoch + if batch_nb < 12: + if epoch_nb == 0: + # reset counter when starting epoch + if batch_nb == 0: + self.prev_called_batch_nb = 0 + + # use this opportunity to test once + assert self.trainer.accumulate_grad_batches == 1 + + assert batch_nb == self.prev_called_batch_nb + self.prev_called_batch_nb += 1 + + elif 1 <= epoch_nb <= 2: + # reset counter when starting epoch + if batch_nb == 1: + self.prev_called_batch_nb = 1 + + # use this opportunity to test once + assert self.trainer.accumulate_grad_batches == 2 + + assert batch_nb == self.prev_called_batch_nb + self.prev_called_batch_nb += 2 + + else: + if batch_nb == 3: + self.prev_called_batch_nb = 3 + + # use this opportunity to test once + assert self.trainer.accumulate_grad_batches == 4 + + assert batch_nb == self.prev_called_batch_nb + self.prev_called_batch_nb += 3 + + optimizer.step() + + # clear gradients + optimizer.zero_grad() + + hparams = get_hparams() + model = LightningTestModel(hparams) + schedule = {1: 2, 3: 4} + + trainer = Trainer(accumulate_grad_batches=schedule, + train_percent_check=0.1, + val_percent_check=0.1, + max_nb_epochs=4) + + # for the test + trainer.optimizer_step = optimizer_step + model.prev_called_batch_nb = 0 + + trainer.fit(model) + + def test_multi_gpu_model_ddp(): """ Make sure DDP works