Fix num_sanity_val_steps is clipped to limit_val_batches (#2917)
* Fix num_sanity_val_steps according to limit_val_steps * fix test * add num_sanity_batches * pep * update docstring in test * add more test * chlog * update comments and docstring in test Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Adrian Wälchli <adrian.waelchli@inf.unibe.ch> Co-authored-by: Ananya Harsh Jha <ananya@pytorchlightning.ai>
This commit is contained in:
parent
bcdb750976
commit
7cca3859a7
|
@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed `num_sanity_val_steps` is clipped to `limit_val_batches` ([#2917](https://github.com/PyTorchLightning/pytorch-lightning/pull/2917))
|
||||
|
||||
## [0.9.0] - YYYY-MM-DD
|
||||
|
||||
|
@ -121,7 +122,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed automatic batch scaling not working with half precision ([#3045](https://github.com/PyTorchLightning/pytorch-lightning/pull/3045))
|
||||
- Fixed setting device to root gpu ([#3042](https://github.com/PyTorchLightning/pytorch-lightning/pull/3042))
|
||||
|
||||
|
||||
## [0.8.5] - 2020-07-09
|
||||
|
||||
### Added
|
||||
|
|
|
@ -307,7 +307,7 @@ class ProgressBar(ProgressBarBase):
|
|||
def on_sanity_check_start(self, trainer, pl_module):
|
||||
super().on_sanity_check_start(trainer, pl_module)
|
||||
self.val_progress_bar = self.init_sanity_tqdm()
|
||||
self.val_progress_bar.total = convert_inf(trainer.num_sanity_val_steps * len(trainer.val_dataloaders))
|
||||
self.val_progress_bar.total = convert_inf(sum(trainer.num_sanity_val_batches))
|
||||
self.main_progress_bar = tqdm(disable=True) # dummy progress bar
|
||||
|
||||
def on_sanity_check_end(self, trainer, pl_module):
|
||||
|
|
|
@ -377,6 +377,7 @@ class Trainer(
|
|||
self.logged_metrics = {}
|
||||
self.num_training_batches = 0
|
||||
self.num_val_batches = []
|
||||
self.num_sanity_val_batches = []
|
||||
self.num_test_batches = []
|
||||
self.train_dataloader = None
|
||||
self.test_dataloaders = None
|
||||
|
@ -463,9 +464,9 @@ class Trainer(
|
|||
self.min_steps = min_steps
|
||||
|
||||
if num_sanity_val_steps == -1:
|
||||
self.num_sanity_val_steps = float("inf")
|
||||
self.num_sanity_val_steps = float('inf')
|
||||
else:
|
||||
self.num_sanity_val_steps = min(num_sanity_val_steps, limit_val_batches)
|
||||
self.num_sanity_val_steps = num_sanity_val_steps
|
||||
|
||||
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
|
||||
|
||||
|
@ -1239,7 +1240,6 @@ class Trainer(
|
|||
self.train()
|
||||
|
||||
def _run_sanity_check(self, ref_model, model):
|
||||
|
||||
using_val_step = ref_model.val_dataloader is not None and self.is_overridden('validation_step')
|
||||
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
|
||||
|
||||
|
@ -1247,14 +1247,15 @@ class Trainer(
|
|||
# to make sure program won't crash during val
|
||||
if should_sanity_check:
|
||||
self.reset_val_dataloader(ref_model)
|
||||
self.num_sanity_val_batches = [
|
||||
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
|
||||
]
|
||||
|
||||
# hook and callback
|
||||
self.running_sanity_check = True
|
||||
self.on_sanity_check_start()
|
||||
|
||||
num_loaders = len(self.val_dataloaders)
|
||||
max_batches = [self.num_sanity_val_steps] * num_loaders
|
||||
eval_results = self._evaluate(model, self.val_dataloaders, max_batches, False)
|
||||
eval_results = self._evaluate(model, self.val_dataloaders, self.num_sanity_val_batches, False)
|
||||
|
||||
# allow no returns from eval
|
||||
if eval_results is not None and len(eval_results) > 0:
|
||||
|
|
|
@ -907,12 +907,42 @@ def test_tpu_choice(tmpdir, tpu_cores, expected_tpu_id, error_expected):
|
|||
pytest.param(0.0), # this should run no sanity checks
|
||||
pytest.param(1),
|
||||
pytest.param(1.0),
|
||||
pytest.param(0.3),
|
||||
pytest.param(0.5),
|
||||
pytest.param(5),
|
||||
])
|
||||
def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
||||
""" Test that the number of sanity check batches is clipped to limit_val_batches. """
|
||||
model = EvalModelTemplate()
|
||||
model.validation_step = model.validation_step__multiple_dataloaders
|
||||
model.validation_epoch_end = model.validation_epoch_end__multiple_dataloaders
|
||||
num_sanity_val_steps = 4
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=num_sanity_val_steps,
|
||||
limit_val_batches=limit_val_batches,
|
||||
max_steps=1,
|
||||
)
|
||||
assert trainer.num_sanity_val_steps == num_sanity_val_steps
|
||||
val_dataloaders = model.val_dataloader__multiple_mixed_length()
|
||||
|
||||
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
|
||||
trainer.fit(model, val_dataloaders=val_dataloaders)
|
||||
assert mocked.call_count == sum(
|
||||
min(num_sanity_val_steps, num_batches) for num_batches in trainer.num_val_batches
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(['limit_val_batches'], [
|
||||
pytest.param(0.0), # this should run no sanity checks
|
||||
pytest.param(1),
|
||||
pytest.param(1.0),
|
||||
pytest.param(0.3),
|
||||
])
|
||||
def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
|
||||
"""
|
||||
Test that num_sanity_val_steps=-1 runs through all validation data once.
|
||||
Makes sure this setting is independent of limit_val_batches.
|
||||
Test that num_sanity_val_steps=-1 runs through all validation data once, and as many batches as
|
||||
limited by "limit_val_batches" Trainer argument.
|
||||
"""
|
||||
model = EvalModelTemplate()
|
||||
model.validation_step = model.validation_step__multiple_dataloaders
|
||||
|
@ -920,7 +950,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
|||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
num_sanity_val_steps=-1,
|
||||
limit_val_batches=limit_val_batches, # should have no influence
|
||||
limit_val_batches=limit_val_batches,
|
||||
max_steps=1,
|
||||
)
|
||||
assert trainer.num_sanity_val_steps == float('inf')
|
||||
|
@ -928,7 +958,7 @@ def test_num_sanity_val_steps(tmpdir, limit_val_batches):
|
|||
|
||||
with patch.object(trainer, 'evaluation_forward', wraps=trainer.evaluation_forward) as mocked:
|
||||
trainer.fit(model, val_dataloaders=val_dataloaders)
|
||||
assert mocked.call_count == sum(len(dl) * (limit_val_batches > 0) for dl in val_dataloaders)
|
||||
assert mocked.call_count == sum(trainer.num_val_batches)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("trainer_kwargs,expected", [
|
||||
|
|
Loading…
Reference in New Issue