added warnings to unimplemented methods (#1317)

* added warnings and removed default optimizer

* opt

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
William Falcon 2020-04-03 15:06:51 -04:00 committed by GitHub
parent 3c5530c29d
commit e68ba1c836
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 2 deletions

View File

@ -37,8 +37,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199)) - Added model configuration checking ([#1199](https://github.com/PyTorchLightning/pytorch-lightning/pull/1199))
- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319)) - On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318)) - Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
- Remove default Adam optimizer ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317))
- Give warnings for unimplemented required lightning methods ([#1317](https://github.com/PyTorchLightning/pytorch-lightning/pull/1317))
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307)) - Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260)) - Made `evaluate` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))
### Deprecated ### Deprecated

View File

@ -269,7 +269,6 @@ In PyTorch we do it as follows:
In Lightning we do the same but organize it under the configure_optimizers method. In Lightning we do the same but organize it under the configure_optimizers method.
If you don't define this, Lightning will automatically use `Adam(self.parameters(), lr=1e-3)`.
.. code-block:: python .. code-block:: python
@ -278,6 +277,17 @@ If you don't define this, Lightning will automatically use `Adam(self.parameters
def configure_optimizers(self): def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3) return Adam(self.parameters(), lr=1e-3)
.. note:: The LightningModule itself has the parameters, so pass in self.parameters()
However, if you have multiple optimizers use the matching parameters
.. code-block:: python
class LitMNIST(pl.LightningModule):
def configure_optimizers(self):
return Adam(self.generator(), lr=1e-3), Adam(self.discriminator(), lr=1e-3)
Training step Training step
^^^^^^^^^^^^^ ^^^^^^^^^^^^^

View File

@ -224,6 +224,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
The presented loss value in progress bar is smooth (average) over last values, The presented loss value in progress bar is smooth (average) over last values,
so it differs from values set in train/validation step. so it differs from values set in train/validation step.
""" """
warnings.warn('`training_step` must be implemented to be used with the Lightning Trainer')
def training_end(self, *args, **kwargs): def training_end(self, *args, **kwargs):
""" """
@ -1079,6 +1080,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
} }
""" """
warnings.warn('`configure_optimizers` must be implemented to be used with the Lightning Trainer')
def optimizer_step( def optimizer_step(
self, self,
@ -1280,6 +1282,7 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
return loader return loader
""" """
warnings.warn('`train_dataloader` must be implemented to be used with the Lightning Trainer')
def tng_dataloader(self): # todo: remove in v1.0.0 def tng_dataloader(self): # todo: remove in v1.0.0
"""Implement a PyTorch DataLoader. """Implement a PyTorch DataLoader.