refactor - check E501 (#5200)
This commit is contained in:
parent
6d2c564bc6
commit
35fd6e93c7
|
@ -84,7 +84,8 @@ class ExternalSourcePipeline(Pipeline):
|
|||
|
||||
class DALIClassificationLoader(DALIClassificationIterator):
|
||||
"""
|
||||
This class extends DALI's original DALIClassificationIterator with the __len__() function so that we can call len() on it
|
||||
This class extends DALI's original `DALIClassificationIterator` with the `__len__()` function
|
||||
so that we can call `len()` on it
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -1188,7 +1188,8 @@ class LightningModule(
|
|||
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
|
||||
once per optimizer.
|
||||
|
||||
.. tip:: With `Trainer(enable_pl_optimizer=True)`, you can user `optimizer.step()` directly and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you.
|
||||
.. tip:: With `Trainer(enable_pl_optimizer=True)`, you can user `optimizer.step()` directly
|
||||
and it will handle zero_grad, accumulated gradients, AMP, TPU and more automatically for you.
|
||||
|
||||
Warning:
|
||||
If you are overriding this method, make sure that you pass the ``optimizer_closure`` parameter
|
||||
|
|
|
@ -51,7 +51,8 @@ class LightningOptimizer:
|
|||
|
||||
# For Horovod
|
||||
if hasattr(optimizer, "skip_synchronize"):
|
||||
self.__class__ = type("Lightning" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__.__bases__[0]), {})
|
||||
self.__class__ = type("Lightning" + optimizer.__class__.__name__,
|
||||
(self.__class__, optimizer.__class__.__bases__[0]), {})
|
||||
self.skip_synchronize = optimizer.skip_synchronize
|
||||
self.synchronize = optimizer.synchronize
|
||||
else:
|
||||
|
|
|
@ -174,7 +174,8 @@ class ModelIO(object):
|
|||
cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key))
|
||||
|
||||
# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
|
||||
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
|
||||
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded,
|
||||
checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
|
||||
|
||||
# 4. Update cls_kwargs_new with cls_kwargs_old, such that new has higher priority
|
||||
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
|
||||
|
|
|
@ -127,7 +127,8 @@ def _check_num_classes_binary(num_classes: int, is_multiclass: bool):
|
|||
if num_classes == 1 and is_multiclass:
|
||||
raise ValueError(
|
||||
"You have binary data and have set `is_multiclass=True`, but `num_classes` is 1."
|
||||
" Either set `is_multiclass=None`(default) or set `num_classes=2` to transform binary data to multi-class format."
|
||||
" Either set `is_multiclass=None`(default) or set `num_classes=2`"
|
||||
" to transform binary data to multi-class format."
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -73,11 +73,12 @@ class PrecisionRecallCurve(Metric):
|
|||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> pr_curve = PrecisionRecallCurve(num_classes=5)
|
||||
>>> precision, recall, thresholds = pr_curve(pred, target)
|
||||
>>> precision
|
||||
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
|
||||
>>> precision # doctest: +NORMALIZE_WHITESPACE
|
||||
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
|
||||
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
|
||||
>>> recall
|
||||
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
|
||||
>>> thresholds # doctest: +NORMALIZE_WHITESPACE
|
||||
>>> thresholds
|
||||
[tensor([0.7500]), tensor([0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500, 0.7500]), tensor([0.0500])]
|
||||
|
||||
"""
|
||||
|
|
|
@ -18,7 +18,10 @@ import torch
|
|||
|
||||
from pytorch_lightning.metrics.functional.average_precision import average_precision as __ap
|
||||
from pytorch_lightning.metrics.functional.f_beta import fbeta as __fb, f1 as __f1
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import _binary_clf_curve, precision_recall_curve as __prc
|
||||
from pytorch_lightning.metrics.functional.precision_recall_curve import (
|
||||
_binary_clf_curve,
|
||||
precision_recall_curve as __prc
|
||||
)
|
||||
from pytorch_lightning.metrics.functional.roc import roc as __roc
|
||||
from pytorch_lightning.metrics.utils import (
|
||||
to_categorical as __tc,
|
||||
|
@ -821,7 +824,8 @@ def precision_recall_curve(
|
|||
"""
|
||||
Computes precision-recall pairs for different thresholds.
|
||||
|
||||
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
|
||||
.. warning :: Deprecated in favor of
|
||||
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
|
||||
"""
|
||||
rank_zero_warn(
|
||||
"This `precision_recall_curve` was deprecated in v1.1.0 in favor of"
|
||||
|
@ -841,7 +845,8 @@ def multiclass_precision_recall_curve(
|
|||
"""
|
||||
Computes precision-recall pairs for different thresholds given a multiclass scores.
|
||||
|
||||
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
|
||||
.. warning :: Deprecated in favor of
|
||||
:func:`~pytorch_lightning.metrics.functional.precision_recall_curve.precision_recall_curve`
|
||||
"""
|
||||
rank_zero_warn(
|
||||
"This `multiclass_precision_recall_curve` was deprecated in v1.1.0 in favor of"
|
||||
|
@ -863,7 +868,8 @@ def average_precision(
|
|||
"""
|
||||
Compute average precision from prediction scores.
|
||||
|
||||
.. warning :: Deprecated in favor of :func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
|
||||
.. warning :: Deprecated in favor of
|
||||
:func:`~pytorch_lightning.metrics.functional.average_precision.average_precision`
|
||||
"""
|
||||
rank_zero_warn(
|
||||
"This `average_precision` was deprecated in v1.1.0 in favor of"
|
||||
|
|
|
@ -208,8 +208,9 @@ def precision_recall_curve(
|
|||
... [0.05, 0.05, 0.05, 0.75, 0.05]])
|
||||
>>> target = torch.tensor([0, 1, 3, 2])
|
||||
>>> precision, recall, thresholds = precision_recall_curve(pred, target, num_classes=5)
|
||||
>>> precision
|
||||
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]), tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
|
||||
>>> precision # doctest: +NORMALIZE_WHITESPACE
|
||||
[tensor([1., 1.]), tensor([1., 1.]), tensor([0.2500, 0.0000, 1.0000]),
|
||||
tensor([0.2500, 0.0000, 1.0000]), tensor([0., 1.])]
|
||||
>>> recall
|
||||
[tensor([1., 0.]), tensor([1., 0.]), tensor([1., 0., 0.]), tensor([1., 0., 0.]), tensor([nan, 0.])]
|
||||
>>> thresholds
|
||||
|
|
|
@ -172,7 +172,8 @@ def _load_long_description(path_dir: str) -> str:
|
|||
|
||||
# readthedocs badge
|
||||
text = text.replace('badge/?version=stable', f'badge/?version={__version__}')
|
||||
text = text.replace('pytorch-lightning.readthedocs.io/en/stable/', f'pytorch-lightning.readthedocs.io/en/{__version__}')
|
||||
text = text.replace('pytorch-lightning.readthedocs.io/en/stable/',
|
||||
f'pytorch-lightning.readthedocs.io/en/{__version__}')
|
||||
# codecov badge
|
||||
text = text.replace('/branch/master/graph/badge.svg', f'/release/{__version__}/graph/badge.svg')
|
||||
# replace github badges for release ones
|
||||
|
|
|
@ -282,7 +282,9 @@ class CheckpointConnector:
|
|||
checkpoint['lr_schedulers'] = lr_schedulers
|
||||
|
||||
# dump amp scaling
|
||||
if self.trainer.amp_backend == AMPType.NATIVE and not self.trainer.use_tpu and self.trainer.scaler is not None:
|
||||
if (self.trainer.amp_backend == AMPType.NATIVE
|
||||
and not self.trainer.use_tpu
|
||||
and self.trainer.scaler is not None):
|
||||
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
|
||||
elif self.trainer.amp_backend == AMPType.APEX:
|
||||
checkpoint['amp_scaling_state'] = amp.state_dict()
|
||||
|
|
|
@ -168,8 +168,9 @@ class LoggerConnector:
|
|||
metrics (dict): Metric values
|
||||
grad_norm_dic (dict): Gradient norms
|
||||
step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
|
||||
log_train_step_metrics (bool): Used to track if log_metrics function is being called in during training steps.
|
||||
In training steps, we will log metrics on step: total_nb_idx (for accumulated gradients) and global_step for the rest.
|
||||
log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training
|
||||
steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients)
|
||||
and global_step for the rest.
|
||||
"""
|
||||
# add gpu memory
|
||||
if self.trainer.on_gpu and self.trainer.log_gpu_memory:
|
||||
|
|
|
@ -125,7 +125,8 @@ class TrainerOptimizersMixin(ABC):
|
|||
if monitor is None:
|
||||
raise MisconfigurationException(
|
||||
'`configure_optimizers` must include a monitor when a `ReduceLROnPlateau` scheduler is used.'
|
||||
' For example: {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
|
||||
' For example:'
|
||||
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
|
||||
)
|
||||
lr_schedulers.append(
|
||||
{**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor}
|
||||
|
|
13
setup.cfg
13
setup.cfg
|
@ -65,13 +65,12 @@ verbose = 2
|
|||
# https://pep8.readthedocs.io/en/latest/intro.html#error-codes
|
||||
format = pylint
|
||||
ignore =
|
||||
E731
|
||||
W504
|
||||
F401
|
||||
E203 # E203 - whitespace before ':'. Opposite convention enforced by black
|
||||
E231 # E231: missing whitespace after ',', ';', or ':'; for black
|
||||
E501 # E501 - line too long. Handled by black, we have longer lines
|
||||
W503 # W503 - line break before binary operator, need for black
|
||||
E731 # do not assign a lambda expression, use a def
|
||||
W504 # line break occurred after a binary operator
|
||||
F401 # module imported but unused
|
||||
E203 # whitespace before ':'. Opposite convention enforced by black
|
||||
E231 # missing whitespace after ',', ';', or ':'; for black
|
||||
W503 # line break before binary operator, need for black
|
||||
|
||||
# setup.cfg or tox.ini
|
||||
[check-manifest]
|
||||
|
|
|
@ -73,7 +73,8 @@ class Discriminator(nn.Module):
|
|||
class BasicGAN(LightningModule):
|
||||
"""Implements a basic GAN for the purpose of illustrating multiple optimizers."""
|
||||
|
||||
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001, b1: float = 0.5, b2: float = 0.999, **kwargs):
|
||||
def __init__(self, hidden_dim: int = 128, learning_rate: float = 0.001,
|
||||
b1: float = 0.5, b2: float = 0.999, **kwargs):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.learning_rate = learning_rate
|
||||
|
|
|
@ -213,11 +213,14 @@ def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze, mi
|
|||
IF `min_steps` was set to a higher value than the `trainer.global_step` when `early_stopping` is being triggered,
|
||||
THEN the trainer should continue until reaching `trainer.global_step` == `min_steps`, and stop.
|
||||
|
||||
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step` when `early_stopping` is being triggered,
|
||||
THEN the trainer should continue until reaching `trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
|
||||
IF `min_epochs` resulted in a higher number of steps than the `trainer.global_step`
|
||||
when `early_stopping` is being triggered,
|
||||
THEN the trainer should continue until reaching
|
||||
`trainer.global_step` == `min_epochs * len(train_dataloader)`, and stop.
|
||||
This test validate this expected behaviour
|
||||
|
||||
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step` when `early_stopping` is being triggered,
|
||||
IF both `min_epochs` and `min_steps` are provided and higher than the `trainer.global_step`
|
||||
when `early_stopping` is being triggered,
|
||||
THEN the highest between `min_epochs * len(train_dataloader)` and `min_steps` would be reached.
|
||||
|
||||
Caviat: IF min_steps is divisible by len(train_dataloader), then it will do min_steps + len(train_dataloader)
|
||||
|
|
|
@ -269,7 +269,12 @@ def test_model_checkpoint_file_extension(tmpdir):
|
|||
"""
|
||||
|
||||
model = LogInTwoMethods()
|
||||
model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True)
|
||||
model_checkpoint = ModelCheckpointExtensionTest(
|
||||
monitor='early_stop_on',
|
||||
dirpath=tmpdir,
|
||||
save_top_k=1,
|
||||
save_last=True,
|
||||
)
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
callbacks=[model_checkpoint],
|
||||
|
|
|
@ -87,7 +87,8 @@ def test_result_reduce_ddp(result_cls):
|
|||
7, True, 0, id='write_dict_predictions'
|
||||
),
|
||||
pytest.param(
|
||||
0, True, 1, id='full_loop_single_gpu', marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine")
|
||||
0, True, 1, id='full_loop_single_gpu',
|
||||
marks=pytest.mark.skipif(torch.cuda.device_count() < 1, reason="test requires single-GPU machine")
|
||||
)
|
||||
]
|
||||
)
|
||||
|
|
|
@ -86,23 +86,14 @@ def _sk_fbeta_multidim_multiclass(preds, target, average='micro', beta=1.0):
|
|||
(_binary_inputs.preds, _binary_inputs.target, _sk_fbeta_binary, 1, False),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_fbeta_multilabel_prob, NUM_CLASSES, True),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
|
||||
(_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target, _sk_fbeta_multilabel, NUM_CLASSES, True),
|
||||
(_multilabel_inputs_no_match.preds, _multilabel_inputs_no_match.target,
|
||||
_sk_fbeta_multilabel, NUM_CLASSES, True),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_fbeta_multiclass_prob, NUM_CLASSES, False),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_fbeta_multiclass, NUM_CLASSES, False),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_sk_fbeta_multidim_multiclass_prob,
|
||||
NUM_CLASSES,
|
||||
False,
|
||||
),
|
||||
(
|
||||
_multidim_multiclass_inputs.preds,
|
||||
_multidim_multiclass_inputs.target,
|
||||
_sk_fbeta_multidim_multiclass,
|
||||
NUM_CLASSES,
|
||||
False,
|
||||
),
|
||||
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target,
|
||||
_sk_fbeta_multidim_multiclass_prob, NUM_CLASSES, False),
|
||||
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target,
|
||||
_sk_fbeta_multidim_multiclass, NUM_CLASSES, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("average", ['micro', 'macro', 'weighted', None])
|
||||
|
|
|
@ -85,24 +85,16 @@ def _sk_prec_recall_multidim_multiclass(preds, target, sk_fn=precision_score, av
|
|||
[
|
||||
(_binary_prob_inputs.preds, _binary_prob_inputs.target, _sk_prec_recall_binary_prob, 1, False),
|
||||
(_binary_inputs.preds, _binary_inputs.target, _sk_prec_recall_binary, 1, False),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target, _sk_prec_recall_multilabel_prob, NUM_CLASSES, True),
|
||||
(_multilabel_prob_inputs.preds, _multilabel_prob_inputs.target,
|
||||
_sk_prec_recall_multilabel_prob, NUM_CLASSES, True),
|
||||
(_multilabel_inputs.preds, _multilabel_inputs.target, _sk_prec_recall_multilabel, NUM_CLASSES, True),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target, _sk_prec_recall_multiclass_prob, NUM_CLASSES, False),
|
||||
(_multiclass_prob_inputs.preds, _multiclass_prob_inputs.target,
|
||||
_sk_prec_recall_multiclass_prob, NUM_CLASSES, False),
|
||||
(_multiclass_inputs.preds, _multiclass_inputs.target, _sk_prec_recall_multiclass, NUM_CLASSES, False),
|
||||
(
|
||||
_multidim_multiclass_prob_inputs.preds,
|
||||
_multidim_multiclass_prob_inputs.target,
|
||||
_sk_prec_recall_multidim_multiclass_prob,
|
||||
NUM_CLASSES,
|
||||
False,
|
||||
),
|
||||
(
|
||||
_multidim_multiclass_inputs.preds,
|
||||
_multidim_multiclass_inputs.target,
|
||||
_sk_prec_recall_multidim_multiclass,
|
||||
NUM_CLASSES,
|
||||
False,
|
||||
),
|
||||
(_multidim_multiclass_prob_inputs.preds, _multidim_multiclass_prob_inputs.target,
|
||||
_sk_prec_recall_multidim_multiclass_prob, NUM_CLASSES, False),
|
||||
(_multidim_multiclass_inputs.preds, _multidim_multiclass_inputs.target,
|
||||
_sk_prec_recall_multidim_multiclass, NUM_CLASSES, False),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
@ -153,7 +153,8 @@ def test_stat_scores(pred, target, expected_tp, expected_fp, expected_tn, expect
|
|||
pytest.param(to_onehot(torch.tensor([0., 2., 4., 4.])), torch.tensor([0., 4., 3., 4.]), 'elementwise_mean',
|
||||
torch.tensor(0.4), torch.tensor(0.4), torch.tensor(2.8), torch.tensor(0.4), torch.tensor(0.8))
|
||||
])
|
||||
def test_stat_scores_multiclass(pred, target, reduction, expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||
def test_stat_scores_multiclass(pred, target, reduction,
|
||||
expected_tp, expected_fp, expected_tn, expected_fn, expected_support):
|
||||
tp, fp, tn, fn, sup = stat_scores_multiple_classes(pred, target, reduction=reduction)
|
||||
|
||||
assert torch.allclose(torch.tensor(expected_tp).to(tp), tp)
|
||||
|
|
|
@ -520,7 +520,8 @@ def test_log_works_in_train_callback(tmpdir):
|
|||
def make_logging(self, pl_module: pl.LightningModule, func_name, func_idx,
|
||||
on_steps=[], on_epochs=[], prob_bars=[]):
|
||||
self.funcs_called_count[func_name] += 1
|
||||
for idx, (on_step, on_epoch, prog_bar) in enumerate(list(itertools.product(*[on_steps, on_epochs, prob_bars]))):
|
||||
iterate = list(itertools.product(*[on_steps, on_epochs, prob_bars]))
|
||||
for idx, (on_step, on_epoch, prog_bar) in enumerate(iterate):
|
||||
# run logging
|
||||
custom_func_name = f"{func_idx}_{idx}_{func_name}"
|
||||
pl_module.log(custom_func_name, self.count * func_idx, on_step=on_step,
|
||||
|
|
|
@ -942,7 +942,8 @@ def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, m
|
|||
@patch("torch.optim.Adam.step")
|
||||
@patch("torch.optim.SGD.step")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest")
|
||||
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
|
||||
reason="test should be run outside of pytest")
|
||||
def test_step_with_optimizer_closure_with_different_frequencies_ddp(mock_sgd_step, mock_adam_step, tmpdir):
|
||||
"""
|
||||
Tests that `step` works with optimizer_closure and different accumulated_gradient frequency
|
||||
|
|
|
@ -316,9 +316,11 @@ def test_gradient_accumulation_scheduling_last_batch(tmpdir, accumulate_grad_bat
|
|||
self.on_train_batch_start_end_dict = self.state_dict()
|
||||
for key in self.on_train_batch_start_end_dict.keys():
|
||||
if (batch_idx + 1) == self.trainer.num_training_batches:
|
||||
assert torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key])
|
||||
assert torch.equal(self.on_train_batch_start_state_dict[key],
|
||||
self.on_train_batch_start_end_dict[key])
|
||||
else:
|
||||
assert not torch.equal(self.on_train_batch_start_state_dict[key], self.on_train_batch_start_end_dict[key])
|
||||
assert not torch.equal(self.on_train_batch_start_state_dict[key],
|
||||
self.on_train_batch_start_end_dict[key])
|
||||
|
||||
model = CurrentModel()
|
||||
|
||||
|
|
|
@ -23,7 +23,8 @@ if _XLA_AVAILABLE:
|
|||
import torch_xla.core.xla_model as xm
|
||||
|
||||
|
||||
# lets hope that in or env we have installed XLA only for TPU devices, otherwise, it is testing in the cycle "if I am true test that I am true :D"
|
||||
# lets hope that in or env we have installed XLA only for TPU devices, otherwise,
|
||||
# it is testing in the cycle "if I am true test that I am true :D"
|
||||
@pytest.mark.skipif(_XLA_AVAILABLE, reason="test requires torch_xla to be absent")
|
||||
def test_tpu_device_absence():
|
||||
"""Check tpu_device_exists returns None when torch_xla is not available"""
|
||||
|
|
Loading…
Reference in New Issue