From b0d38d532d1db0ab45bc190f6b6a762531672f2c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 25 Jul 2019 12:01:52 -0400 Subject: [PATCH] updated docs --- pytorch_lightning/root_module/root_module.py | 27 +++++----- tests/debug.py | 53 +++++++++++++++++++- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 584c5aa4b4..4f55ed37c4 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -8,9 +8,8 @@ from pytorch_lightning.root_module.decorators import data_loader class LightningModule(GradInformation, ModelIO, ModelHooks): - def __init__(self, hparams): + def __init__(self): super(LightningModule, self).__init__() - self.hparams = hparams self.dtype = torch.FloatTensor self.exp_save_path = None @@ -64,18 +63,6 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): """ raise NotImplementedError - def summarize(self): - model_summary = ModelSummary(self) - print(model_summary) - - def freeze(self): - for param in self.parameters(): - param.requires_grad = False - - def unfreeze(self): - for param in self.parameters(): - param.requires_grad = True - @data_loader def tng_dataloader(self): """ @@ -128,5 +115,17 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): model.load_state_dict(checkpoint['state_dict'], strict=False) return model + def summarize(self): + model_summary = ModelSummary(self) + print(model_summary) + + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def unfreeze(self): + for param in self.parameters(): + param.requires_grad = True + diff --git a/tests/debug.py b/tests/debug.py index 2833b651a5..c4c3ffd19f 100644 --- a/tests/debug.py +++ b/tests/debug.py @@ -11,6 +11,55 @@ import os import shutil import pdb +import pytorch_lightning as ptl +import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torchvision.datasets import MNIST + + +class CoolModel(ptl.LightningModule): + + def __init(self): + super(CoolModel, self).__init__() + # not the best model... + self.l1 = torch.nn.Linear(28 * 28, 10) + + def forward(self, x): + return torch.relu(self.l1(x)) + + def my_loss(self, y_hat, y): + return F.cross_entropy(y_hat, y) + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'tng_loss': self.my_loss(y_hat, y)} + + def validation_step(self, batch, batch_nb): + x, y = batch + y_hat = self.forward(x) + return {'val_loss': self.my_loss(y_hat, y)} + + def validation_end(self, outputs): + avg_loss = torch.stack([x for x in outputs['val_loss']]).mean() + return avg_loss + + def configure_optimizers(self): + return [torch.optim.Adam(self.parameters(), lr=0.02)] + + @ptl.data_loader + def tng_dataloader(self): + return DataLoader(MNIST('path/to/save', train=True), batch_size=32) + + @ptl.data_loader + def val_dataloader(self): + return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + + @ptl.data_loader + def test_dataloader(self): + return DataLoader(MNIST('path/to/save', train=False), batch_size=32) + def get_model(): # set up model with these hyperparams @@ -94,11 +143,9 @@ def run_prediction(dataloader, trained_model): def main(): save_dir = init_save_dir() - model, hparams = get_model() # exp file to get meta exp = get_exp(False) - exp.argparse(hparams) exp.save() # exp file to get weights @@ -113,6 +160,8 @@ def main(): distributed_backend='dp', ) + model = CoolModel() + result = trainer.fit(model) # correct result and ok accuracy