updated docs

This commit is contained in:
William Falcon 2019-07-25 12:01:52 -04:00
parent 4562580461
commit b0d38d532d
2 changed files with 64 additions and 16 deletions

View File

@ -8,9 +8,8 @@ from pytorch_lightning.root_module.decorators import data_loader
class LightningModule(GradInformation, ModelIO, ModelHooks): class LightningModule(GradInformation, ModelIO, ModelHooks):
def __init__(self, hparams): def __init__(self):
super(LightningModule, self).__init__() super(LightningModule, self).__init__()
self.hparams = hparams
self.dtype = torch.FloatTensor self.dtype = torch.FloatTensor
self.exp_save_path = None self.exp_save_path = None
@ -64,18 +63,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
""" """
raise NotImplementedError raise NotImplementedError
def summarize(self):
model_summary = ModelSummary(self)
print(model_summary)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.parameters():
param.requires_grad = True
@data_loader @data_loader
def tng_dataloader(self): def tng_dataloader(self):
""" """
@ -128,5 +115,17 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
model.load_state_dict(checkpoint['state_dict'], strict=False) model.load_state_dict(checkpoint['state_dict'], strict=False)
return model return model
def summarize(self):
model_summary = ModelSummary(self)
print(model_summary)
def freeze(self):
for param in self.parameters():
param.requires_grad = False
def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

View File

@ -11,6 +11,55 @@ import os
import shutil import shutil
import pdb import pdb
import pytorch_lightning as ptl
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
class CoolModel(ptl.LightningModule):
def __init(self):
super(CoolModel, self).__init__()
# not the best model...
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x))
def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)
def training_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'tng_loss': self.my_loss(y_hat, y)}
def validation_step(self, batch, batch_nb):
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}
def validation_end(self, outputs):
avg_loss = torch.stack([x for x in outputs['val_loss']]).mean()
return avg_loss
def configure_optimizers(self):
return [torch.optim.Adam(self.parameters(), lr=0.02)]
@ptl.data_loader
def tng_dataloader(self):
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
@ptl.data_loader
def val_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
@ptl.data_loader
def test_dataloader(self):
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
def get_model(): def get_model():
# set up model with these hyperparams # set up model with these hyperparams
@ -94,11 +143,9 @@ def run_prediction(dataloader, trained_model):
def main(): def main():
save_dir = init_save_dir() save_dir = init_save_dir()
model, hparams = get_model()
# exp file to get meta # exp file to get meta
exp = get_exp(False) exp = get_exp(False)
exp.argparse(hparams)
exp.save() exp.save()
# exp file to get weights # exp file to get weights
@ -113,6 +160,8 @@ def main():
distributed_backend='dp', distributed_backend='dp',
) )
model = CoolModel()
result = trainer.fit(model) result = trainer.fit(model)
# correct result and ok accuracy # correct result and ok accuracy