From 9f7caa213178a0830cf5aae914293481fb6fad6d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sun, 31 Mar 2019 16:30:55 -0400 Subject: [PATCH] added example and verified --- README.md | 109 +++++++++++++++++++++++++++++++++++------------------- 1 file changed, 70 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 79e8fdc656..245004c279 100644 --- a/README.md +++ b/README.md @@ -53,51 +53,82 @@ To use lightning do 2 things: ```python # trainer.py -from pytorch_lightning.models.trainer import Trainer +import os +import sys + +from test_tube import HyperOptArgumentParser, Experiment +from pytorch_lightning.models.trainer import Trainer +from pytorch_lightning.utils.arg_parse import add_default_args from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint -from my_project import My_Model -from test_tube import HyperOptArgumentParser, Experiment, SlurmCluster +from demo.example_model import ExampleModel -# -------------- -# TEST TUBE INIT -exp = Experiment( - name='my_exp', - debug=True, - save_dir='/some/path', - autosave=False, - description='my desc' -) -# -------------------- -# CALLBACKS -early_stop = EarlyStopping( - monitor='val_loss', - patience=3, - verbose=True, - mode='min' -) +def main(hparams): + """ + Main training routine specific for this project + :param hparams: + :return: + """ + # init experiment + exp = Experiment( + name=hparams.tt_name, + debug=hparams.debug, + save_dir=hparams.tt_save_path, + version=hparams.hpc_exp_number, + autosave=False, + description=hparams.tt_description + ) -model_save_path = 'PATH/TO/SAVE' -checkpoint = ModelCheckpoint( - filepath=model_save_path, - save_function=None, - save_best_only=True, - verbose=True, - monitor='val_acc', - mode='min' -) + exp.argparse(hparams) + exp.save() -# configure trainer -trainer = Trainer( - experiment=experiment, - cluster=cluster, - checkpoint_callback=checkpoint, - early_stop_callback=early_stop -) + # build model + print('loading model...') + model = ExampleModel(hparams) + print('model built') -# init model and train -model = My_Model() -trainer.fit(model) + # callbacks + early_stop = EarlyStopping( + monitor=hparams.early_stop_metric, + patience=hparams.early_stop_patience, + verbose=True, + mode=hparams.early_stop_mode + ) + + model_save_path = '{}/{}/{}'.format(hparams.model_save_path, exp.name, exp.version) + checkpoint = ModelCheckpoint( + filepath=model_save_path, + save_function=None, + save_best_only=True, + verbose=True, + monitor=hparams.model_save_monitor_value, + mode=hparams.model_save_monitor_mode + ) + + # configure trainer + trainer = Trainer( + experiment=exp, + checkpoint_callback=checkpoint, + early_stop_callback=early_stop, + ) + + # train model + trainer.fit(model) + + +if __name__ == '__main__': + + # use default args given by lightning + root_dir = os.path.split(os.path.dirname(sys.modules['__main__'].__file__))[0] + parent_parser = HyperOptArgumentParser(strategy='random_search', add_help=False) + add_default_args(parent_parser, root_dir) + + # allow model to overwrite or extend args + parser = ExampleModel.add_model_specific_args(parent_parser) + hyperparams = parser.parse_args() + + # train model + main(hyperparams) ``` #### Define the model