2020-10-13 11:18:07 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2020-06-27 01:38:25 +00:00
|
|
|
import torch
|
2021-07-28 16:57:31 +00:00
|
|
|
from torchmetrics.functional import accuracy
|
2020-06-27 01:38:25 +00:00
|
|
|
|
2021-02-23 22:08:46 +00:00
|
|
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
2021-02-09 10:10:52 +00:00
|
|
|
from tests.helpers import BoringModel
|
2021-02-08 10:52:02 +00:00
|
|
|
from tests.helpers.utils import get_default_logger, load_model_from_checkpoint, reset_seed
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
|
2021-02-23 22:08:46 +00:00
|
|
|
def run_model_test_without_loggers(
|
|
|
|
trainer_options: dict, model: LightningModule, data: LightningDataModule = None, min_acc: float = 0.50
|
|
|
|
):
|
2020-06-27 01:38:25 +00:00
|
|
|
reset_seed()
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
2021-02-23 22:08:46 +00:00
|
|
|
trainer.fit(model, datamodule=data)
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
# correct result and ok accuracy
|
2021-05-04 10:50:56 +00:00
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
2020-06-27 01:38:25 +00:00
|
|
|
|
2021-02-23 22:08:46 +00:00
|
|
|
model2 = load_model_from_checkpoint(trainer.logger, trainer.checkpoint_callback.best_model_path, type(model))
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
# test new model accuracy
|
2021-02-23 22:08:46 +00:00
|
|
|
test_loaders = model2.test_dataloader() if not data else data.test_dataloader()
|
2020-06-27 01:38:25 +00:00
|
|
|
if not isinstance(test_loaders, list):
|
|
|
|
test_loaders = [test_loaders]
|
|
|
|
|
2021-02-23 22:08:46 +00:00
|
|
|
if not isinstance(model2, BoringModel):
|
|
|
|
for dataloader in test_loaders:
|
2021-12-19 13:08:43 +00:00
|
|
|
run_model_prediction(model2, dataloader, min_acc=min_acc)
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
|
2021-02-06 13:22:10 +00:00
|
|
|
def run_model_test(
|
2021-02-09 17:25:57 +00:00
|
|
|
trainer_options,
|
2021-02-23 22:08:46 +00:00
|
|
|
model: LightningModule,
|
2021-02-09 17:25:57 +00:00
|
|
|
data: LightningDataModule = None,
|
|
|
|
on_gpu: bool = True,
|
|
|
|
version=None,
|
|
|
|
with_hpc: bool = True,
|
2021-07-26 11:37:35 +00:00
|
|
|
min_acc: float = 0.25,
|
2021-02-06 13:22:10 +00:00
|
|
|
):
|
2020-06-27 01:38:25 +00:00
|
|
|
reset_seed()
|
2021-07-26 11:37:35 +00:00
|
|
|
save_dir = trainer_options["default_root_dir"]
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
logger = get_default_logger(save_dir, version=version)
|
|
|
|
trainer_options.update(logger=logger)
|
|
|
|
trainer = Trainer(**trainer_options)
|
2020-09-29 13:38:09 +00:00
|
|
|
initial_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])
|
2021-02-09 17:25:57 +00:00
|
|
|
trainer.fit(model, datamodule=data)
|
2020-09-29 13:38:09 +00:00
|
|
|
post_train_values = torch.tensor([torch.sum(torch.abs(x)) for x in model.parameters()])
|
2020-06-27 01:38:25 +00:00
|
|
|
|
2021-05-04 10:50:56 +00:00
|
|
|
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
2020-09-29 13:38:09 +00:00
|
|
|
# Check that the model is actually changed post-training
|
2021-01-07 10:50:08 +00:00
|
|
|
change_ratio = torch.norm(initial_values - post_train_values)
|
2021-10-29 21:46:39 +00:00
|
|
|
assert change_ratio > 0.03, f"the model is changed of {change_ratio}"
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
# test model loading
|
2021-12-23 01:56:37 +00:00
|
|
|
_ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
# test new model accuracy
|
2021-02-09 17:25:57 +00:00
|
|
|
test_loaders = model.test_dataloader() if not data else data.test_dataloader()
|
2020-06-27 01:38:25 +00:00
|
|
|
if not isinstance(test_loaders, list):
|
|
|
|
test_loaders = [test_loaders]
|
|
|
|
|
2021-02-23 22:08:46 +00:00
|
|
|
if not isinstance(model, BoringModel):
|
|
|
|
for dataloader in test_loaders:
|
2021-12-19 13:08:43 +00:00
|
|
|
run_model_prediction(model, dataloader, min_acc=min_acc)
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
if with_hpc:
|
2020-12-13 16:13:50 +00:00
|
|
|
# test HPC saving
|
2022-01-03 12:23:13 +00:00
|
|
|
# save logger to make sure we get all the metrics
|
|
|
|
if logger:
|
|
|
|
logger.finalize("finished")
|
|
|
|
hpc_save_path = trainer.checkpoint_connector.hpc_save_path(save_dir)
|
|
|
|
trainer.save_checkpoint(hpc_save_path)
|
2020-12-13 16:13:50 +00:00
|
|
|
# test HPC loading
|
2022-01-03 12:23:13 +00:00
|
|
|
checkpoint_path = trainer.checkpoint_connector._CheckpointConnector__get_max_ckpt_path_from_folder(save_dir)
|
2021-06-14 12:20:01 +00:00
|
|
|
trainer.checkpoint_connector.restore(checkpoint_path)
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
|
2021-02-23 22:08:46 +00:00
|
|
|
@torch.no_grad()
|
2021-12-19 13:08:43 +00:00
|
|
|
def run_model_prediction(trained_model, dataloader, min_acc=0.50):
|
2021-06-15 12:55:06 +00:00
|
|
|
orig_device = trained_model.device
|
2020-06-27 01:38:25 +00:00
|
|
|
# run prediction on 1 batch
|
2021-02-23 22:08:46 +00:00
|
|
|
trained_model.cpu()
|
|
|
|
trained_model.eval()
|
|
|
|
|
2020-07-07 18:54:07 +00:00
|
|
|
batch = next(iter(dataloader))
|
2020-06-27 01:38:25 +00:00
|
|
|
x, y = batch
|
2021-02-23 22:08:46 +00:00
|
|
|
x = x.flatten(1)
|
2020-06-27 01:38:25 +00:00
|
|
|
|
2021-02-23 22:08:46 +00:00
|
|
|
y_hat = trained_model(x)
|
|
|
|
acc = accuracy(y_hat.cpu(), y.cpu(), top_k=2).item()
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})"
|
2021-06-15 12:55:06 +00:00
|
|
|
trained_model.to(orig_device)
|