Update README.md
This commit is contained in:
parent
985af56892
commit
649f4d5f4d
24
README.md
24
README.md
|
@ -34,38 +34,27 @@ import torch.nn as nn
|
|||
|
||||
class ExampleModel(RootModule):
|
||||
def __init__(self):
|
||||
self.l1 = nn.Linear(100, 20)
|
||||
# define model
|
||||
|
||||
# ---------------
|
||||
# TRAINING
|
||||
def training_step(self, data_batch):
|
||||
# your dataloader decides what each batch looks like
|
||||
x, y = data_batch
|
||||
y_hat = self.l1(x)
|
||||
loss = some_loss(y_hat)
|
||||
|
||||
tqdm_dic = {'train_loss': loss}
|
||||
|
||||
# must return scalar, dict for logging
|
||||
return loss_val, tqdm_dic
|
||||
return loss_val, {'train_loss': loss}
|
||||
|
||||
def validation_step(self, data_batch):
|
||||
# same as training...
|
||||
x, y = data_batch
|
||||
y_hat = self.l1(x)
|
||||
loss = some_loss(y_hat)
|
||||
|
||||
# val specific
|
||||
acc = calculate_acc(y_hat, y)
|
||||
|
||||
tqdm_dic = {'train_loss': loss, 'val_acc': acc, 'whatever_you_want': 'a'}
|
||||
return loss_val, tqdm_dic
|
||||
return loss_val, {'val_loss': loss}
|
||||
|
||||
def validation_end(self, outputs):
|
||||
total_accs = []
|
||||
|
||||
# given to you by the framework with all validation outputs.
|
||||
# chance to collate
|
||||
for output in outputs:
|
||||
total_accs.append(output['val_acc'].item())
|
||||
|
||||
|
@ -83,17 +72,14 @@ class ExampleModel(RootModule):
|
|||
def load_model_specific(self, checkpoint):
|
||||
# lightning loads for you. Here's your chance to say what you want to load
|
||||
self.load_state_dict(checkpoint['state_dict'])
|
||||
pass
|
||||
|
||||
|
||||
# ---------------
|
||||
# TRAINING CONFIG
|
||||
def configure_optimizers(self):
|
||||
# give lightning the list of optimizers you want to use.
|
||||
# lightning will call automatically
|
||||
optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
|
||||
self.optimizers = [optimizer]
|
||||
return self.optimizers
|
||||
optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
|
||||
return [optimizer]
|
||||
|
||||
# LIGHTING WILL USE THE LOADERS YOU DEFINE HERE
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue