added docs page
This commit is contained in:
parent
0f44c5067a
commit
56f7ebf07e
68
README.md
68
README.md
|
@ -26,7 +26,7 @@ pip install pytorch-lightning
|
||||||
This is a research tool I built for myself internally while doing my PhD. The API is not 100% production quality, but my hope is that by open-sourcing, we can all get it there (I don't have too much time nowadays to write production-level code).
|
This is a research tool I built for myself internally while doing my PhD. The API is not 100% production quality, but my hope is that by open-sourcing, we can all get it there (I don't have too much time nowadays to write production-level code).
|
||||||
|
|
||||||
## What is it?
|
## What is it?
|
||||||
Keras is too abstract for researchers. Lightning makes it so you only have to define your model but still control all details of training if you need to.
|
Keras is too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points.
|
||||||
|
|
||||||
To use lightning do 2 things:
|
To use lightning do 2 things:
|
||||||
1. Define a model with the lightning interface.
|
1. Define a model with the lightning interface.
|
||||||
|
@ -69,12 +69,68 @@ Pytorch
|
||||||
<-- Lightning
|
<-- Lightning
|
||||||
Your model.
|
Your model.
|
||||||
|
|
||||||
**Lightning will do the following for you:**
|
## Why do I want to use lightning?
|
||||||
|
Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free.
|
||||||
|
|
||||||
|
|
||||||
|
## What are some key lightning features?
|
||||||
|
|
||||||
|
- Automatic training loop
|
||||||
|
```python
|
||||||
|
# define what happens for training here
|
||||||
|
def training_step(self, data_batch, batch_nb):
|
||||||
|
```
|
||||||
|
- Automatic validation loop
|
||||||
|
|
||||||
|
```python
|
||||||
|
# define what happens for validation here
|
||||||
|
def validation_step(self, data_batch, batch_nb):
|
||||||
|
```
|
||||||
|
- Automatic early stopping
|
||||||
|
```python
|
||||||
|
callback = EarlyStopping(...)
|
||||||
|
Trainer(early_stopping=callback)
|
||||||
|
```
|
||||||
|
|
||||||
|
- Automatic early stopping
|
||||||
|
```python
|
||||||
|
callback = EarlyStopping(...)
|
||||||
|
Trainer(early_stopping=callback)
|
||||||
|
```
|
||||||
|
|
||||||
|
- Learning rate annealing
|
||||||
|
```python
|
||||||
|
# anneal at 100 and 200 epochs
|
||||||
|
Trainer(lr_scheduler_milestones=[100, 200])
|
||||||
|
```
|
||||||
|
|
||||||
|
- 16 bit precision training
|
||||||
|
```--use_amp```
|
||||||
|
|
||||||
|
- multi-gpu training
|
||||||
|
```python
|
||||||
|
# train on 4 gpus
|
||||||
|
Trainer(gpus=[0, 1, 2, 3])
|
||||||
|
```
|
||||||
|
|
||||||
|
- Automatic checkpointing
|
||||||
|
```python
|
||||||
|
# do 3 things:
|
||||||
|
# 1
|
||||||
|
Trainer(checkpoint_callback=ModelCheckpoint)
|
||||||
|
|
||||||
|
# 2 return what to save in a checkpoint
|
||||||
|
def get_save_dict(self):
|
||||||
|
return {'state_dict': self.state_dict()}
|
||||||
|
|
||||||
|
|
||||||
|
# 3 use the checkpoint to reset your model state
|
||||||
|
def load_model_specific(self, checkpoint):
|
||||||
|
self.load_state_dict(checkpoint['state_dict'])
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
1. Run the training loop.
|
|
||||||
2. Run the validation loop.
|
|
||||||
3. Run the testing loop.
|
|
||||||
4. Early stopping.
|
|
||||||
5. Learning rate annealing.
|
5. Learning rate annealing.
|
||||||
6. Can train complex models like GANs or anything with multiple optimizers.
|
6. Can train complex models like GANs or anything with multiple optimizers.
|
||||||
7. Weight checkpointing.
|
7. Weight checkpointing.
|
||||||
|
|
Loading…
Reference in New Issue