Update README.md
This commit is contained in:
parent
97d730216e
commit
985af56892
12
README.md
12
README.md
|
@ -10,8 +10,7 @@ To use lightning, define a model that implements these 10 functions:
|
||||||
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
|
| training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict |
|
||||||
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
|
| validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict |
|
||||||
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
|
| validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging |
|
||||||
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
||||||
| load_model_specific | | | |
|
|
||||||
|
|
||||||
#### Model training
|
#### Model training
|
||||||
| Name | Description | Input | Return |
|
| Name | Description | Input | Return |
|
||||||
|
@ -37,6 +36,7 @@ class ExampleModel(RootModule):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.l1 = nn.Linear(100, 20)
|
self.l1 = nn.Linear(100, 20)
|
||||||
|
|
||||||
|
# ---------------
|
||||||
# TRAINING
|
# TRAINING
|
||||||
def training_step(self, data_batch):
|
def training_step(self, data_batch):
|
||||||
# your dataloader decides what each batch looks like
|
# your dataloader decides what each batch looks like
|
||||||
|
@ -71,7 +71,8 @@ class ExampleModel(RootModule):
|
||||||
|
|
||||||
# return a dict
|
# return a dict
|
||||||
return {'total_acc': np.mean(total_accs)}
|
return {'total_acc': np.mean(total_accs)}
|
||||||
|
|
||||||
|
# ---------------
|
||||||
# SAVING
|
# SAVING
|
||||||
def get_save_dict(self):
|
def get_save_dict(self):
|
||||||
# lightning saves for you. Here's your chance to say what you want to save
|
# lightning saves for you. Here's your chance to say what you want to save
|
||||||
|
@ -84,6 +85,8 @@ class ExampleModel(RootModule):
|
||||||
self.load_state_dict(checkpoint['state_dict'])
|
self.load_state_dict(checkpoint['state_dict'])
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------
|
||||||
# TRAINING CONFIG
|
# TRAINING CONFIG
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
# give lightning the list of optimizers you want to use.
|
# give lightning the list of optimizers you want to use.
|
||||||
|
@ -104,7 +107,8 @@ class ExampleModel(RootModule):
|
||||||
@property
|
@property
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
return pytorch_dataloader('test')
|
return pytorch_dataloader('test')
|
||||||
|
|
||||||
|
# ---------------
|
||||||
# MODIFY YOUR COMMAND LINE ARGS
|
# MODIFY YOUR COMMAND LINE ARGS
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parent_parser):
|
def add_model_specific_args(parent_parser):
|
||||||
|
|
Loading…
Reference in New Issue