136 lines
5.3 KiB
Markdown
136 lines
5.3 KiB
Markdown
# Pytorch-lightning
|
|
The Keras for ML-research in Pytorch. Simple to use, but not as abstracted out.
|
|
|
|
## Usage
|
|
To use lightning, define a model that implements these 10 functions:
|
|
|
|
```python
|
|
import torch.nn as nn
|
|
|
|
class ExampleModel(RootModule):
|
|
def __init__(self):
|
|
# define model
|
|
|
|
# ---------------
|
|
# TRAINING
|
|
def training_step(self, data_batch):
|
|
x, y = data_batch
|
|
y_hat = self.l1(x)
|
|
loss = some_loss(y_hat)
|
|
|
|
return loss_val, {'train_loss': loss}
|
|
|
|
def validation_step(self, data_batch):
|
|
x, y = data_batch
|
|
y_hat = self.l1(x)
|
|
loss = some_loss(y_hat)
|
|
|
|
return loss_val, {'val_loss': loss}
|
|
|
|
def validation_end(self, outputs):
|
|
total_accs = []
|
|
|
|
for output in outputs:
|
|
total_accs.append(output['val_acc'].item())
|
|
|
|
# return a dict
|
|
return {'total_acc': np.mean(total_accs)}
|
|
|
|
# ---------------
|
|
# SAVING
|
|
def get_save_dict(self):
|
|
# lightning saves for you. Here's your chance to say what you want to save
|
|
checkpoint = {'state_dict': self.state_dict()}
|
|
|
|
return checkpoint
|
|
|
|
def load_model_specific(self, checkpoint):
|
|
# lightning loads for you. Here's your chance to say what you want to load
|
|
self.load_state_dict(checkpoint['state_dict'])
|
|
|
|
# ---------------
|
|
# TRAINING CONFIG
|
|
def configure_optimizers(self):
|
|
# give lightning the list of optimizers you want to use.
|
|
# lightning will call automatically
|
|
optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
|
|
return [optimizer]
|
|
|
|
@property
|
|
def tng_dataloader(self):
|
|
return pytorch_dataloader('train')
|
|
|
|
@property
|
|
def val_dataloader(self):
|
|
return pytorch_dataloader('val')
|
|
|
|
@property
|
|
def test_dataloader(self):
|
|
return pytorch_dataloader('test')
|
|
|
|
# ---------------
|
|
# MODIFY YOUR COMMAND LINE ARGS
|
|
@staticmethod
|
|
def add_model_specific_args(parent_parser):
|
|
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
|
|
parser.add_argument('--out_features', default=20)
|
|
return parser
|
|
```
|
|
### Details
|
|
|
|
#### Model definition
|
|
| Name | Description | Input | Return |
|
|
|---|---|---|---|
|
|
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
|
|
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
|
|
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
|
|
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
|
|
|
#### Model training
|
|
| Name | Description | Input | Return |
|
|
|---|---|---|---|
|
|
| configure_optimizers | called during training setup | None | list: optimizers you want to use |
|
|
| tng_dataloader | called during training | None | pytorch dataloader |
|
|
| val_dataloader | called during validation | None | pytorch dataloader |
|
|
| test_dataloader | called during testing | None | pytorch dataloader |
|
|
| add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse |
|
|
|
|
#### Model Saving/Loading
|
|
| Name | Description | Input | Return |
|
|
|---|---|---|---|
|
|
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
|
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
|
|
|
|
|
|
### Add new model
|
|
1. Create a new model under /models.
|
|
2. Add model name to trainer_main
|
|
```python
|
|
AVAILABLE_MODELS = {
|
|
'model_1': ExampleModel1
|
|
}
|
|
```
|
|
|
|
### Model methods that can be implemented
|
|
|
|
| Method | Purpose | Input | Output | Required |
|
|
|---|---|---|---|---|
|
|
| forward() | Forward pass | model_in tuple with your data | model_out tuple to be passed to loss | Y |
|
|
| loss() | calculate model loss | model_out tuple from forward() | A scalar | Y |
|
|
| check_performance() | run a full loop through val data to check for metrics | dataloader, nb_tests | metrics tuple to be tracked | Y |
|
|
| tng_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
|
|
| val_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
|
|
| test_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
|
|
|
|
### Model lifecycle hooks
|
|
Use these hooks to customize functionality
|
|
|
|
| Method | Purpose | Input | Output | Required |
|
|
|---|---|---|---|---|
|
|
| on_batch_start() | called right before the batch starts | - | - | N |
|
|
| on_batch_end() | called right after the batch ends | - | - | N |
|
|
| on_epoch_start() | called right before the epoch starts | - | - | N |
|
|
| on_epoch_end() | called right afger the epoch ends | - | - | N |
|
|
| on_pre_performance_check() | called right before the performance check starts | - | - | N |
|
|
| on_post_performance_check() | called right after the batch starts | - | - | N |
|