Update README.md
This commit is contained in:
parent
649f4d5f4d
commit
24255a9eab
50
README.md
50
README.md
|
@ -4,31 +4,6 @@ Seed for ML research
|
||||||
## Usage
|
## Usage
|
||||||
To use lightning, define a model that implements these 10 functions:
|
To use lightning, define a model that implements these 10 functions:
|
||||||
|
|
||||||
#### Model definition
|
|
||||||
| Name | Description | Input | Return |
|
|
||||||
|---|---|---|---|
|
|
||||||
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
|
|
||||||
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
|
|
||||||
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
|
|
||||||
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
|
||||||
|
|
||||||
#### Model training
|
|
||||||
| Name | Description | Input | Return |
|
|
||||||
|---|---|---|---|
|
|
||||||
| configure_optimizers | called during training setup | None | list: optimizers you want to use |
|
|
||||||
| tng_dataloader | called during training | None | pytorch dataloader |
|
|
||||||
| val_dataloader | called during validation | None | pytorch dataloader |
|
|
||||||
| test_dataloader | called during testing | None | pytorch dataloader |
|
|
||||||
| add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse |
|
|
||||||
|
|
||||||
#### Model Saving/Loading
|
|
||||||
| Name | Description | Input | Return |
|
|
||||||
|---|---|---|---|
|
|
||||||
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
|
||||||
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
|
|
||||||
|
|
||||||
|
|
||||||
## Example
|
|
||||||
```python
|
```python
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -102,6 +77,31 @@ class ExampleModel(RootModule):
|
||||||
parser.add_argument('--out_features', default=20)
|
parser.add_argument('--out_features', default=20)
|
||||||
return parser
|
return parser
|
||||||
```
|
```
|
||||||
|
### Details
|
||||||
|
|
||||||
|
#### Model definition
|
||||||
|
| Name | Description | Input | Return |
|
||||||
|
|---|---|---|---|
|
||||||
|
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
|
||||||
|
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
|
||||||
|
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
|
||||||
|
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
||||||
|
|
||||||
|
#### Model training
|
||||||
|
| Name | Description | Input | Return |
|
||||||
|
|---|---|---|---|
|
||||||
|
| configure_optimizers | called during training setup | None | list: optimizers you want to use |
|
||||||
|
| tng_dataloader | called during training | None | pytorch dataloader |
|
||||||
|
| val_dataloader | called during validation | None | pytorch dataloader |
|
||||||
|
| test_dataloader | called during testing | None | pytorch dataloader |
|
||||||
|
| add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse |
|
||||||
|
|
||||||
|
#### Model Saving/Loading
|
||||||
|
| Name | Description | Input | Return |
|
||||||
|
|---|---|---|---|
|
||||||
|
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
||||||
|
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
|
||||||
|
|
||||||
|
|
||||||
### Add new model
|
### Add new model
|
||||||
1. Create a new model under /models.
|
1. Create a new model under /models.
|
||||||
|
|
Loading…
Reference in New Issue