More explicit exception message when testing with fast_dev_run=True (#6667)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
dcf6e4e310
commit
cca0eca5f3
|
@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
|
||||
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
|
||||
|
||||
|
||||
- Trigger warning when non-metric logged value with multi processes hasn't been reduced ([#6417](https://github.com/PyTorchLightning/pytorch-lightning/pull/6417))
|
||||
|
||||
|
||||
|
|
|
@ -955,31 +955,38 @@ class Trainer(
|
|||
model,
|
||||
ckpt_path: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
# if user requests the best checkpoint but we don't have it, error
|
||||
if ckpt_path == 'best' and not self.checkpoint_callback.best_model_path:
|
||||
if ckpt_path is None:
|
||||
return
|
||||
|
||||
fn = self.state.value
|
||||
|
||||
if ckpt_path == 'best':
|
||||
# if user requests the best checkpoint but we don't have it, error
|
||||
if not self.checkpoint_callback.best_model_path:
|
||||
if self.fast_dev_run:
|
||||
raise MisconfigurationException(
|
||||
f'You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do'
|
||||
f' `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.'
|
||||
)
|
||||
raise MisconfigurationException(
|
||||
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
|
||||
)
|
||||
# load best weights
|
||||
ckpt_path = self.checkpoint_callback.best_model_path
|
||||
|
||||
if not ckpt_path:
|
||||
raise MisconfigurationException(
|
||||
'ckpt_path is "best", but `ModelCheckpoint` is not configured to save the best model.'
|
||||
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
|
||||
f' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
|
||||
)
|
||||
|
||||
# load best weights
|
||||
if ckpt_path is not None:
|
||||
# ckpt_path is 'best' so load the best model
|
||||
if ckpt_path == 'best':
|
||||
ckpt_path = self.checkpoint_callback.best_model_path
|
||||
# only one process running at this point for TPUs, as spawn isn't triggered yet
|
||||
if self._device_type != DeviceType.TPU:
|
||||
self.training_type_plugin.barrier()
|
||||
|
||||
if not ckpt_path:
|
||||
fn = self.state.value
|
||||
raise MisconfigurationException(
|
||||
f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'
|
||||
' specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`'
|
||||
)
|
||||
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(ckpt['state_dict'])
|
||||
|
||||
# only one process running at this point for TPUs, as spawn isn't triggered yet
|
||||
if not self._device_type == DeviceType.TPU:
|
||||
self.training_type_plugin.barrier()
|
||||
|
||||
ckpt = pl_load(ckpt_path, map_location=lambda storage, loc: storage)
|
||||
model.load_state_dict(ckpt['state_dict'])
|
||||
return ckpt_path
|
||||
|
||||
def predict(
|
||||
|
|
|
@ -1777,3 +1777,12 @@ def test_trainer_attach_data_pipeline_to_model(tmpdir):
|
|||
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=[TestCallback()])
|
||||
trainer.fit(model, datamodule=dm)
|
||||
|
||||
|
||||
def test_exception_when_testing_or_validating_with_fast_dev_run(tmpdir):
|
||||
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
|
||||
|
||||
with pytest.raises(MisconfigurationException, match=r"\.validate\(\)` with `fast_dev_run=True"):
|
||||
trainer.validate()
|
||||
with pytest.raises(MisconfigurationException, match=r"\.test\(\)` with `fast_dev_run=True"):
|
||||
trainer.test()
|
||||
|
|
Loading…
Reference in New Issue