diff --git a/README.md b/README.md index 259f63cae6..9d6e01eb97 100644 --- a/README.md +++ b/README.md @@ -152,26 +152,9 @@ class ExampleModel(RootModule): | get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | | load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want | - -### Add new model -1. Create a new model under /models. -2. Add model name to trainer_main -```python -AVAILABLE_MODELS = { - 'model_1': ExampleModel1 -} -``` +## Optional model hooks. +Add these to the model whenever you want to configure training behavior. -### Model methods that can be implemented - -| Method | Purpose | Input | Output | Required | -|---|---|---|---|---| -| forward() | Forward pass | model_in tuple with your data | model_out tuple to be passed to loss | Y | -| loss() | calculate model loss | model_out tuple from forward() | A scalar | Y | -| check_performance() | run a full loop through val data to check for metrics | dataloader, nb_tests | metrics tuple to be tracked | Y | -| tng_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y | -| val_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y | -| test_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y | ### Model lifecycle hooks Use these hooks to customize functionality