diff --git a/README.md b/README.md index 6a55575e92..03cee59505 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,7 @@ To use lightning, define a model that implements these 10 functions: | training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict | | validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict | | validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging | -| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | -| load_model_specific | | | | +| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | #### Model training | Name | Description | Input | Return | @@ -37,6 +36,7 @@ class ExampleModel(RootModule): def __init__(self): self.l1 = nn.Linear(100, 20) + # --------------- # TRAINING def training_step(self, data_batch): # your dataloader decides what each batch looks like @@ -71,7 +71,8 @@ class ExampleModel(RootModule): # return a dict return {'total_acc': np.mean(total_accs)} - + + # --------------- # SAVING def get_save_dict(self): # lightning saves for you. Here's your chance to say what you want to save @@ -84,6 +85,8 @@ class ExampleModel(RootModule): self.load_state_dict(checkpoint['state_dict']) pass + + # --------------- # TRAINING CONFIG def configure_optimizers(self): # give lightning the list of optimizers you want to use. @@ -104,7 +107,8 @@ class ExampleModel(RootModule): @property def test_dataloader(self): return pytorch_dataloader('test') - + + # --------------- # MODIFY YOUR COMMAND LINE ARGS @staticmethod def add_model_specific_args(parent_parser):