lightning/tests/models/test_grad_norm.py

105 lines
3.0 KiB
Python
Raw Normal View History

import numpy as np
2020-06-16 02:03:40 +00:00
import pytest
2020-06-16 02:03:40 +00:00
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from tests.base import EvalModelTemplate
from tests.base.develop_utils import reset_seed
class OnlyMetricsListLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.metrics = []
@rank_zero_only
def log_metrics(self, metrics, step):
self.metrics.append(metrics)
@property
def experiment(self):
return 'test'
@rank_zero_only
def log_hyperparams(self, params):
pass
@rank_zero_only
def finalize(self, status):
pass
@property
def name(self):
return 'name'
@property
def version(self):
return '1'
class ModelWithManualGradTracker(EvalModelTemplate):
def __init__(self, norm_type, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stored_grad_norms, self.norm_type = [], float(norm_type)
# validation spoils logger's metrics with `val_loss` records
validation_step = None
val_dataloader = None
def training_step(self, batch, batch_idx, optimizer_idx=None):
# just return a loss, no log or progress bar meta
x, y = batch
loss_val = self.loss(y, self(x.flatten(1, -1)))
return {'loss': loss_val}
def on_after_backward(self):
out, norms = {}, []
prefix = f'grad_{self.norm_type}_norm_'
for name, p in self.named_parameters():
if p.grad is None:
continue
# `np.linalg.norm` implementation likely uses fp64 intermediates
flat = p.grad.data.cpu().numpy().ravel()
norm = np.linalg.norm(flat, self.norm_type)
norms.append(norm)
out[prefix + name] = round(norm, 3)
# handle total norm
norm = np.linalg.norm(norms, self.norm_type)
out[prefix + 'total'] = round(norm, 3)
self.stored_grad_norms.append(out)
@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
# rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above
reset_seed()
# use a custom grad tracking module and a list logger
model = ModelWithManualGradTracker(norm_type)
logger = OnlyMetricsListLogger()
trainer = Trainer(
Continue Jeremy's early stopping PR #1504 (#2391) * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-06-29 01:36:46 +00:00
default_root_dir=tmpdir,
max_epochs=3,
logger=logger,
track_grad_norm=norm_type,
row_log_interval=1, # request grad_norms every batch
)
result = trainer.fit(model)
assert result == 1, "Training failed"
assert len(logger.metrics) == len(model.stored_grad_norms)
# compare the logged metrics against tracked norms on `.backward`
for mod, log in zip(model.stored_grad_norms, logger.metrics):
common = mod.keys() & log.keys()
log, mod = [log[k] for k in common], [mod[k] for k in common]
assert np.allclose(log, mod, rtol=rtol)