From 63d84283a4cd91578c413b2ad6aaaec80e3fa3bc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Fri, 28 Jun 2019 17:14:18 -0400 Subject: [PATCH] removed checkpoint save_function option --- docs/Trainer/Checkpointing.md | 18 ++++++++++++++++++ docs/Trainer/index.md | 6 ++---- pytorch_lightning/callbacks/pt_callbacks.py | 3 +-- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/docs/Trainer/Checkpointing.md b/docs/Trainer/Checkpointing.md index e69de29bb2..99bf6f7ee0 100644 --- a/docs/Trainer/Checkpointing.md +++ b/docs/Trainer/Checkpointing.md @@ -0,0 +1,18 @@ +Lightning can automate saving and loading checkpoints. + +--- +### Model saving +To enable checkpointing, define the checkpoint callback + +``` {.python} +from pytorch_lightning.utils.pt_callbacks import ModelCheckpoint + +checkpoint = ModelCheckpoint( + filepath='/path/to/store/weights.ckpt', + save_function=None, + save_best_only=not hparams.keep_all_checkpoints, + verbose=True, + monitor=hparams.model_save_monitor_value, + mode=hparams.model_save_monitor_mode +) +``` \ No newline at end of file diff --git a/docs/Trainer/index.md b/docs/Trainer/index.md index 197d8cfc99..1b30da1966 100644 --- a/docs/Trainer/index.md +++ b/docs/Trainer/index.md @@ -24,10 +24,8 @@ But of course the fun is in all the advanced things it can do: **Computing cluster (SLURM)** -- Automatic checkpointing -- Automatic saving, loading -- Running grid search on a cluster -- Walltime auto-resubmit +- [Running grid search on a cluster](SLURM%20Managed%20Cluster/#running-grid-search-on-a-cluster) +- [Walltime auto-resubmit](SLURM%20Managed%20Cluster/#walltime-auto-resubmit) **Debugging** diff --git a/pytorch_lightning/callbacks/pt_callbacks.py b/pytorch_lightning/callbacks/pt_callbacks.py index 46e6195dfd..e560b306de 100644 --- a/pytorch_lightning/callbacks/pt_callbacks.py +++ b/pytorch_lightning/callbacks/pt_callbacks.py @@ -170,12 +170,11 @@ class ModelCheckpoint(Callback): period: Interval (number of epochs) between checkpoints. """ - def __init__(self, filepath, save_function, monitor='val_loss', verbose=0, + def __init__(self, filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1, prefix=''): super(ModelCheckpoint, self).__init__() self.monitor = monitor - self.save_function = save_function self.verbose = verbose self.filepath = filepath self.save_best_only = save_best_only