Update README.md

This commit is contained in:
William Falcon 2019-03-30 21:24:46 -04:00 committed by GitHub
parent 985af56892
commit 649f4d5f4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 19 deletions

View File

@ -34,38 +34,27 @@ import torch.nn as nn
class ExampleModel(RootModule): class ExampleModel(RootModule):
def __init__(self): def __init__(self):
self.l1 = nn.Linear(100, 20) # define model
# --------------- # ---------------
# TRAINING # TRAINING
def training_step(self, data_batch): def training_step(self, data_batch):
# your dataloader decides what each batch looks like
x, y = data_batch x, y = data_batch
y_hat = self.l1(x) y_hat = self.l1(x)
loss = some_loss(y_hat) loss = some_loss(y_hat)
tqdm_dic = {'train_loss': loss} return loss_val, {'train_loss': loss}
# must return scalar, dict for logging
return loss_val, tqdm_dic
def validation_step(self, data_batch): def validation_step(self, data_batch):
# same as training...
x, y = data_batch x, y = data_batch
y_hat = self.l1(x) y_hat = self.l1(x)
loss = some_loss(y_hat) loss = some_loss(y_hat)
# val specific return loss_val, {'val_loss': loss}
acc = calculate_acc(y_hat, y)
tqdm_dic = {'train_loss': loss, 'val_acc': acc, 'whatever_you_want': 'a'}
return loss_val, tqdm_dic
def validation_end(self, outputs): def validation_end(self, outputs):
total_accs = [] total_accs = []
# given to you by the framework with all validation outputs.
# chance to collate
for output in outputs: for output in outputs:
total_accs.append(output['val_acc'].item()) total_accs.append(output['val_acc'].item())
@ -83,17 +72,14 @@ class ExampleModel(RootModule):
def load_model_specific(self, checkpoint): def load_model_specific(self, checkpoint):
# lightning loads for you. Here's your chance to say what you want to load # lightning loads for you. Here's your chance to say what you want to load
self.load_state_dict(checkpoint['state_dict']) self.load_state_dict(checkpoint['state_dict'])
pass
# --------------- # ---------------
# TRAINING CONFIG # TRAINING CONFIG
def configure_optimizers(self): def configure_optimizers(self):
# give lightning the list of optimizers you want to use. # give lightning the list of optimizers you want to use.
# lightning will call automatically # lightning will call automatically
optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer') optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
self.optimizers = [optimizer] return [optimizer]
return self.optimizers
# LIGHTING WILL USE THE LOADERS YOU DEFINE HERE # LIGHTING WILL USE THE LOADERS YOU DEFINE HERE
@property @property