2020-12-17 11:03:45 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
2020-03-30 22:16:32 +00:00
|
|
|
import time
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import pytest
|
|
|
|
import torch
|
2020-12-17 11:03:45 +00:00
|
|
|
from tqdm import tqdm
|
2020-03-30 22:16:32 +00:00
|
|
|
|
2020-12-16 06:09:26 +00:00
|
|
|
from pytorch_lightning import seed_everything, Trainer
|
2020-06-27 01:38:25 +00:00
|
|
|
import tests.base.develop_utils as tutils
|
2020-11-13 21:57:46 +00:00
|
|
|
from tests.base.models import ParityModuleMNIST, ParityModuleRNN
|
2020-03-30 22:16:32 +00:00
|
|
|
|
|
|
|
|
2020-11-27 18:36:50 +00:00
|
|
|
# ParityModuleMNIST runs with num_workers=1
|
2020-06-04 15:20:12 +00:00
|
|
|
@pytest.mark.parametrize('cls_model,max_diff', [
|
|
|
|
(ParityModuleRNN, 0.05),
|
2020-11-30 21:21:59 +00:00
|
|
|
(ParityModuleMNIST, 0.25), # todo: lower this thr
|
2020-06-04 15:20:12 +00:00
|
|
|
])
|
2020-03-30 22:16:32 +00:00
|
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
2020-12-17 11:03:45 +00:00
|
|
|
def test_pytorch_parity(tmpdir, cls_model, max_diff: float, num_epochs: int = 4, num_runs: int = 3):
|
2020-03-30 22:16:32 +00:00
|
|
|
"""
|
|
|
|
Verify that the same pytorch and lightning models achieve the same results
|
|
|
|
"""
|
2020-12-17 11:03:45 +00:00
|
|
|
lightning = lightning_loop(cls_model, num_runs, num_epochs)
|
|
|
|
vanilla = vanilla_loop(cls_model, num_runs, num_epochs)
|
2020-03-30 22:16:32 +00:00
|
|
|
|
|
|
|
# make sure the losses match exactly to 5 decimal places
|
2020-12-17 11:03:45 +00:00
|
|
|
for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']):
|
2020-03-30 22:16:32 +00:00
|
|
|
np.testing.assert_almost_equal(pl_out, pt_out, 5)
|
|
|
|
|
2020-04-16 02:16:40 +00:00
|
|
|
# the fist run initialize dataset (download & filter)
|
2020-12-17 11:03:45 +00:00
|
|
|
tutils.assert_speed_parity_absolute(
|
|
|
|
lightning['durations'][1:], vanilla['durations'][1:], nb_epochs=num_epochs, max_diff=max_diff
|
|
|
|
)
|
2020-04-15 00:23:36 +00:00
|
|
|
|
2020-03-30 22:16:32 +00:00
|
|
|
|
2020-06-04 15:20:12 +00:00
|
|
|
def vanilla_loop(cls_model, num_runs=10, num_epochs=10):
|
2020-03-30 22:16:32 +00:00
|
|
|
"""
|
|
|
|
Returns an array with the last loss from each epoch for each run
|
|
|
|
"""
|
2020-12-17 11:03:45 +00:00
|
|
|
hist_losses = []
|
|
|
|
hist_durations = []
|
2020-03-30 22:16:32 +00:00
|
|
|
|
2020-12-17 11:03:45 +00:00
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
2020-05-12 11:53:20 +00:00
|
|
|
torch.backends.cudnn.deterministic = True
|
2020-12-17 11:03:45 +00:00
|
|
|
for i in tqdm(range(num_runs), desc=f'Vanilla PT with {cls_model.__name__}'):
|
2020-03-30 22:16:32 +00:00
|
|
|
time_start = time.perf_counter()
|
|
|
|
|
|
|
|
# set seed
|
|
|
|
seed = i
|
2020-05-12 11:53:20 +00:00
|
|
|
seed_everything(seed)
|
2020-03-30 22:16:32 +00:00
|
|
|
|
|
|
|
# init model parts
|
2020-06-04 15:20:12 +00:00
|
|
|
model = cls_model()
|
2020-03-30 22:16:32 +00:00
|
|
|
dl = model.train_dataloader()
|
|
|
|
optimizer = model.configure_optimizers()
|
|
|
|
|
|
|
|
# model to GPU
|
|
|
|
model = model.to(device)
|
|
|
|
|
|
|
|
epoch_losses = []
|
2020-06-04 15:20:12 +00:00
|
|
|
# as the first run is skipped, no need to run it long
|
|
|
|
for epoch in range(num_epochs if i > 0 else 1):
|
2020-03-30 22:16:32 +00:00
|
|
|
|
|
|
|
# run through full training set
|
|
|
|
for j, batch in enumerate(dl):
|
2020-06-04 15:20:12 +00:00
|
|
|
batch = [x.to(device) for x in batch]
|
2020-03-30 22:16:32 +00:00
|
|
|
loss_dict = model.training_step(batch, j)
|
|
|
|
loss = loss_dict['loss']
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
# track last epoch loss
|
|
|
|
epoch_losses.append(loss.item())
|
|
|
|
|
|
|
|
time_end = time.perf_counter()
|
2020-12-17 11:03:45 +00:00
|
|
|
hist_durations.append(time_end - time_start)
|
2020-03-30 22:16:32 +00:00
|
|
|
|
2020-12-17 11:03:45 +00:00
|
|
|
hist_losses.append(epoch_losses[-1])
|
2020-03-30 22:16:32 +00:00
|
|
|
|
2020-12-17 11:03:45 +00:00
|
|
|
return {
|
|
|
|
'losses': hist_losses,
|
|
|
|
'durations': hist_durations,
|
|
|
|
}
|
2020-03-30 22:16:32 +00:00
|
|
|
|
|
|
|
|
2020-06-04 15:20:12 +00:00
|
|
|
def lightning_loop(cls_model, num_runs=10, num_epochs=10):
|
2020-12-17 11:03:45 +00:00
|
|
|
hist_losses = []
|
|
|
|
hist_durations = []
|
2020-03-30 22:16:32 +00:00
|
|
|
|
2020-12-17 11:03:45 +00:00
|
|
|
for i in tqdm(range(num_runs), desc=f'PT Lightning with {cls_model.__name__}'):
|
2020-03-30 22:16:32 +00:00
|
|
|
time_start = time.perf_counter()
|
|
|
|
|
|
|
|
# set seed
|
|
|
|
seed = i
|
2020-05-12 11:53:20 +00:00
|
|
|
seed_everything(seed)
|
2020-03-30 22:16:32 +00:00
|
|
|
|
2020-06-04 15:20:12 +00:00
|
|
|
model = cls_model()
|
2020-05-12 11:53:20 +00:00
|
|
|
# init model parts
|
2020-03-30 22:16:32 +00:00
|
|
|
trainer = Trainer(
|
2020-06-04 15:20:12 +00:00
|
|
|
# as the first run is skipped, no need to run it long
|
|
|
|
max_epochs=num_epochs if i > 0 else 1,
|
2020-04-24 18:45:43 +00:00
|
|
|
progress_bar_refresh_rate=0,
|
2020-03-30 22:16:32 +00:00
|
|
|
weights_summary=None,
|
|
|
|
gpus=1,
|
2020-05-12 11:53:20 +00:00
|
|
|
checkpoint_callback=False,
|
|
|
|
deterministic=True,
|
2020-06-04 15:20:12 +00:00
|
|
|
logger=False,
|
|
|
|
replace_sampler_ddp=False,
|
2020-03-30 22:16:32 +00:00
|
|
|
)
|
|
|
|
trainer.fit(model)
|
|
|
|
|
2020-09-10 11:24:42 +00:00
|
|
|
final_loss = trainer.train_loop.running_loss.last().item()
|
2020-12-17 11:03:45 +00:00
|
|
|
hist_losses.append(final_loss)
|
2020-03-30 22:16:32 +00:00
|
|
|
|
|
|
|
time_end = time.perf_counter()
|
2020-12-17 11:03:45 +00:00
|
|
|
hist_durations.append(time_end - time_start)
|
2020-03-30 22:16:32 +00:00
|
|
|
|
2020-12-17 11:03:45 +00:00
|
|
|
return {
|
|
|
|
'losses': hist_losses,
|
|
|
|
'durations': hist_durations,
|
|
|
|
}
|