diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py index 092eb108cb..23d0cbc9d5 100644 --- a/tests/callbacks/test_quantization.py +++ b/tests/callbacks/test_quantization.py @@ -16,7 +16,7 @@ from typing import Callable, Union import pytest import torch -from torchmetrics.functional import mean_relative_error +from torchmetrics.functional import mean_absolute_percentage_error as mape from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import QuantizationAwareTraining @@ -42,7 +42,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool): trainer = Trainer(**trainer_args) trainer.fit(model, datamodule=dm) org_size = get_model_size_mb(model) - org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()])) + org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()])) fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert) @@ -51,7 +51,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool): quant_calls = qcb._forward_calls assert quant_calls == qcb._forward_calls - quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()])) + quant_score = torch.mean(torch.tensor([mape(qmodel(x), y) for x, y in dm.test_dataloader()])) # test that the test score is almost the same as with pure training assert torch.allclose(org_score, quant_score, atol=0.45) model_path = trainer.checkpoint_callback.best_model_path @@ -70,7 +70,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool): # todo: make it work also with strict loading qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False) - quant2_score = torch.mean(torch.tensor([mean_relative_error(qmodel2(x), y) for x, y in dm.test_dataloader()])) + quant2_score = torch.mean(torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()])) assert torch.allclose(org_score, quant2_score, atol=0.45)