Update README.md

This commit is contained in:
William Falcon 2019-03-30 21:42:33 -04:00 committed by GitHub
parent a7406bb752
commit 49de749945
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 19 deletions

View File

@ -152,26 +152,9 @@ class ExampleModel(RootModule):
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | | get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want | | load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
## Optional model hooks.
Add these to the model whenever you want to configure training behavior.
### Add new model
1. Create a new model under /models.
2. Add model name to trainer_main
```python
AVAILABLE_MODELS = {
'model_1': ExampleModel1
}
```
### Model methods that can be implemented
| Method | Purpose | Input | Output | Required |
|---|---|---|---|---|
| forward() | Forward pass | model_in tuple with your data | model_out tuple to be passed to loss | Y |
| loss() | calculate model loss | model_out tuple from forward() | A scalar | Y |
| check_performance() | run a full loop through val data to check for metrics | dataloader, nb_tests | metrics tuple to be tracked | Y |
| tng_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
| val_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
| test_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
### Model lifecycle hooks ### Model lifecycle hooks
Use these hooks to customize functionality Use these hooks to customize functionality