Update README.md

This commit is contained in:
William Falcon 2019-03-30 21:22:38 -04:00 committed by GitHub
parent 97d730216e
commit 985af56892
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 4 deletions

View File

@ -11,7 +11,6 @@ To use lightning, define a model that implements these 10 functions:
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict | | validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging | | validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | | get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
| load_model_specific | | | |
#### Model training #### Model training
| Name | Description | Input | Return | | Name | Description | Input | Return |
@ -37,6 +36,7 @@ class ExampleModel(RootModule):
def __init__(self): def __init__(self):
self.l1 = nn.Linear(100, 20) self.l1 = nn.Linear(100, 20)
# ---------------
# TRAINING # TRAINING
def training_step(self, data_batch): def training_step(self, data_batch):
# your dataloader decides what each batch looks like # your dataloader decides what each batch looks like
@ -72,6 +72,7 @@ class ExampleModel(RootModule):
# return a dict # return a dict
return {'total_acc': np.mean(total_accs)} return {'total_acc': np.mean(total_accs)}
# ---------------
# SAVING # SAVING
def get_save_dict(self): def get_save_dict(self):
# lightning saves for you. Here's your chance to say what you want to save # lightning saves for you. Here's your chance to say what you want to save
@ -84,6 +85,8 @@ class ExampleModel(RootModule):
self.load_state_dict(checkpoint['state_dict']) self.load_state_dict(checkpoint['state_dict'])
pass pass
# ---------------
# TRAINING CONFIG # TRAINING CONFIG
def configure_optimizers(self): def configure_optimizers(self):
# give lightning the list of optimizers you want to use. # give lightning the list of optimizers you want to use.
@ -105,6 +108,7 @@ class ExampleModel(RootModule):
def test_dataloader(self): def test_dataloader(self):
return pytorch_dataloader('test') return pytorch_dataloader('test')
# ---------------
# MODIFY YOUR COMMAND LINE ARGS # MODIFY YOUR COMMAND LINE ARGS
@staticmethod @staticmethod
def add_model_specific_args(parent_parser): def add_model_specific_args(parent_parser):