Update README.md
This commit is contained in:
parent
985af56892
commit
649f4d5f4d
24
README.md
24
README.md
|
@ -34,38 +34,27 @@ import torch.nn as nn
|
||||||
|
|
||||||
class ExampleModel(RootModule):
|
class ExampleModel(RootModule):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.l1 = nn.Linear(100, 20)
|
# define model
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# TRAINING
|
# TRAINING
|
||||||
def training_step(self, data_batch):
|
def training_step(self, data_batch):
|
||||||
# your dataloader decides what each batch looks like
|
|
||||||
x, y = data_batch
|
x, y = data_batch
|
||||||
y_hat = self.l1(x)
|
y_hat = self.l1(x)
|
||||||
loss = some_loss(y_hat)
|
loss = some_loss(y_hat)
|
||||||
|
|
||||||
tqdm_dic = {'train_loss': loss}
|
return loss_val, {'train_loss': loss}
|
||||||
|
|
||||||
# must return scalar, dict for logging
|
|
||||||
return loss_val, tqdm_dic
|
|
||||||
|
|
||||||
def validation_step(self, data_batch):
|
def validation_step(self, data_batch):
|
||||||
# same as training...
|
|
||||||
x, y = data_batch
|
x, y = data_batch
|
||||||
y_hat = self.l1(x)
|
y_hat = self.l1(x)
|
||||||
loss = some_loss(y_hat)
|
loss = some_loss(y_hat)
|
||||||
|
|
||||||
# val specific
|
return loss_val, {'val_loss': loss}
|
||||||
acc = calculate_acc(y_hat, y)
|
|
||||||
|
|
||||||
tqdm_dic = {'train_loss': loss, 'val_acc': acc, 'whatever_you_want': 'a'}
|
|
||||||
return loss_val, tqdm_dic
|
|
||||||
|
|
||||||
def validation_end(self, outputs):
|
def validation_end(self, outputs):
|
||||||
total_accs = []
|
total_accs = []
|
||||||
|
|
||||||
# given to you by the framework with all validation outputs.
|
|
||||||
# chance to collate
|
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
total_accs.append(output['val_acc'].item())
|
total_accs.append(output['val_acc'].item())
|
||||||
|
|
||||||
|
@ -83,17 +72,14 @@ class ExampleModel(RootModule):
|
||||||
def load_model_specific(self, checkpoint):
|
def load_model_specific(self, checkpoint):
|
||||||
# lightning loads for you. Here's your chance to say what you want to load
|
# lightning loads for you. Here's your chance to say what you want to load
|
||||||
self.load_state_dict(checkpoint['state_dict'])
|
self.load_state_dict(checkpoint['state_dict'])
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# TRAINING CONFIG
|
# TRAINING CONFIG
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
# give lightning the list of optimizers you want to use.
|
# give lightning the list of optimizers you want to use.
|
||||||
# lightning will call automatically
|
# lightning will call automatically
|
||||||
optimizer = self.choose_optimizer(self.hparams.optimizer_name, self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
|
optimizer = self.choose_optimizer('adam', self.parameters(), {'lr': self.hparams.learning_rate}, 'optimizer')
|
||||||
self.optimizers = [optimizer]
|
return [optimizer]
|
||||||
return self.optimizers
|
|
||||||
|
|
||||||
# LIGHTING WILL USE THE LOADERS YOU DEFINE HERE
|
# LIGHTING WILL USE THE LOADERS YOU DEFINE HERE
|
||||||
@property
|
@property
|
||||||
|
|
Loading…
Reference in New Issue