More explicit exception message when testing with fast_dev_run=True (#6667)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Łukasz Zalewski 2021-03-29 15:29:54 +02:00 committed by GitHub
parent dcf6e4e310
commit cca0eca5f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 20 deletions

View File

@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))

View File

@ -955,31 +955,38 @@ class Trainer(
model,
ckpt_path: Optional[str] = None,
) -> Optional[str]:
# if user requests the best checkpoint but we don't have it, error
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
if ckpt_path is None:
return
fn = self.state.value
if ckpt_path == 'best':
# if user requests the best checkpoint but we don't have it, error
if not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do'
f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.'
)
raise MisconfigurationException(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
)
# load best weights
ckpt_path = self.checkpoint_callback.best_model_path
if not ckpt_path:
raise MisconfigurationException(
'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.'
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)
# load best weights
if ckpt_path is not None:
# ckpt_path is 'best' so load the best model
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path
# only one process running at this point for TPUs, as spawn isn't triggered yet
if self._device_type != DeviceType.TPU:
self.training_type_plugin.barrier()
if not ckpt_path:
fn = self.state.value
raise MisconfigurationException(
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
)
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
# only one process running at this point for TPUs, as spawn isn't triggered yet
if not self._device_type == DeviceType.TPU:
self.training_type_plugin.barrier()
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
return ckpt_path
def predict(

View File

@ -1777,3 +1777,12 @@ def test_trainer_attach_data_pipeline_to_model(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()])
trainer.fit(model, datamodule=dm)
def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"):
trainer.validate()
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
trainer.test()