Refine style guide (#11394)

This commit is contained in:
Rohit Gupta 2022-02-03 11:47:18 +05:30 committed by GitHub
parent 0cb64fb8ba
commit 67438fc2f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 73 additions and 50 deletions

View File

@ -1,8 +1,9 @@
###########
Style guide
Style Guide
###########
A main goal of Lightning is to improve readability and reproducibility. Imagine looking into any GitHub repo,
finding a lightning module and knowing exactly where to look to find the things you care about.
A main goal of Lightning is to improve readability and reproducibility. Imagine looking into any GitHub repo or a research project,
finding a :class:`~pytorch_lightning.core.lightning.LightningModule`, and knowing exactly where to look to find the things you care about.
The goal of this style guide is to encourage Lightning code to be structured similarly.
@ -11,9 +12,10 @@ The goal of this style guide is to encourage Lightning code to be structured sim
***************
LightningModule
***************
These are best practices about structuring your LightningModule
Systems vs models
These are best practices about structuring your :class:`~pytorch_lightning.core.lightning.LightningModule` class:
Systems vs Models
=================
.. figure:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/model_system.png
@ -24,41 +26,62 @@ In Lightning we differentiate between a system and a model.
A model is something like a resnet18, RNN, etc.
A system defines how a collection of models interact with each other. Examples of this are:
A system defines how a collection of models interact with each other with user-defined training/evaluation logic. Examples of this are:
* GANs
* Seq2Seq
* BERT
* etc
* etc.
A LightningModule can define both a system and a model.
A LightningModule can define both a system and a model:
Here's a LightningModule that defines a model:
Here's a LightningModule that defines a system. This structure is what we recommend as a best practice. Keeping the model separate from the system improves
modularity, which eventually helps in better testing, reduces dependencies on the system and makes it easier to refactor.
.. testcode::
class Encoder(nn.Module):
...
class Decoder(nn.Module):
...
class AutoEncoder(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.encoder(x)
class AutoEncoderSystem(LightningModule):
def __init__(self):
super().__init__()
self.auto_encoder = AutoEncoder()
For fast prototyping it's often useful to define all the computations in a LightningModule. For reusability
and scalability it might be better to pass in the relevant backbones.
Here's a LightningModule that defines a model. Although, we do not recommend to define a model like in the example.
.. testcode::
class LitModel(LightningModule):
def __init__(self, num_layers: int = 3):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear()
self.layer_2 = nn.Linear()
self.layer_3 = nn.Linear()
Here's a LightningModule that defines a system:
.. testcode::
class LitModel(LightningModule):
def __init__(self, encoder: nn.Module = None, decoder: nn.Module = None):
super().__init__()
self.encoder = encoder
self.decoder = decoder
For fast prototyping it's often useful to define all the computations in a LightningModule. For reusability
and scalability it might be better to pass in the relevant backbones.
Self-contained
==============
A Lightning module should be self-contained. A good test to see how self-contained your model is, is to ask
yourself this question:
@ -69,6 +92,7 @@ a specific learning rate scheduler to work well.
Init
====
The first place where LightningModules tend to stop being self-contained is in the init. Try to define all the relevant
sensible defaults in the init so that the user doesn't have to guess.
@ -88,16 +112,17 @@ Instead, be explicit in your init
.. testcode::
class LitModel(LightningModule):
def __init__(self, encoder: nn.Module, coeff_x: float = 0.2, lr: float = 1e-3):
def __init__(self, encoder: nn.Module, coef_x: float = 0.2, lr: float = 1e-3):
...
Now the user doesn't have to guess. Instead they know the value type and the model has a sensible default where the
user can see the value immediately.
Method order
Method Order
============
The only required methods in the LightningModule are:
At the bare minimum, the only required methods in the LightningModule to configure a training pipeline are:
* init
* training_step
@ -110,6 +135,7 @@ However, if you decide to implement the rest of the optional methods, the recomm
* training hooks
* validation hooks
* test hooks
* predict hooks
* configure_optimizers
* any other hooks
@ -147,58 +173,55 @@ In practice, this code looks like:
Forward vs training_step
========================
We recommend using forward for inference/predictions and keeping training_step independent
We recommend using forward for inference/predictions and keeping ``training_step`` independent.
.. code-block:: python
def forward(self, x):
embeddings = self.encoder(x)
return embeddings
def training_step(self):
x, y = ...
def training_step(self, batch, batch_idx):
x, y = batch
z = self.encoder(x)
pred = self.decoder(z)
...
However, when using DataParallel, you will need to call forward manually
.. code-block:: python
def training_step(self):
x, y = ...
z = self(x) # < ---------- instead of self.encoder(x)
pred = self.decoder(z)
...
--------------
****
Data
****
These are best practices for handling data.
Dataloaders
===========
Lightning uses dataloaders to handle all the data flow through the system. Whenever you structure dataloaders,
Lightning uses :class:`~torch.utils.data.DataLoader` to handle all the data flow through the system. Whenever you structure dataloaders,
make sure to tune the number of workers for maximum efficiency.
.. warning:: Make sure not to use ddp_spawn with num_workers > 0 or you will bottleneck your code.
.. warning:: Make sure not to use ``Trainer(strategy="ddp_spawn")`` with ``num_workers>0`` in a DataLoader or you will bottleneck your code.
DataModules
===========
Lightning introduced datamodules. The problem with dataloaders is that sharing full datasets is often still challenging
because all these questions need to be answered:
* What splits were used?
* How many samples does this dataset have?
* What transforms were used?
* etc...
The :class:`~pytorch_lightning.core.datamodule.LightningDataModule` is designed as a way of decoupling data-related
hooks from the :class:`~pytorch_lightning.core.lightning.LightningModule` so you can develop dataset agnostic models. It makes it easy to hot swap different
datasets with your model, so you can test it and benchmark it across domains. It also makes sharing and reusing the exact data splits and transforms across projects possible.
It's for this reason that we recommend you use datamodules. This is specially important when collaborating because
it will save your team a lot of time as well.
Check out :ref:`data` document to understand data management within Lightning and its best practices.
All they need to do is drop a datamodule into a lightning trainer and not worry about what was done to the data.
------------
This is true for both academic and corporate settings where data cleaning and ad-hoc instructions slow down the progress
of iterating through ideas.
********
Examples
********
Checkout the live examples to get your hands dirty:
- `Introduction to PyTorch Lightning <https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/mnist-hello-world.html>`_
- `Introduction to DataModules <https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/datamodules.html>`_