2019-03-31 00:50:32 +00:00
# Pytorch-lightning
2019-03-31 01:27:11 +00:00
The Keras for ML-research in Pytorch. Simple to use, but not as abstracted out.
2019-03-31 00:50:32 +00:00
## Usage
2019-03-31 01:32:37 +00:00
To use lightning, first define a trainer function:
2019-03-31 01:21:10 +00:00
2019-03-31 01:32:37 +00:00
```python
# trainer.py
from pytorch_lightning.models.trainer import Trainer
from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint
from my_project import My_Model
# --------------------
# CALLBACKS
early_stop = EarlyStopping(
monitor='val_loss',
patience=3,
verbose=True,
mode='min'
)
model_save_path = 'PATH/TO/SAVE'
checkpoint = ModelCheckpoint(
filepath=model_save_path,
save_function=None,
save_best_only=True,
verbose=True,
monitor='val_acc',
mode='min'
)
# configure trainer
trainer = Trainer(
on_gpu=False,
enable_tqdm=True,
overfit_pct=None,
track_grad_norm=-1,
fast_dev_run=False,
check_val_every_n_epoch=1,
accumulate_grad_batches=2,
process_position=0,
current_gpu_name=0,
checkpoint_callback=checkpoint,
early_stop_callback=early_stop,
enable_early_stop=True,
max_nb_epochs=12,
min_nb_epochs=2,
train_percent_check=1.0,
val_percent_check=0.5,
test_percent_check=0.5,
val_check_interval=0.95,
log_save_interval=0.95,
add_log_row_interval=20,
lr_scheduler_milestones=None
2019-03-31 01:33:39 +00:00
)
# init model
model = My_Model()
trainer.fit(model)
2019-03-31 01:32:37 +00:00
```
next define a model that implements these 10 functions:
2019-03-31 01:21:10 +00:00
```python
import torch.nn as nn
class ExampleModel(RootModule):
def __init__ (self):
2019-03-31 01:24:46 +00:00
# define model
2019-03-31 01:21:10 +00:00
2019-03-31 01:22:38 +00:00
# ---------------
2019-03-31 01:21:10 +00:00
# TRAINING
def training_step(self, data_batch):
x, y = data_batch
y_hat = self.l1(x)
loss = some_loss(y_hat)
2019-03-31 01:24:46 +00:00
return loss_val, {'train_loss': loss}
2019-03-31 01:21:10 +00:00
def validation_step(self, data_batch):
x, y = data_batch
y_hat = self.l1(x)
loss = some_loss(y_hat)
2019-03-31 01:24:46 +00:00
return loss_val, {'val_loss': loss}
2019-03-31 01:21:10 +00:00
def validation_end(self, outputs):
total_accs = []
for output in outputs:
total_accs.append(output['val_acc'].item())
# return a dict
return {'total_acc': np.mean(total_accs)}
2019-03-31 01:22:38 +00:00
# ---------------
2019-03-31 01:21:10 +00:00
# SAVING
def get_save_dict(self):
# lightning saves for you. Here's your chance to say what you want to save
checkpoint = {'state_dict': self.state_dict()}
return checkpoint
def load_model_specific(self, checkpoint):
# lightning loads for you. Here's your chance to say what you want to load
self.load_state_dict(checkpoint['state_dict'])
2019-03-31 01:22:38 +00:00
# ---------------
2019-03-31 01:21:10 +00:00
# TRAINING CONFIG
def configure_optimizers(self):
# give lightning the list of optimizers you want to use.
# lightning will call automatically
2019-03-31 01:24:46 +00:00
optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
return [optimizer]
2019-03-31 01:21:10 +00:00
@property
def tng_dataloader(self):
return pytorch_dataloader('train')
@property
def val_dataloader(self):
return pytorch_dataloader('val')
@property
def test_dataloader(self):
return pytorch_dataloader('test')
2019-03-31 01:22:38 +00:00
# ---------------
2019-03-31 01:21:10 +00:00
# MODIFY YOUR COMMAND LINE ARGS
@staticmethod
def add_model_specific_args(parent_parser):
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
parser.add_argument('--out_features', default=20)
return parser
```
2019-03-31 01:25:43 +00:00
### Details
2019-03-31 00:50:32 +00:00
2019-03-31 01:25:43 +00:00
#### Model definition
| Name | Description | Input | Return |
|---|---|---|---|
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
#### Model training
| Name | Description | Input | Return |
|---|---|---|---|
| configure_optimizers | called during training setup | None | list: optimizers you want to use |
| tng_dataloader | called during training | None | pytorch dataloader |
| val_dataloader | called during validation | None | pytorch dataloader |
| test_dataloader | called during testing | None | pytorch dataloader |
| add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse |
#### Model Saving/Loading
| Name | Description | Input | Return |
|---|---|---|---|
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
2019-03-31 00:50:32 +00:00
### Add new model
1. Create a new model under /models.
2. Add model name to trainer_main
```python
AVAILABLE_MODELS = {
'model_1': ExampleModel1
}
```
### Model methods that can be implemented
| Method | Purpose | Input | Output | Required |
|---|---|---|---|---|
| forward() | Forward pass | model_in tuple with your data | model_out tuple to be passed to loss | Y |
| loss() | calculate model loss | model_out tuple from forward() | A scalar | Y |
| check_performance() | run a full loop through val data to check for metrics | dataloader, nb_tests | metrics tuple to be tracked | Y |
| tng_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
| val_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
| test_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
### Model lifecycle hooks
Use these hooks to customize functionality
| Method | Purpose | Input | Output | Required |
|---|---|---|---|---|
| on_batch_start() | called right before the batch starts | - | - | N |
| on_batch_end() | called right after the batch ends | - | - | N |
| on_epoch_start() | called right before the epoch starts | - | - | N |
| on_epoch_end() | called right afger the epoch ends | - | - | N |
| on_pre_performance_check() | called right before the performance check starts | - | - | N |
| on_post_performance_check() | called right after the batch starts | - | - | N |