Update README.md
This commit is contained in:
parent
a7406bb752
commit
49de749945
21
README.md
21
README.md
|
@ -152,26 +152,9 @@ class ExampleModel(RootModule):
|
|||
| get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved |
|
||||
| load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want |
|
||||
|
||||
|
||||
### Add new model
|
||||
1. Create a new model under /models.
|
||||
2. Add model name to trainer_main
|
||||
```python
|
||||
AVAILABLE_MODELS = {
|
||||
'model_1': ExampleModel1
|
||||
}
|
||||
```
|
||||
## Optional model hooks.
|
||||
Add these to the model whenever you want to configure training behavior.
|
||||
|
||||
### Model methods that can be implemented
|
||||
|
||||
| Method | Purpose | Input | Output | Required |
|
||||
|---|---|---|---|---|
|
||||
| forward() | Forward pass | model_in tuple with your data | model_out tuple to be passed to loss | Y |
|
||||
| loss() | calculate model loss | model_out tuple from forward() | A scalar | Y |
|
||||
| check_performance() | run a full loop through val data to check for metrics | dataloader, nb_tests | metrics tuple to be tracked | Y |
|
||||
| tng_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
|
||||
| val_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
|
||||
| test_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y |
|
||||
|
||||
### Model lifecycle hooks
|
||||
Use these hooks to customize functionality
|
||||
|
|
Loading…
Reference in New Issue