diff --git a/README.md b/README.md index c0fe260869..29d4d4fcb7 100644 --- a/README.md +++ b/README.md @@ -2,33 +2,8 @@ Seed for ML research ## Usage -To use lightning, define a model that implements these 10 functions: +To use lightning, define a model that implements these 10 functions: -#### Model definition -| Name | Description | Input | Return | -|---|---|---|---| -| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict | -| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict | -| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging | -| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | - -#### Model training -| Name | Description | Input | Return | -|---|---|---|---| -| configure_optimizers | called during training setup | None | list: optimizers you want to use | -| tng_dataloader | called during training | None | pytorch dataloader | -| val_dataloader | called during validation | None | pytorch dataloader | -| test_dataloader | called during testing | None | pytorch dataloader | -| add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse | - -#### Model Saving/Loading -| Name | Description | Input | Return | -|---|---|---|---| -| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | -| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want | - - -## Example ```python import torch.nn as nn @@ -102,7 +77,32 @@ class ExampleModel(RootModule): parser.add_argument('--out_features', default=20) return parser ``` +### Details +#### Model definition +| Name | Description | Input | Return | +|---|---|---|---| +| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict | +| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict | +| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging | +| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | + +#### Model training +| Name | Description | Input | Return | +|---|---|---|---| +| configure_optimizers | called during training setup | None | list: optimizers you want to use | +| tng_dataloader | called during training | None | pytorch dataloader | +| val_dataloader | called during validation | None | pytorch dataloader | +| test_dataloader | called during testing | None | pytorch dataloader | +| add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse | + +#### Model Saving/Loading +| Name | Description | Input | Return | +|---|---|---|---| +| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | +| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want | + + ### Add new model 1. Create a new model under /models. 2. Add model name to trainer_main