Docs/fixes (#5914)
* wip * .. * ... * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
e8190e8848
commit
9475c845cb
2
Makefile
2
Makefile
|
@ -3,7 +3,7 @@
|
|||
# to imitate SLURM set only single node
|
||||
export SLURM_LOCALID=0
|
||||
# assume you have installed need packages
|
||||
export SPHINX_MOCK_REQUIREMENTS=0
|
||||
export SPHINX_MOCK_REQUIREMENTS=1
|
||||
|
||||
clean:
|
||||
# clean all temp runs
|
||||
|
|
|
@ -1,3 +0,0 @@
|
|||
rm -rf source/generated
|
||||
make clean
|
||||
make html --debug --jobs 2 SPHINXOPTS="-W"
|
|
@ -23,10 +23,10 @@ from torch import nn
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6_0, rank_zero_warn
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6, rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_6_0:
|
||||
if _TORCH_GREATER_EQUAL_1_6:
|
||||
from torch.optim.swa_utils import SWALR
|
||||
|
||||
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
|
||||
|
|
|
@ -36,7 +36,7 @@ from pytorch_lightning.utilities.imports import ( # noqa: F401
|
|||
_NATIVE_AMP_AVAILABLE,
|
||||
_OMEGACONF_AVAILABLE,
|
||||
_RPC_AVAILABLE,
|
||||
_TORCH_GREATER_EQUAL_1_6_0,
|
||||
_TORCH_GREATER_EQUAL_1_6,
|
||||
_TORCHTEXT_AVAILABLE,
|
||||
_TORCHVISION_AVAILABLE,
|
||||
_XLA_AVAILABLE,
|
||||
|
|
|
@ -12,12 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""General utilities"""
|
||||
import operator
|
||||
import platform
|
||||
from distutils.version import LooseVersion
|
||||
from importlib.util import find_spec
|
||||
|
||||
import pkg_resources
|
||||
import torch
|
||||
from pkg_resources import DistributionNotFound, get_distribution
|
||||
|
||||
|
||||
def _module_available(module_path: str) -> bool:
|
||||
|
@ -39,19 +40,21 @@ def _module_available(module_path: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _get_version(package: str) -> LooseVersion:
|
||||
return LooseVersion(pkg_resources.get_distribution(package).version)
|
||||
def _compare_version(package: str, op, version) -> bool:
|
||||
try:
|
||||
pkg_version = LooseVersion(get_distribution(package).version)
|
||||
return op(pkg_version, LooseVersion(version))
|
||||
except DistributionNotFound:
|
||||
return False
|
||||
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_TORCH_GREATER_EQUAL_1_6_0 = _get_version("torch") >= LooseVersion("1.6.0")
|
||||
_TORCH_GREATER_EQUAL_1_6 = _compare_version("torch", operator.ge, "1.6.0")
|
||||
|
||||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel')
|
||||
_FAIRSCALE_PIPE_AVAILABLE = (
|
||||
_FAIRSCALE_AVAILABLE and _TORCH_GREATER_EQUAL_1_6_0 and _get_version('fairscale') <= LooseVersion("0.1.3")
|
||||
)
|
||||
_FAIRSCALE_PIPE_AVAILABLE = _TORCH_GREATER_EQUAL_1_6 and _compare_version("fairscale", operator.le, "0.1.3")
|
||||
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
|
||||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
_HYDRA_AVAILABLE = _module_available("hydra")
|
||||
|
|
|
@ -21,11 +21,11 @@ from torch import nn
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6_0
|
||||
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_6
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers import BoringModel, RandomDataset
|
||||
|
||||
if _TORCH_GREATER_EQUAL_1_6_0:
|
||||
if _TORCH_GREATER_EQUAL_1_6:
|
||||
from pytorch_lightning.callbacks import StochasticWeightAveraging
|
||||
|
||||
class SwaTestModel(BoringModel):
|
||||
|
@ -114,7 +114,7 @@ def train_with_swa(tmpdir, batchnorm=True, accelerator=None, gpus=None, num_proc
|
|||
assert trainer.get_model() == model
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1', reason="test should be run outside of pytest"
|
||||
|
@ -123,31 +123,31 @@ def test_swa_callback_ddp(tmpdir):
|
|||
train_with_swa(tmpdir, accelerator="ddp", gpus=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
def test_swa_callback_ddp_spawn(tmpdir):
|
||||
train_with_swa(tmpdir, accelerator="ddp_spawn", gpus=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="ddp_cpu is not available on Windows")
|
||||
def test_swa_callback_ddp_cpu(tmpdir):
|
||||
train_with_swa(tmpdir, accelerator="ddp_cpu", num_processes=2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires a GPU machine")
|
||||
def test_swa_callback_1_gpu(tmpdir):
|
||||
train_with_swa(tmpdir, gpus=1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.parametrize("batchnorm", (True, False))
|
||||
def test_swa_callback(tmpdir, batchnorm):
|
||||
train_with_swa(tmpdir, batchnorm=batchnorm)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6_0, reason="SWA available from PyTorch 1.6.0")
|
||||
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_6, reason="SWA available from PyTorch 1.6.0")
|
||||
def test_swa_raises():
|
||||
with pytest.raises(MisconfigurationException, match=">0 integer or a float between 0 and 1"):
|
||||
StochasticWeightAveraging(swa_epoch_start=0, swa_lrs=0.1)
|
||||
|
|
Loading…
Reference in New Issue