fix/test quant (#6040)

* fix/test quant

* ...

* ---
This commit is contained in:
Jirka Borovec 2021-02-18 11:47:29 +01:00 committed by GitHub
parent 38ad9e0764
commit 049006a59c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 17 additions and 8 deletions

View File

@ -128,8 +128,11 @@ class QuantizationAwareTraining(Callback):
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
but break compatibility to torchscript
"""
if not isinstance(qconfig, (str, QConfig)):
raise MisconfigurationException(f"Unsupported qconfig: f{qconfig}.")
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
raise MisconfigurationException(
f"Unsupported qconfig: f{qconfig}.\nTry one of defaults: {torch.backends.quantized.supported_engines}"
)
self._qconfig = qconfig
if observer_type not in self.OBSERVER_TYPES:

View File

@ -30,9 +30,7 @@ def has_len(dataloader: DataLoader) -> bool:
try:
# try getting the length
if len(dataloader) == 0:
raise ValueError(
'`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch'
)
raise ValueError('`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch')
has_len = True
except TypeError:
has_len = False

View File

@ -52,7 +52,7 @@ _IS_WINDOWS = platform.system() == "Windows"
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
_TORCH_QUANTIZE_AVAILABLE = _module_available('torch.ops.quantized')
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')

View File

@ -35,9 +35,12 @@ RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
if not os.path.isdir(_TEMP_PATH):
os.mkdir(_TEMP_PATH)
_MISS_QUANT_DEFAULT = 'fbgemm' not in torch.backends.quantized.supported_engines
_SKIPIF_ARGS_PT_LE_1_4 = dict(condition=_TORCH_LOWER_EQUAL_1_4, reason="test pytorch > 1.4")
_SKIPIF_ARGS_NO_GPU = dict(condition=not torch.cuda.is_available(), reason="test requires single-GPU machine")
_SKIPIF_ARGS_NO_GPUS = dict(condition=torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
_SKIPIF_ARGS_NO_PT_QUANT = dict(
condition=not _TORCH_QUANTIZE_AVAILABLE, reason="PyTorch quantization is needed for this test"
condition=not _TORCH_QUANTIZE_AVAILABLE or _MISS_QUANT_DEFAULT,
reason="PyTorch quantization is needed for this test"
)

View File

@ -73,9 +73,14 @@ def test_quantize_torchscript(tmpdir):
trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1)
trainer.fit(qmodel, datamodule=dm)
qmodel.to_torchscript()
batch = iter(dm.test_dataloader()).next()
qmodel(qmodel.quant(batch[0]))
tsmodel = qmodel.to_torchscript()
tsmodel(tsmodel.quant(batch[0]))
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
def test_quantization_exceptions(tmpdir):
"""Test wrong fuse layers"""
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):