# Pytorch-lightning The Keras for ML-research in Pytorch. Simple to use, but not as abstracted out. ## Usage To use lightning, first define a trainer function: ```python # trainer.py from pytorch_lightning.models.trainer import Trainer from pytorch_lightning.utils.pt_callbacks import EarlyStopping, ModelCheckpoint from my_project import My_Model # -------------------- # CALLBACKS early_stop = EarlyStopping( monitor='val_loss', patience=3, verbose=True, mode='min' ) model_save_path = 'PATH/TO/SAVE' checkpoint = ModelCheckpoint( filepath=model_save_path, save_function=None, save_best_only=True, verbose=True, monitor='val_acc', mode='min' ) # configure trainer trainer = Trainer( on_gpu=False, enable_tqdm=True, overfit_pct=None, track_grad_norm=-1, fast_dev_run=False, check_val_every_n_epoch=1, accumulate_grad_batches=2, process_position=0, current_gpu_name=0, checkpoint_callback=checkpoint, early_stop_callback=early_stop, enable_early_stop=True, max_nb_epochs=12, min_nb_epochs=2, train_percent_check=1.0, val_percent_check=0.5, test_percent_check=0.5, val_check_interval=0.95, log_save_interval=0.95, add_log_row_interval=20, lr_scheduler_milestones=None ) # init model model = My_Model() trainer.fit(model) ``` next define a model that implements these 10 functions: ```python import torch.nn as nn class ExampleModel(RootModule): def __init__(self): # define model # --------------- # TRAINING def training_step(self, data_batch): x, y = data_batch y_hat = self.l1(x) loss = some_loss(y_hat) return loss_val, {'train_loss': loss} def validation_step(self, data_batch): x, y = data_batch y_hat = self.l1(x) loss = some_loss(y_hat) return loss_val, {'val_loss': loss} def validation_end(self, outputs): total_accs = [] for output in outputs: total_accs.append(output['val_acc'].item()) # return a dict return {'total_acc': np.mean(total_accs)} # --------------- # SAVING def get_save_dict(self): # lightning saves for you. Here's your chance to say what you want to save checkpoint = {'state_dict': self.state_dict()} return checkpoint def load_model_specific(self, checkpoint): # lightning loads for you. Here's your chance to say what you want to load self.load_state_dict(checkpoint['state_dict']) # --------------- # TRAINING CONFIG def configure_optimizers(self): # give lightning the list of optimizers you want to use. # lightning will call automatically optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer') return [optimizer] @property def tng_dataloader(self): return pytorch_dataloader('train') @property def val_dataloader(self): return pytorch_dataloader('val') @property def test_dataloader(self): return pytorch_dataloader('test') # --------------- # MODIFY YOUR COMMAND LINE ARGS @staticmethod def add_model_specific_args(parent_parser): parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) parser.add_argument('--out_features', default=20) return parser ``` ### Details #### Model definition | Name | Description | Input | Return | |---|---|---|---| | training_step | Called with a batch of data during training | data from your dataloaders | tuple: scalar, dict | | validation_step | Called with a batch of data during validation | data from your dataloaders | tuple: scalar, dict | | validation_end | Collate metrics from all validation steps | outputs: array where each item is the output of a validation step | dict: for logging | | get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | #### Model training | Name | Description | Input | Return | |---|---|---|---| | configure_optimizers | called during training setup | None | list: optimizers you want to use | | tng_dataloader | called during training | None | pytorch dataloader | | val_dataloader | called during validation | None | pytorch dataloader | | test_dataloader | called during testing | None | pytorch dataloader | | add_model_specific_args | called with args you defined in your main. This lets you tailor args for each model and keep main the same | argparse | argparse | #### Model Saving/Loading | Name | Description | Input | Return | |---|---|---|---| | get_save_dict | called when your model needs to be saved (checkpoints, hpc save, etc...) | None | dict to be saved | | load_model_specific | called when loading a model | checkpoint: dict you created in get_save_dict | dict: modified in whatever way you want | ### Add new model 1. Create a new model under /models. 2. Add model name to trainer_main ```python AVAILABLE_MODELS = { 'model_1': ExampleModel1 } ``` ### Model methods that can be implemented | Method | Purpose | Input | Output | Required | |---|---|---|---|---| | forward() | Forward pass | model_in tuple with your data | model_out tuple to be passed to loss | Y | | loss() | calculate model loss | model_out tuple from forward() | A scalar | Y | | check_performance() | run a full loop through val data to check for metrics | dataloader, nb_tests | metrics tuple to be tracked | Y | | tng_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y | | val_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y | | test_dataloader | Computed option, used to feed tng data | - | Pytorch DataLoader subclass | Y | ### Model lifecycle hooks Use these hooks to customize functionality | Method | Purpose | Input | Output | Required | |---|---|---|---|---| | on_batch_start() | called right before the batch starts | - | - | N | | on_batch_end() | called right after the batch ends | - | - | N | | on_epoch_start() | called right before the epoch starts | - | - | N | | on_epoch_end() | called right afger the epoch ends | - | - | N | | on_pre_performance_check() | called right before the performance check starts | - | - | N | | on_post_performance_check() | called right after the batch starts | - | - | N |