added docs page

This commit is contained in:
William Falcon 2019-06-26 19:44:41 -04:00
parent 0f44c5067a
commit 56f7ebf07e
1 changed files with 62 additions and 6 deletions

View File

@ -26,7 +26,7 @@ pip install pytorch-lightning
This is a research tool I built for myself internally while doing my PhD. The API is not 100% production quality, but my hope is that by open-sourcing, we can all get it there (I don't have too much time nowadays to write production-level code).
## What is it?
Keras is too abstract for researchers. Lightning makes it so you only have to define your model but still control all details of training if you need to.
Keras is too abstract for researchers. Lightning abstracts the full training loop but gives you control in the critical points.
To use lightning do 2 things:
1. Define a model with the lightning interface.
@ -69,12 +69,68 @@ Pytorch
<-- Lightning
Your model.
**Lightning will do the following for you:**
## Why do I want to use lightning?
Because you want to use best practices and get gpu training, multi-node training, checkpointing, mixed-precision, etc... for free.
## What are some key lightning features?
- Automatic training loop
```python
# define what happens for training here
def training_step(self, data_batch, batch_nb):
```
- Automatic validation loop
```python
# define what happens for validation here
def validation_step(self, data_batch, batch_nb):
```
- Automatic early stopping
```python
callback = EarlyStopping(...)
Trainer(early_stopping=callback)
```
- Automatic early stopping
```python
callback = EarlyStopping(...)
Trainer(early_stopping=callback)
```
- Learning rate annealing
```python
# anneal at 100 and 200 epochs
Trainer(lr_scheduler_milestones=[100, 200])
```
- 16 bit precision training
```--use_amp```
- multi-gpu training
```python
# train on 4 gpus
Trainer(gpus=[0, 1, 2, 3])
```
- Automatic checkpointing
```python
# do 3 things:
# 1
Trainer(checkpoint_callback=ModelCheckpoint)
# 2 return what to save in a checkpoint
def get_save_dict(self):
return {'state_dict': self.state_dict()}
# 3 use the checkpoint to reset your model state
def load_model_specific(self, checkpoint):
self.load_state_dict(checkpoint['state_dict'])
```
1. Run the training loop.
2. Run the validation loop.
3. Run the testing loop.
4. Early stopping.
5. Learning rate annealing.
6. Can train complex models like GANs or anything with multiple optimizers.
7. Weight checkpointing.