Parity test (#1284)
* adding test * adding test * added base parity model * added base parity model * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * added parity test * move parity to benchmark * formatting * fixed gradient acc sched * move parity to benchmark * formatting * fixed gradient acc sched * skip for CPU * call last Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
This commit is contained in:
parent
c869dd8b8f
commit
18d055a390
|
@ -22,7 +22,7 @@ steps:
|
|||
- pip install -r ./tests/requirements.txt --user
|
||||
- pip list
|
||||
- python -c "import torch ; print(' & '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) if torch.cuda.is_available() else 'only CPU')"
|
||||
- coverage run --source pytorch_lightning -m py.test pytorch_lightning tests -v --doctest-modules # --flake8
|
||||
- coverage run --source pytorch_lightning -m py.test pytorch_lightning tests benchmarks -v --doctest-modules # --flake8
|
||||
- coverage report
|
||||
- codecov --token $CODECOV_TOKEN # --pr $DRONE_PULL_REQUEST --build $DRONE_BUILD_NUMBER --branch $DRONE_BRANCH --commit $DRONE_COMMIT --tag $DRONE_TAG
|
||||
- python tests/collect_env_details.py
|
||||
|
|
|
@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added parity test between a vanilla MNIST model and lightning model ([#1284](https://github.com/PyTorchLightning/pytorch-lightning/pull/1284))
|
||||
- Added Reinforcement Learning - Deep Q-network (DQN) lightning example ([#1232](https://github.com/PyTorchLightning/pytorch-lightning/pull/1232))
|
||||
- Added support for hierarchical `dict` ([#1152](https://github.com/PyTorchLightning/pytorch-lightning/pull/1152))
|
||||
- Added `TrainsLogger` class ([#1122](https://github.com/PyTorchLightning/pytorch-lightning/pull/1122))
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import MNIST
|
||||
|
||||
from pytorch_lightning import Trainer, LightningModule
|
||||
|
||||
|
||||
class ParityMNIST(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super(ParityMNIST, self).__init__()
|
||||
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
|
||||
self.c_d1_bn = nn.BatchNorm1d(128)
|
||||
self.c_d1_drop = nn.Dropout(0.3)
|
||||
self.c_d2 = nn.Linear(in_features=128, out_features=10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.c_d1(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.c_d1_bn(x)
|
||||
x = self.c_d1_drop(x)
|
||||
x = self.c_d2(x)
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
return {'loss': loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(MNIST(os.getcwd(), train=True, download=True,
|
||||
transform=transforms.ToTensor()), batch_size=32)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_pytorch_parity(tmpdir):
|
||||
"""
|
||||
Verify that the same pytorch and lightning models achieve the same results
|
||||
:param tmpdir:
|
||||
:return:
|
||||
"""
|
||||
num_epochs = 2
|
||||
num_rums = 3
|
||||
lightning_outs, pl_times = lightning_loop(ParityMNIST, num_rums, num_epochs)
|
||||
manual_outs, pt_times = vanilla_loop(ParityMNIST, num_rums, num_epochs)
|
||||
|
||||
# make sure the losses match exactly to 5 decimal places
|
||||
for pl_out, pt_out in zip(lightning_outs, manual_outs):
|
||||
np.testing.assert_almost_equal(pl_out, pt_out, 5)
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
|
||||
def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
|
||||
"""
|
||||
Returns an array with the last loss from each epoch for each run
|
||||
"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
||||
errors = []
|
||||
times = []
|
||||
|
||||
for i in range(num_runs):
|
||||
time_start = time.perf_counter()
|
||||
|
||||
# set seed
|
||||
seed = i
|
||||
set_seed(seed)
|
||||
|
||||
# init model parts
|
||||
model = MODEL()
|
||||
dl = model.train_dataloader()
|
||||
optimizer = model.configure_optimizers()
|
||||
|
||||
# model to GPU
|
||||
model = model.to(device)
|
||||
|
||||
epoch_losses = []
|
||||
for epoch in range(num_epochs):
|
||||
|
||||
# run through full training set
|
||||
for j, batch in enumerate(dl):
|
||||
x, y = batch
|
||||
x = x.cuda(0)
|
||||
y = y.cuda(0)
|
||||
batch = (x, y)
|
||||
|
||||
loss_dict = model.training_step(batch, j)
|
||||
loss = loss_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# track last epoch loss
|
||||
epoch_losses.append(loss.item())
|
||||
|
||||
time_end = time.perf_counter()
|
||||
times.append(time_end - time_start)
|
||||
|
||||
errors.append(epoch_losses[-1])
|
||||
|
||||
return errors, times
|
||||
|
||||
|
||||
def lightning_loop(MODEL, num_runs=10, num_epochs=10):
|
||||
errors = []
|
||||
times = []
|
||||
|
||||
for i in range(num_runs):
|
||||
time_start = time.perf_counter()
|
||||
|
||||
# set seed
|
||||
seed = i
|
||||
set_seed(seed)
|
||||
|
||||
# init model parts
|
||||
model = MODEL()
|
||||
trainer = Trainer(
|
||||
max_epochs=num_epochs,
|
||||
show_progress_bar=False,
|
||||
weights_summary=None,
|
||||
gpus=1,
|
||||
early_stop_callback=False,
|
||||
checkpoint_callback=False
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
final_loss = trainer.running_loss.last().item()
|
||||
errors.append(final_loss)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
times.append(time_end - time_start)
|
||||
|
||||
return errors, times
|
Loading…
Reference in New Issue