Tests: fix deprecated TM mape (#8830)
This commit is contained in:
parent
3ef8cd654d
commit
3096ab88eb
|
@ -16,7 +16,7 @@ from typing import Callable, Union
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torchmetrics.functional import mean_relative_error
|
from torchmetrics.functional import mean_absolute_percentage_error as mape
|
||||||
|
|
||||||
from pytorch_lightning import seed_everything, Trainer
|
from pytorch_lightning import seed_everything, Trainer
|
||||||
from pytorch_lightning.callbacks import QuantizationAwareTraining
|
from pytorch_lightning.callbacks import QuantizationAwareTraining
|
||||||
|
@ -42,7 +42,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
|
||||||
trainer = Trainer(**trainer_args)
|
trainer = Trainer(**trainer_args)
|
||||||
trainer.fit(model, datamodule=dm)
|
trainer.fit(model, datamodule=dm)
|
||||||
org_size = get_model_size_mb(model)
|
org_size = get_model_size_mb(model)
|
||||||
org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()]))
|
org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()]))
|
||||||
|
|
||||||
fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None
|
fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None
|
||||||
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert)
|
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert)
|
||||||
|
@ -51,7 +51,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
|
||||||
|
|
||||||
quant_calls = qcb._forward_calls
|
quant_calls = qcb._forward_calls
|
||||||
assert quant_calls == qcb._forward_calls
|
assert quant_calls == qcb._forward_calls
|
||||||
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
|
quant_score = torch.mean(torch.tensor([mape(qmodel(x), y) for x, y in dm.test_dataloader()]))
|
||||||
# test that the test score is almost the same as with pure training
|
# test that the test score is almost the same as with pure training
|
||||||
assert torch.allclose(org_score, quant_score, atol=0.45)
|
assert torch.allclose(org_score, quant_score, atol=0.45)
|
||||||
model_path = trainer.checkpoint_callback.best_model_path
|
model_path = trainer.checkpoint_callback.best_model_path
|
||||||
|
@ -70,7 +70,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
|
||||||
|
|
||||||
# todo: make it work also with strict loading
|
# todo: make it work also with strict loading
|
||||||
qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
|
qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
|
||||||
quant2_score = torch.mean(torch.tensor([mean_relative_error(qmodel2(x), y) for x, y in dm.test_dataloader()]))
|
quant2_score = torch.mean(torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()]))
|
||||||
assert torch.allclose(org_score, quant2_score, atol=0.45)
|
assert torch.allclose(org_score, quant2_score, atol=0.45)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue