diff --git a/README.md b/README.md index 03cee59505..c0fe260869 100644 --- a/README.md +++ b/README.md @@ -34,38 +34,27 @@ import torch.nn as nn class ExampleModel(RootModule): def __init__(self): - self.l1 = nn.Linear(100, 20) + # define model # --------------- # TRAINING def training_step(self, data_batch): - # your dataloader decides what each batch looks like x, y = data_batch y_hat = self.l1(x) loss = some_loss(y_hat) - tqdm_dic = {'train_loss': loss} - - # must return scalar, dict for logging - return loss_val, tqdm_dic + return loss_val, {'train_loss': loss} def validation_step(self, data_batch): - # same as training... x, y = data_batch y_hat = self.l1(x) loss = some_loss(y_hat) - # val specific - acc = calculate_acc(y_hat, y) - - tqdm_dic = {'train_loss': loss, 'val_acc': acc, 'whatever_you_want': 'a'} - return loss_val, tqdm_dic + return loss_val, {'val_loss': loss} def validation_end(self, outputs): total_accs = [] - # given to you by the framework with all validation outputs. - # chance to collate for output in outputs: total_accs.append(output['val_acc'].item()) @@ -83,17 +72,14 @@ class ExampleModel(RootModule): def load_model_specific(self, checkpoint): # lightning loads for you. Here's your chance to say what you want to load self.load_state_dict(checkpoint['state_dict']) - pass - # --------------- # TRAINING CONFIG def configure_optimizers(self): # give lightning the list of optimizers you want to use. # lightning will call automatically - optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer') - self.optimizers = [optimizer] - return self.optimizers + optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer') + return [optimizer] # LIGHTING WILL USE THE LOADERS YOU DEFINE HERE @property