Docs/fixes (#5914)

* wip

* ..

* ...

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
Jirka Borovec 2021-02-11 11:22:07 +01:00 committed by GitHub
parent e8190e8848
commit 9475c845cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 22 additions and 22 deletions

View File

@ -3,7 +3,7 @@
# to imitate SLURM set only single node
export SLURM_LOCALID=0
# assume you have installed need packages
export SPHINX_MOCK_REQUIREMENTS=0
export SPHINX_MOCK_REQUIREMENTS=1
clean:
# clean all temp runs

View File

@ -1,3 +0,0 @@
rm -rf source/generated
make clean
make html --debug --jobs 2 SPHINXOPTS="-W"

View File

@ -23,10 +23,10 @@ from torch import nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6_0, rank_zero_warn
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _TORCH_GREATER_EQUAL_1_6_0:
if _TORCH_GREATER_EQUAL_1_6:
from torch.optim.swa_utils import SWALR
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]

View File

@ -36,7 +36,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
_NATIVE_AMP_AVAILABLE,
_OMEGACONF_AVAILABLE,
_RPC_AVAILABLE,
_TORCH_GREATER_EQUAL_1_6_0,
_TORCH_GREATER_EQUAL_1_6,
_TORCHTEXT_AVAILABLE,
_TORCHVISION_AVAILABLE,
_XLA_AVAILABLE,

View File

@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import operator
import platform
from distutils.version import LooseVersion
from importlib.util import find_spec
import pkg_resources
import torch
from pkg_resources import DistributionNotFound, get_distribution
def _module_available(module_path: str) -> bool:
@ -39,19 +40,21 @@ def _module_available(module_path: str) -> bool:
return False
def _get_version(package: str) -> LooseVersion:
return LooseVersion(pkg_resources.get_distribution(package).version)
def _compare_version(package: str, op, version) -> bool:
try:
pkg_version = LooseVersion(get_distribution(package).version)
return op(pkg_version, LooseVersion(version))
except DistributionNotFound:
return False
_IS_WINDOWS = platform.system() == "Windows"
_TORCH_GREATER_EQUAL_1_6_0 = _get_version("torch") >= LooseVersion("1.6.0")
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
_APEX_AVAILABLE = _module_available("apex.amp")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel')
_FAIRSCALE_PIPE_AVAILABLE = (
_FAIRSCALE_AVAILABLE and _TORCH_GREATER_EQUAL_1_6_0 and _get_version('fairscale') <= LooseVersion("0.1.3")
)
_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3")
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _module_available("hydra")

View File

@ -21,11 +21,11 @@ from torch import nn
from torch.utils.data import DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6_0
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
if _TORCH_GREATER_EQUAL_1_6_0:
if _TORCH_GREATER_EQUAL_1_6:
from pytorch_lightning.callbacks import StochasticWeightAveraging
class SwaTestModel(BoringModel):
@ -114,7 +114,7 @@ def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_proc
assert trainer.get_model() == model
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
@ -123,31 +123,31 @@ def test_swa_callback_ddp(tmpdir):
train_with_swa(tmpdir, accelerator="ddp", gpus=2)
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_swa_callback_ddp_spawn(tmpdir):
train_with_swa(tmpdir, accelerator="ddp_spawn", gpus=2)
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(platform.system() == "Windows", reason="ddp_cpu is not available on Windows")
def test_swa_callback_ddp_cpu(tmpdir):
train_with_swa(tmpdir, accelerator="ddp_cpu", num_processes=2)
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU machine")
def test_swa_callback_1_gpu(tmpdir):
train_with_swa(tmpdir, gpus=1)
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.parametrize("batchnorm", (True, False))
def test_swa_callback(tmpdir, batchnorm):
train_with_swa(tmpdir, batchnorm=batchnorm)
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
def test_swa_raises():
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1)