2020-06-27 01:38:25 +00:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from pytorch_lightning import Trainer
|
2020-07-07 18:54:07 +00:00
|
|
|
from tests.base.develop_utils import load_model_from_checkpoint, get_default_logger, \
|
2020-06-27 01:38:25 +00:00
|
|
|
reset_seed
|
|
|
|
|
|
|
|
|
|
|
|
def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50):
|
|
|
|
reset_seed()
|
|
|
|
|
|
|
|
# fit model
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
|
|
|
|
# correct result and ok accuracy
|
|
|
|
assert result == 1, 'amp + ddp model failed to complete'
|
|
|
|
|
|
|
|
pretrained_model = load_model_from_checkpoint(
|
|
|
|
trainer.logger,
|
2020-07-07 16:24:56 +00:00
|
|
|
trainer.checkpoint_callback.best_model_path,
|
2020-06-27 01:38:25 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# test new model accuracy
|
|
|
|
test_loaders = model.test_dataloader()
|
|
|
|
if not isinstance(test_loaders, list):
|
|
|
|
test_loaders = [test_loaders]
|
|
|
|
|
|
|
|
for dataloader in test_loaders:
|
|
|
|
run_prediction(dataloader, pretrained_model, min_acc=min_acc)
|
|
|
|
|
|
|
|
if trainer.use_ddp:
|
|
|
|
# on hpc this would work fine... but need to hack it for the purpose of the test
|
|
|
|
trainer.model = pretrained_model
|
|
|
|
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
|
|
|
|
|
|
|
|
|
|
|
|
def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, with_hpc: bool = True):
|
2020-07-07 16:24:56 +00:00
|
|
|
|
2020-06-27 01:38:25 +00:00
|
|
|
reset_seed()
|
|
|
|
save_dir = trainer_options['default_root_dir']
|
|
|
|
|
|
|
|
# logger file to get meta
|
|
|
|
logger = get_default_logger(save_dir, version=version)
|
|
|
|
trainer_options.update(logger=logger)
|
|
|
|
|
|
|
|
if 'checkpoint_callback' not in trainer_options:
|
2020-07-07 16:24:56 +00:00
|
|
|
trainer_options.update(checkpoint_callback=True)
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
trainer = Trainer(**trainer_options)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
|
|
|
|
# correct result and ok accuracy
|
2020-07-09 10:46:07 +00:00
|
|
|
assert result == 1, 'trainer failed'
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
# test model loading
|
2020-07-07 16:24:56 +00:00
|
|
|
pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path)
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
# test new model accuracy
|
|
|
|
test_loaders = model.test_dataloader()
|
|
|
|
if not isinstance(test_loaders, list):
|
|
|
|
test_loaders = [test_loaders]
|
|
|
|
|
2020-07-07 18:54:07 +00:00
|
|
|
for dataloader in test_loaders:
|
|
|
|
run_prediction(dataloader, pretrained_model)
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
if with_hpc:
|
|
|
|
if trainer.use_ddp or trainer.use_ddp2:
|
|
|
|
# on hpc this would work fine... but need to hack it for the purpose of the test
|
|
|
|
trainer.model = pretrained_model
|
|
|
|
trainer.optimizers, trainer.lr_schedulers, trainer.optimizer_frequencies = \
|
|
|
|
trainer.init_optimizers(pretrained_model)
|
|
|
|
|
|
|
|
# test HPC loading / saving
|
|
|
|
trainer.hpc_save(save_dir, logger)
|
|
|
|
trainer.hpc_load(save_dir, on_gpu=on_gpu)
|
|
|
|
|
|
|
|
|
|
|
|
def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
|
|
|
|
# run prediction on 1 batch
|
2020-07-07 18:54:07 +00:00
|
|
|
batch = next(iter(dataloader))
|
2020-06-27 01:38:25 +00:00
|
|
|
x, y = batch
|
|
|
|
x = x.view(x.size(0), -1)
|
|
|
|
|
|
|
|
if dp:
|
2020-07-09 00:33:48 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
output = trained_model(batch, 0)
|
2020-06-27 01:38:25 +00:00
|
|
|
acc = output['val_acc']
|
|
|
|
acc = torch.mean(acc).item()
|
|
|
|
|
|
|
|
else:
|
2020-07-09 00:33:48 +00:00
|
|
|
with torch.no_grad():
|
|
|
|
y_hat = trained_model(x)
|
|
|
|
y_hat = y_hat.cpu()
|
2020-06-27 01:38:25 +00:00
|
|
|
|
|
|
|
# acc
|
|
|
|
labels_hat = torch.argmax(y_hat, dim=1)
|
2020-07-09 00:33:48 +00:00
|
|
|
|
|
|
|
y = y.cpu()
|
2020-06-27 01:38:25 +00:00
|
|
|
acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
|
|
|
|
acc = torch.tensor(acc)
|
|
|
|
acc = acc.item()
|
|
|
|
|
|
|
|
assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})"
|