parent
38ad9e0764
commit
049006a59c
|
@ -128,8 +128,11 @@ class QuantizationAwareTraining(Callback):
|
|||
input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
|
||||
but break compatibility to torchscript
|
||||
"""
|
||||
if not isinstance(qconfig, (str, QConfig)):
|
||||
raise MisconfigurationException(f"Unsupported qconfig: f{qconfig}.")
|
||||
_valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
|
||||
if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
|
||||
raise MisconfigurationException(
|
||||
f"Unsupported qconfig: f{qconfig}.\nTry one of defaults: {torch.backends.quantized.supported_engines}"
|
||||
)
|
||||
self._qconfig = qconfig
|
||||
|
||||
if observer_type not in self.OBSERVER_TYPES:
|
||||
|
|
|
@ -30,9 +30,7 @@ def has_len(dataloader: DataLoader) -> bool:
|
|||
try:
|
||||
# try getting the length
|
||||
if len(dataloader) == 0:
|
||||
raise ValueError(
|
||||
'`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch'
|
||||
)
|
||||
raise ValueError('`Dataloader` returned 0 length. Please make sure that it returns at least 1 batch')
|
||||
has_len = True
|
||||
except TypeError:
|
||||
has_len = False
|
||||
|
|
|
@ -52,7 +52,7 @@ _IS_WINDOWS = platform.system() == "Windows"
|
|||
_TORCH_LOWER_EQUAL_1_4 = _compare_version("torch", operator.le, "1.5.0")
|
||||
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
|
||||
_TORCH_GREATER_EQUAL_1_7 = _compare_version("torch", operator.ge, "1.7.0")
|
||||
_TORCH_QUANTIZE_AVAILABLE = _module_available('torch.ops.quantized')
|
||||
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != 'none'])
|
||||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
_DEEPSPEED_AVAILABLE = not _IS_WINDOWS and _module_available('deepspeed')
|
||||
|
|
|
@ -35,9 +35,12 @@ RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
|
|||
if not os.path.isdir(_TEMP_PATH):
|
||||
os.mkdir(_TEMP_PATH)
|
||||
|
||||
_MISS_QUANT_DEFAULT = 'fbgemm' not in torch.backends.quantized.supported_engines
|
||||
|
||||
_SKIPIF_ARGS_PT_LE_1_4 = dict(condition=_TORCH_LOWER_EQUAL_1_4, reason="test pytorch > 1.4")
|
||||
_SKIPIF_ARGS_NO_GPU = dict(condition=not torch.cuda.is_available(), reason="test requires single-GPU machine")
|
||||
_SKIPIF_ARGS_NO_GPUS = dict(condition=torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
_SKIPIF_ARGS_NO_PT_QUANT = dict(
|
||||
condition=not _TORCH_QUANTIZE_AVAILABLE, reason="PyTorch quantization is needed for this test"
|
||||
condition=not _TORCH_QUANTIZE_AVAILABLE or _MISS_QUANT_DEFAULT,
|
||||
reason="PyTorch quantization is needed for this test"
|
||||
)
|
||||
|
|
|
@ -73,9 +73,14 @@ def test_quantize_torchscript(tmpdir):
|
|||
trainer = Trainer(callbacks=[qcb], default_root_dir=tmpdir, max_epochs=1)
|
||||
trainer.fit(qmodel, datamodule=dm)
|
||||
|
||||
qmodel.to_torchscript()
|
||||
batch = iter(dm.test_dataloader()).next()
|
||||
qmodel(qmodel.quant(batch[0]))
|
||||
|
||||
tsmodel = qmodel.to_torchscript()
|
||||
tsmodel(tsmodel.quant(batch[0]))
|
||||
|
||||
|
||||
@pytest.mark.skipif(**_SKIPIF_ARGS_NO_PT_QUANT)
|
||||
def test_quantization_exceptions(tmpdir):
|
||||
"""Test wrong fuse layers"""
|
||||
with pytest.raises(MisconfigurationException, match='Unsupported qconfig'):
|
||||
|
|
Loading…
Reference in New Issue