From 049006a59cf62d3539aaf19037aca1caac98abfd Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 18 Feb 2021 11:47:29 +0100 Subject: [PATCH] fix/test quant (#6040) * fix/test quant * ... * --- --- pytorch_lightning/callbacks/quantization.py | 7 +++++-- pytorch_lightning/utilities/data.py | 4 +--- pytorch_lightning/utilities/imports.py | 2 +- tests/__init__.py | 5 ++++- tests/callbacks/test_quantization.py | 7 ++++++- 5 files changed, 17 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/quantization.py b/pytorch_lightning/callbacks/quantization.py index f0458ff3b1..c5a1282f7b 100644 --- a/pytorch_lightning/callbacks/quantization.py +++ b/pytorch_lightning/callbacks/quantization.py @@ -128,8 +128,11 @@ class QuantizationAwareTraining(Callback): input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model, but break compatibility to torchscript """ - if not isinstance(qconfig, (str, QConfig)): - raise MisconfigurationException(f"Unsupported qconfig: f{qconfig}.") + _valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines + if not isinstance(qconfig, QConfig) and not _valid_qconf_str: + raise MisconfigurationException( + f"Unsupported qconfig: f{qconfig}.\nTry one of defaults: {torch.backends.quantized.supported_engines}" + ) self._qconfig = qconfig if observer_type not in self.OBSERVER_TYPES: diff --git a/pytorch_lightning/utilities/data.py b/pytorch_lightning/utilities/data.py index 6b887b8526..a73299e2af 100644 --- a/pytorch_lightning/utilities/data.py +++ b/pytorch_lightning/utilities/data.py @@ -30,9 +30,7 @@ def has_len(dataloader: DataLoader) -> bool: try: # try getting the length if len(dataloader) == 0: - raise ValueError( - '`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch' - ) + raise ValueError('`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch') has_len = True except TypeError: has_len = False diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index b4c30097fa..8024997382 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -52,7 +52,7 @@ _IS_WINDOWS = platform.system() == "Windows" _TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0") _TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0") _TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0") -_TORCH_QUANTIZE_AVAILABLE = _module_available('torch.ops.quantized') +_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none']) _APEX_AVAILABLE = _module_available("apex.amp") _BOLTS_AVAILABLE = _module_available('pl_bolts') _DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed') diff --git a/tests/__init__.py b/tests/__init__.py index c9642003d6..a833da7cbd 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -35,9 +35,12 @@ RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000)) if not os.path.isdir(_TEMP_PATH): os.mkdir(_TEMP_PATH) +_MISS_QUANT_DEFAULT = 'fbgemm' not in torch.backends.quantized.supported_engines + _SKIPIF_ARGS_PT_LE_1_4 = dict(condition=_TORCH_LOWER_EQUAL_1_4, reason="test pytorch > 1.4") _SKIPIF_ARGS_NO_GPU = dict(condition=not torch.cuda.is_available(), reason="test requires single-GPU machine") _SKIPIF_ARGS_NO_GPUS = dict(condition=torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") _SKIPIF_ARGS_NO_PT_QUANT = dict( - condition=not _TORCH_QUANTIZE_AVAILABLE, reason="PyTorch quantization is needed for this test" + condition=not _TORCH_QUANTIZE_AVAILABLE or _MISS_QUANT_DEFAULT, + reason="PyTorch quantization is needed for this test" ) diff --git a/tests/callbacks/test_quantization.py b/tests/callbacks/test_quantization.py index 620346c0bd..7b51b81e1b 100644 --- a/tests/callbacks/test_quantization.py +++ b/tests/callbacks/test_quantization.py @@ -73,9 +73,14 @@ def test_quantize_torchscript(tmpdir): trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1) trainer.fit(qmodel, datamodule=dm) - qmodel.to_torchscript() + batch = iter(dm.test_dataloader()).next() + qmodel(qmodel.quant(batch[0])) + + tsmodel = qmodel.to_torchscript() + tsmodel(tsmodel.quant(batch[0])) +@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT) def test_quantization_exceptions(tmpdir): """Test wrong fuse layers""" with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):