val and test are optional now (#95)

* made validation step optional

* added no val model

* val_step can be implemented but not validation_end

* added no val end model

* added tests

* added tests

* remove class

* remove class

* remove class

* remove class

* remove class

* remove class

* remove class

* remove class

* remove class

* remove class

* remove class

* updated docs

* updated docs

* updated test

* updated test

* updated test

* updated test

* updated test

* updated test

* updated test

* updated test

* updated test

* fix pep8
This commit is contained in:
William Falcon 2019-08-11 10:01:57 -04:00 committed by GitHub
parent 996b1f9a6d
commit e5805bf8ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 609 additions and 53 deletions

View File

@ -81,36 +81,40 @@ class CoolModel(pl.LightningModule):
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def my_loss(self, y_hat, y):
return F.cross_entropy(y_hat, y)
def training_step(self, batch, batch_nb):
# REQUIRED
x, y = batch
y_hat = self.forward(x)
return {'loss': self.my_loss(y_hat, y)}
return {'loss': F.cross_entropy(y_hat, y)}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x, y = batch
y_hat = self.forward(x)
return {'val_loss': self.my_loss(y_hat, y)}
def validation_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
return {'avg_val_loss': avg_loss}
def configure_optimizers(self):
# REQUIRED
return [torch.optim.Adam(self.parameters(), lr=0.02)]
@pl.data_loader
def tng_dataloader(self):
# REQUIRED
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@pl.data_loader
def val_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
@pl.data_loader
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)
```

View File

@ -10,16 +10,14 @@ Otherwise, to Define a Lightning Module, implement the following methods:
**Required**:
- [training_step](RequiredTrainerInterface.md#training_step)
- [validation_step](RequiredTrainerInterface.md#validation_step)
- [validation_end](RequiredTrainerInterface.md#validation_end)
- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader)
- [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers)
- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader)
- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader)
- [test_dataloader](RequiredTrainerInterface.md#test_dataloader)
**Optional**:
- [validation_step](RequiredTrainerInterface.md#validation_step)
- [validation_end](RequiredTrainerInterface.md#validation_end)
- [val_dataloader](RequiredTrainerInterface.md#val_dataloader)
- [test_dataloader](RequiredTrainerInterface.md#test_dataloader)
- [on_save_checkpoint](RequiredTrainerInterface.md#on_save_checkpoint)
- [on_load_checkpoint](RequiredTrainerInterface.md#on_load_checkpoint)

View File

@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
import torch.distributed as dist
from pytorch_lightning.root_module.root_module import LightningModule
from pytorch_lightning.root_module.memory import get_gpu_memory_map
from pytorch_lightning.root_module.model_saving import TrainerIO
from pytorch_lightning.pt_overrides.override_data_parallel import (
@ -312,6 +313,14 @@ class Trainer(TrainerIO):
f_op = getattr(model, f_name, None)
return callable(f_op)
def __is_overriden(self, f_name):
model = self.__get_model()
super_object = super(model.__class__, model)
# when code pointers are different, it was overriden
is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__
return is_overriden
@property
def __tng_tqdm_dic(self):
tqdm_dic = {
@ -345,13 +354,13 @@ class Trainer(TrainerIO):
self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check)
# determine number of validation batches
self.nb_val_batches = len(self.val_dataloader)
self.nb_val_batches = len(self.val_dataloader) if self.val_dataloader is not None else 0
self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check)
self.nb_val_batches = max(1, self.nb_val_batches)
self.nb_val_batches = self.nb_val_batches
# determine number of test batches
self.nb_test_batches = len(self.test_dataloader)
self.nb_test_batches = len(self.test_dataloader) if self.test_dataloader is not None else 0
self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check)
# determine when to check validation
@ -372,6 +381,10 @@ class Trainer(TrainerIO):
:param max_batches: Scalar
:return:
"""
# skip validation if model has no validation_step defined
if not self.__is_overriden('validation_step'):
return {}
# enable eval mode
model.zero_grad()
model.eval()
@ -418,11 +431,13 @@ class Trainer(TrainerIO):
if self.progress_bar and self.prog_bar is not None:
self.prog_bar.update(1)
# give model a chance to do something with the outputs
if self.data_parallel:
val_results = model.module.validation_end(outputs)
else:
val_results = model.validation_end(outputs)
# give model a chance to do something with the outputs (and method defined)
val_results = {}
if self.__is_overriden('validation_end'):
if self.data_parallel:
val_results = model.module.validation_end(outputs)
else:
val_results = model.validation_end(outputs)
# enable train mode again
model.train()
@ -439,6 +454,7 @@ class Trainer(TrainerIO):
:return:
"""
self.tng_dataloader = model.tng_dataloader
self.test_dataloader = model.test_dataloader
self.val_dataloader = model.val_dataloader

View File

@ -36,18 +36,20 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
def validation_step(self, data_batch, batch_nb):
"""
return whatever outputs will need to be aggregated in validation_end
OPTIONAL
:param data_batch:
:return:
"""
raise NotImplementedError
pass
def validation_end(self, outputs):
"""
Outputs has the appended output after each validation step
OPTIONAL
:param outputs:
:return: dic_with_metrics for tqdm
"""
raise NotImplementedError
pass
def training_step(self, data_batch, batch_nb):
"""
@ -67,7 +69,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
@data_loader
def tng_dataloader(self):
"""
Implement a function to load an h5py of this data
Implement a PyTorch DataLoader
:return:
"""
raise NotImplementedError
@ -75,18 +77,18 @@ class LightningModule(GradInformation, ModelIO, ModelHooks):
@data_loader
def test_dataloader(self):
"""
Implement a function to load an h5py of this data
Implement a PyTorch DataLoader
:return:
"""
raise NotImplementedError
return None
@data_loader
def val_dataloader(self):
"""
Implement a function to load an h5py of this data
Implement a PyTorch DataLoader
:return:
"""
raise NotImplementedError
return None
@classmethod
def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None):

View File

@ -0,0 +1,3 @@
from .lm_test_module import LightningTestModel
from .no_val_end_module import NoValEndTestModel
from .no_val_module import NoValModel

View File

@ -0,0 +1,247 @@
import os
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision import transforms
from test_tube import HyperOptArgumentParser
from pytorch_lightning.root_module.root_module import LightningModule
from pytorch_lightning import data_loader
class NoValEndTestModel(LightningModule):
"""
Sample model to show how to define a template
"""
def __init__(self, hparams, force_remove_distributed_sampler=False):
"""
Pass in parsed HyperOptArgumentParser to the model
:param hparams:
"""
# init superclass
super(NoValEndTestModel, self).__init__()
self.hparams = hparams
self.batch_size = hparams.batch_size
# if you specify an example input, the summary will show input/output for each layer
self.example_input_array = torch.rand(5, 28 * 28)
# remove to test warning for dist sampler
self.force_remove_distributed_sampler = force_remove_distributed_sampler
# build model
self.__build_model()
# ---------------------
# MODEL SETUP
# ---------------------
def __build_model(self):
"""
Layout model
:return:
"""
self.c_d1 = nn.Linear(in_features=self.hparams.in_features,
out_features=self.hparams.hidden_dim)
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim,
out_features=self.hparams.out_features)
# ---------------------
# TRAINING
# ---------------------
def forward(self, x):
"""
No special modification required for lightning, define as you normally would
:param x:
:return:
"""
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
logits = F.log_softmax(x, dim=1)
return logits
def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll
def training_step(self, data_batch, batch_i):
"""
Lightning calls this inside the training loop
:param data_batch:
:return:
"""
# forward pass
x, y = data_batch
x = x.view(x.size(0), -1)
y_hat = self.forward(x)
# calculate loss
loss_val = self.loss(y, y_hat)
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp:
loss_val = loss_val.unsqueeze(0)
# alternate possible outputs to test
if self.trainer.batch_nb % 1 == 0:
output = OrderedDict({
'loss': loss_val,
'prog': {'some_val': loss_val * loss_val}
})
return output
if self.trainer.batch_nb % 2 == 0:
return loss_val
def validation_step(self, data_batch, batch_i):
"""
Lightning calls this inside the validation loop
:param data_batch:
:return:
"""
x, y = data_batch
x = x.view(x.size(0), -1)
y_hat = self.forward(x)
loss_val = self.loss(y, y_hat)
# acc
labels_hat = torch.argmax(y_hat, dim=1)
val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
val_acc = torch.tensor(val_acc)
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp:
loss_val = loss_val.unsqueeze(0)
val_acc = val_acc.unsqueeze(0)
# alternate possible outputs to test
if batch_i % 1 == 0:
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
})
return output
if batch_i % 2 == 0:
return val_acc
if batch_i % 3 == 0:
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,
'test_dic': {'val_loss_a': loss_val}
})
return output
def on_tng_metrics(self, logs):
logs['some_tensor_to_test'] = torch.rand(1)
# ---------------------
# TRAINING SETUP
# ---------------------
def configure_optimizers(self):
"""
return whatever optimizers we want here
:return: list of optimizers
"""
# try no scheduler for this model (testing purposes)
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
# test returning only 1 list instead of 2
return [optimizer]
def __dataloader(self, train):
# init data generators
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=train,
transform=transform, download=True)
# when using multi-node we need to add the datasampler
train_sampler = None
batch_size = self.hparams.batch_size
try:
if self.on_gpu and not self.force_remove_distributed_sampler:
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
batch_size = batch_size // self.trainer.world_size # scale batch size
except Exception:
pass
should_shuffle = train_sampler is None
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=should_shuffle,
sampler=train_sampler
)
return loader
@data_loader
def tng_dataloader(self):
return self.__dataloader(train=True)
@data_loader
def val_dataloader(self):
return self.__dataloader(train=False)
@data_loader
def test_dataloader(self):
return self.__dataloader(train=False)
@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
"""
Parameters you define here will be available to your model through self.hparams
:param parent_parser:
:param root_dir:
:return:
"""
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
# param overwrites
# parser.set_defaults(gradient_clip=5.0)
# network params
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
parser.add_argument('--in_features', default=28 * 28, type=int)
parser.add_argument('--out_features', default=10, type=int)
# use 500 for CPU, 50000 for GPU to see speed difference
parser.add_argument('--hidden_dim', default=50000, type=int)
# data
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
# training params (opt)
parser.opt_list('--learning_rate', default=0.001 * 8, type=float,
options=[0.0001, 0.0005, 0.001, 0.005],
tunable=False)
parser.opt_list('--optimizer_name', default='adam', type=str,
options=['adam'], tunable=False)
# if using 2 nodes with 4 gpus each the batch size here
# (256) will be 256 / (2*8) = 16 per gpu
parser.opt_list('--batch_size', default=256 * 8, type=int,
options=[32, 64, 128, 256], tunable=False,
help='batch size will be divided over all gpus being used across all nodes')
return parser

View File

@ -0,0 +1,196 @@
import os
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import MNIST
from torchvision import transforms
from test_tube import HyperOptArgumentParser
from pytorch_lightning.root_module.root_module import LightningModule
from pytorch_lightning import data_loader
class NoValModel(LightningModule):
"""
Sample model to show how to define a template
"""
def __init__(self, hparams, force_remove_distributed_sampler=False):
"""
Pass in parsed HyperOptArgumentParser to the model
:param hparams:
"""
# init superclass
super(NoValModel, self).__init__()
self.hparams = hparams
self.batch_size = hparams.batch_size
# if you specify an example input, the summary will show input/output for each layer
self.example_input_array = torch.rand(5, 28 * 28)
# remove to test warning for dist sampler
self.force_remove_distributed_sampler = force_remove_distributed_sampler
# build model
self.__build_model()
# ---------------------
# MODEL SETUP
# ---------------------
def __build_model(self):
"""
Layout model
:return:
"""
self.c_d1 = nn.Linear(in_features=self.hparams.in_features,
out_features=self.hparams.hidden_dim)
self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)
self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim,
out_features=self.hparams.out_features)
# ---------------------
# TRAINING
# ---------------------
def forward(self, x):
"""
No special modification required for lightning, define as you normally would
:param x:
:return:
"""
x = self.c_d1(x)
x = torch.tanh(x)
x = self.c_d1_bn(x)
x = self.c_d1_drop(x)
x = self.c_d2(x)
logits = F.log_softmax(x, dim=1)
return logits
def loss(self, labels, logits):
nll = F.nll_loss(logits, labels)
return nll
def training_step(self, data_batch, batch_i):
"""
Lightning calls this inside the training loop
:param data_batch:
:return:
"""
# forward pass
x, y = data_batch
x = x.view(x.size(0), -1)
y_hat = self.forward(x)
# calculate loss
loss_val = self.loss(y, y_hat)
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp:
loss_val = loss_val.unsqueeze(0)
# alternate possible outputs to test
if self.trainer.batch_nb % 1 == 0:
output = OrderedDict({
'loss': loss_val,
'prog': {'some_val': loss_val * loss_val}
})
return output
if self.trainer.batch_nb % 2 == 0:
return loss_val
def on_tng_metrics(self, logs):
logs['some_tensor_to_test'] = torch.rand(1)
# ---------------------
# TRAINING SETUP
# ---------------------
def configure_optimizers(self):
"""
return whatever optimizers we want here
:return: list of optimizers
"""
# try no scheduler for this model (testing purposes)
optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
# test returning only 1 list instead of 2
return [optimizer]
def __dataloader(self, train):
# init data generators
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (1.0,))])
dataset = MNIST(root=self.hparams.data_root, train=train,
transform=transform, download=True)
# when using multi-node we need to add the datasampler
train_sampler = None
batch_size = self.hparams.batch_size
try:
if self.on_gpu and not self.force_remove_distributed_sampler:
train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank)
batch_size = batch_size // self.trainer.world_size # scale batch size
except Exception:
pass
should_shuffle = train_sampler is None
loader = DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=should_shuffle,
sampler=train_sampler
)
return loader
@data_loader
def tng_dataloader(self):
return self.__dataloader(train=True)
@staticmethod
def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
"""
Parameters you define here will be available to your model through self.hparams
:param parent_parser:
:param root_dir:
:return:
"""
parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser])
# param overwrites
# parser.set_defaults(gradient_clip=5.0)
# network params
parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False)
parser.add_argument('--in_features', default=28 * 28, type=int)
parser.add_argument('--out_features', default=10, type=int)
# use 500 for CPU, 50000 for GPU to see speed difference
parser.add_argument('--hidden_dim', default=50000, type=int)
# data
parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str)
# training params (opt)
parser.opt_list('--learning_rate', default=0.001 * 8, type=float,
options=[0.0001, 0.0005, 0.001, 0.005],
tunable=False)
parser.opt_list('--optimizer_name', default='adam', type=str,
options=['adam'], tunable=False)
# if using 2 nodes with 4 gpus each the batch size here
# (256) will be 256 / (2*8) = 16 per gpu
parser.opt_list('--batch_size', default=256 * 8, type=int,
options=[32, 64, 128, 256], tunable=False,
help='batch size will be divided over all gpus being used across all nodes')
return parser

View File

@ -10,7 +10,7 @@ from test_tube import Experiment, SlurmCluster
# sys.path += [os.path.abspath('..'), os.path.abspath('../..')]
from pytorch_lightning import Trainer
from pytorch_lightning.testing.lm_test_module import LightningTestModel
from pytorch_lightning.testing import LightningTestModel, NoValEndTestModel, NoValModel
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.root_module import memory
@ -26,6 +26,122 @@ np.random.seed(SEED)
# ------------------------------------------------------------------------
# TESTS
# ------------------------------------------------------------------------
def test_early_stopping_cpu_model():
"""
Test each of the trainer options
:return:
"""
stopping = EarlyStopping(monitor='val_loss')
trainer_options = dict(
early_stop_callback=stopping,
gradient_clip=1.0,
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
progress_bar=False,
experiment=get_exp(),
train_percent_check=0.1,
val_percent_check=0.1
)
model, hparams = get_model()
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
# test freeze on cpu
model.freeze()
model.unfreeze()
def test_no_val_module():
"""
Tests use case where trainer saves the model, and user loads it from tags independently
:return:
"""
hparams = get_hparams()
model = NoValModel(hparams)
save_dir = init_save_dir()
# exp file to get meta
exp = get_exp(False)
exp.argparse(hparams)
exp.save()
trainer_options = dict(
max_nb_epochs=1,
cluster=SlurmCluster(),
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir)
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# traning complete
assert result == 1, 'amp + ddp model failed to complete'
# save model
new_weights_path = os.path.join(save_dir, 'save_test.ckpt')
trainer.save_checkpoint(new_weights_path)
# load new model
tags_path = exp.get_data_path(exp.name, exp.version)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path, on_gpu=False)
model_2.eval()
# make prediction
clear_save_dir()
def test_no_val_end_module():
"""
Tests use case where trainer saves the model, and user loads it from tags independently
:return:
"""
hparams = get_hparams()
model = NoValEndTestModel(hparams)
save_dir = init_save_dir()
# exp file to get meta
exp = get_exp(False)
exp.argparse(hparams)
exp.save()
trainer_options = dict(
max_nb_epochs=1,
cluster=SlurmCluster(),
experiment=exp,
checkpoint_callback=ModelCheckpoint(save_dir)
)
# fit model
trainer = Trainer(**trainer_options)
result = trainer.fit(model)
# traning complete
assert result == 1, 'amp + ddp model failed to complete'
# save model
new_weights_path = os.path.join(save_dir, 'save_test.ckpt')
trainer.save_checkpoint(new_weights_path)
# load new model
tags_path = exp.get_data_path(exp.name, exp.version)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path,
tags_csv=tags_path, on_gpu=False)
model_2.eval()
# make prediction
clear_save_dir()
def test_simple_cpu():
"""
Verify continue training session on CPU
@ -445,33 +561,6 @@ def test_amp_gpu_ddp_slurm_managed():
clear_save_dir()
def test_early_stopping_cpu_model():
"""
Test each of the trainer options
:return:
"""
stopping = EarlyStopping()
trainer_options = dict(
early_stop_callback=stopping,
gradient_clip=1.0,
overfit_pct=0.20,
track_grad_norm=2,
print_nan_grads=True,
progress_bar=False,
experiment=get_exp(),
train_percent_check=0.1,
val_percent_check=0.1
)
model, hparams = get_model()
run_gpu_model_test(trainer_options, model, hparams, on_gpu=False)
# test freeze on cpu
model.freeze()
model.unfreeze()
def test_cpu_model_with_amp():
"""
Make sure model trains on CPU
@ -525,6 +614,7 @@ def test_all_features_cpu_model():
print_nan_grads=True,
progress_bar=False,
experiment=get_exp(),
accumulate_grad_batches=2,
max_nb_epochs=1,
train_percent_check=0.4,
val_percent_check=0.4