Update README.md
This commit is contained in:
parent
0e82428eb9
commit
97d730216e
110
README.md
110
README.md
|
@ -2,6 +2,116 @@
|
|||
Seed for ML research
|
||||
|
||||
## Usage
|
||||
To use lightning, define a model that implements these 10 functions:
|
||||
|
||||
#### Model definition
|
||||
| Name | Description | Input | Return |
|
||||
|---|---|---|---|
|
||||
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
|
||||
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
|
||||
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
|
||||
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
||||
| load_model_specific | | | |
|
||||
|
||||
#### Model training
|
||||
| Name | Description | Input | Return |
|
||||
|---|---|---|---|
|
||||
| configure_optimizers | called during training setup | None | list: optimizers you want to use |
|
||||
| tng_dataloader | called during training | None | pytorch dataloader |
|
||||
| val_dataloader | called during validation | None | pytorch dataloader |
|
||||
| test_dataloader | called during testing | None | pytorch dataloader |
|
||||
| add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse |
|
||||
|
||||
#### Model Saving/Loading
|
||||
| Name | Description | Input | Return |
|
||||
|---|---|---|---|
|
||||
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
||||
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
|
||||
|
||||
|
||||
## Example
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
class ExampleModel(RootModule):
|
||||
def __init__(self):
|
||||
self.l1 = nn.Linear(100, 20)
|
||||
|
||||
# TRAINING
|
||||
def training_step(self, data_batch):
|
||||
# your dataloader decides what each batch looks like
|
||||
x, y = data_batch
|
||||
y_hat = self.l1(x)
|
||||
loss = some_loss(y_hat)
|
||||
|
||||
tqdm_dic = {'train_loss': loss}
|
||||
|
||||
# must return scalar, dict for logging
|
||||
return loss_val, tqdm_dic
|
||||
|
||||
def validation_step(self, data_batch):
|
||||
# same as training...
|
||||
x, y = data_batch
|
||||
y_hat = self.l1(x)
|
||||
loss = some_loss(y_hat)
|
||||
|
||||
# val specific
|
||||
acc = calculate_acc(y_hat, y)
|
||||
|
||||
tqdm_dic = {'train_loss': loss, 'val_acc': acc, 'whatever_you_want': 'a'}
|
||||
return loss_val, tqdm_dic
|
||||
|
||||
def validation_end(self, outputs):
|
||||
total_accs = []
|
||||
|
||||
# given to you by the framework with all validation outputs.
|
||||
# chance to collate
|
||||
for output in outputs:
|
||||
total_accs.append(output['val_acc'].item())
|
||||
|
||||
# return a dict
|
||||
return {'total_acc': np.mean(total_accs)}
|
||||
|
||||
# SAVING
|
||||
def get_save_dict(self):
|
||||
# lightning saves for you. Here's your chance to say what you want to save
|
||||
checkpoint = {'state_dict': self.state_dict()}
|
||||
|
||||
return checkpoint
|
||||
|
||||
def load_model_specific(self, checkpoint):
|
||||
# lightning loads for you. Here's your chance to say what you want to load
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
pass
|
||||
|
||||
# TRAINING CONFIG
|
||||
def configure_optimizers(self):
|
||||
# give lightning the list of optimizers you want to use.
|
||||
# lightning will call automatically
|
||||
optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
|
||||
self.optimizers = [optimizer]
|
||||
return self.optimizers
|
||||
|
||||
# LIGHTING WILL USE THE LOADERS YOU DEFINE HERE
|
||||
@property
|
||||
def tng_dataloader(self):
|
||||
return pytorch_dataloader('train')
|
||||
|
||||
@property
|
||||
def val_dataloader(self):
|
||||
return pytorch_dataloader('val')
|
||||
|
||||
@property
|
||||
def test_dataloader(self):
|
||||
return pytorch_dataloader('test')
|
||||
|
||||
# MODIFY YOUR COMMAND LINE ARGS
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser):
|
||||
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
|
||||
parser.add_argument('--out_features', default=20)
|
||||
return parser
|
||||
```
|
||||
|
||||
### Add new model
|
||||
1. Create a new model under /models.
|
||||
|
|
Loading…
Reference in New Issue