From 7e9e441843d345d0adf0dd172e760b62bf4631cd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 9 Sep 2022 21:04:57 +0100 Subject: [PATCH] Use TorchVision's Multi-weight Support and Model Registration API on Lightning (#14567) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- .../advanced/transfer_learning.rst | 2 +- examples/pl_basics/profiler_example.py | 4 ++-- .../computer_vision_fine_tuning.py | 6 +++--- examples/pl_domain_templates/imagenet.py | 8 ++++---- examples/pl_servable_module/production.py | 4 ++-- src/pytorch_lightning/utilities/imports.py | 4 ++-- .../utilities/model_helpers.py | 20 ++++++++++++++++++- .../tests_pytorch/helpers/advanced_models.py | 7 ++++--- tests/tests_pytorch/test_cli.py | 2 +- 9 files changed, 38 insertions(+), 19 deletions(-) diff --git a/docs/source-pytorch/advanced/transfer_learning.rst b/docs/source-pytorch/advanced/transfer_learning.rst index caa739bdfc..2d221cf6f7 100644 --- a/docs/source-pytorch/advanced/transfer_learning.rst +++ b/docs/source-pytorch/advanced/transfer_learning.rst @@ -62,7 +62,7 @@ Example: Imagenet (Computer Vision) super().__init__() # init a pretrained resnet - backbone = models.resnet50(pretrained=True) + backbone = models.resnet50(weights="DEFAULT") num_filters = backbone.fc.in_features layers = list(backbone.children())[:-1] self.feature_extractor = nn.Sequential(*layers) diff --git a/examples/pl_basics/profiler_example.py b/examples/pl_basics/profiler_example.py index 6df8f76997..39c147c938 100644 --- a/examples/pl_basics/profiler_example.py +++ b/examples/pl_basics/profiler_example.py @@ -27,12 +27,12 @@ from os import path import torch import torchvision -import torchvision.models as models import torchvision.transforms as T from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.cli import LightningCLI from pytorch_lightning.profilers.pytorch import PyTorchProfiler +from pytorch_lightning.utilities.model_helpers import get_torchvision_model DEFAULT_CMD_LINE = ( "fit", @@ -49,7 +49,7 @@ DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") class ModelToProfile(LightningModule): def __init__(self, name: str = "resnet18", automatic_optimization: bool = True): super().__init__() - self.model = getattr(models, name)(pretrained=True) + self.model = get_torchvision_model(name, weights="DEFAULT") self.criterion = torch.nn.CrossEntropyLoss() self.automatic_optimization = automatic_optimization self.training_step = ( diff --git a/examples/pl_domain_templates/computer_vision_fine_tuning.py b/examples/pl_domain_templates/computer_vision_fine_tuning.py index 7a81df9839..afcfa8f900 100644 --- a/examples/pl_domain_templates/computer_vision_fine_tuning.py +++ b/examples/pl_domain_templates/computer_vision_fine_tuning.py @@ -50,7 +50,7 @@ from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torchmetrics import Accuracy -from torchvision import models, transforms +from torchvision import transforms from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive @@ -58,6 +58,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, Lightning from pytorch_lightning.callbacks.finetuning import BaseFinetuning from pytorch_lightning.cli import LightningCLI from pytorch_lightning.utilities import rank_zero_info +from pytorch_lightning.utilities.model_helpers import get_torchvision_model log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" @@ -193,8 +194,7 @@ class TransferLearningModel(LightningModule): """Define model layers & loss.""" # 1. Load pre-trained network: - model_func = getattr(models, self.backbone) - backbone = model_func(pretrained=True) + backbone = get_torchvision_model(self.backbone, weights="DEFAULT") _layers = list(backbone.children())[:-1] self.feature_extractor = nn.Sequential(*_layers) diff --git a/examples/pl_domain_templates/imagenet.py b/examples/pl_domain_templates/imagenet.py index efb9c40eea..0a3b55d2a6 100644 --- a/examples/pl_domain_templates/imagenet.py +++ b/examples/pl_domain_templates/imagenet.py @@ -40,7 +40,6 @@ import torch.optim.lr_scheduler as lr_scheduler import torch.utils.data import torch.utils.data.distributed import torchvision.datasets as datasets -import torchvision.models as models import torchvision.transforms as transforms from torch.utils.data import Dataset from torchmetrics import Accuracy @@ -49,6 +48,7 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar from pytorch_lightning.cli import LightningCLI from pytorch_lightning.strategies import ParallelStrategy +from pytorch_lightning.utilities.model_helpers import get_torchvision_model class ImageNetLightningModel(LightningModule): @@ -63,7 +63,7 @@ class ImageNetLightningModel(LightningModule): self, data_path: str, arch: str = "resnet18", - pretrained: bool = False, + weights: Optional[str] = None, lr: float = 0.1, momentum: float = 0.9, weight_decay: float = 1e-4, @@ -72,14 +72,14 @@ class ImageNetLightningModel(LightningModule): ): super().__init__() self.arch = arch - self.pretrained = pretrained + self.weights = weights self.lr = lr self.momentum = momentum self.weight_decay = weight_decay self.data_path = data_path self.batch_size = batch_size self.workers = workers - self.model = models.__dict__[self.arch](pretrained=self.pretrained) + self.model = get_torchvision_model(self.arch, weights=self.weights) self.train_dataset: Optional[Dataset] = None self.eval_dataset: Optional[Dataset] = None self.train_acc1 = Accuracy(top_k=1) diff --git a/examples/pl_servable_module/production.py b/examples/pl_servable_module/production.py index 3ecd723764..4005fecb73 100644 --- a/examples/pl_servable_module/production.py +++ b/examples/pl_servable_module/production.py @@ -7,13 +7,13 @@ from typing import Dict, Optional import numpy as np import torch import torchvision -import torchvision.models as models import torchvision.transforms as T from PIL import Image as PILImage from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule from pytorch_lightning.cli import LightningCLI from pytorch_lightning.serve import ServableModule, ServableModuleValidator +from pytorch_lightning.utilities.model_helpers import get_torchvision_model DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") @@ -21,7 +21,7 @@ DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") class LitModule(LightningModule): def __init__(self, name: str = "resnet18"): super().__init__() - self.model = getattr(models, name)(pretrained=True) + self.model = get_torchvision_model(name, weights="DEFAULT") self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10) self.criterion = torch.nn.CrossEntropyLoss() diff --git a/src/pytorch_lightning/utilities/imports.py b/src/pytorch_lightning/utilities/imports.py index cbbfcc21dd..d870d0faab 100644 --- a/src/pytorch_lightning/utilities/imports.py +++ b/src/pytorch_lightning/utilities/imports.py @@ -17,7 +17,7 @@ import platform import sys import torch -from lightning_utilities.core.imports import compare_version, module_available, package_available +from lightning_utilities.core.imports import compare_version, module_available, package_available, RequirementCache _IS_WINDOWS = platform.system() == "Windows" _IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765 @@ -41,7 +41,7 @@ _POPTORCH_AVAILABLE = package_available("poptorch") _PSUTIL_AVAILABLE = package_available("psutil") _RICH_AVAILABLE = package_available("rich") and compare_version("rich", operator.ge, "10.2.2") _TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"]) -_TORCHVISION_AVAILABLE = package_available("torchvision") +_TORCHVISION_AVAILABLE = RequirementCache("torchvision") _XLA_AVAILABLE: bool = package_available("torch_xla") diff --git a/src/pytorch_lightning/utilities/model_helpers.py b/src/pytorch_lightning/utilities/model_helpers.py index 66ad264355..b72e9320b3 100644 --- a/src/pytorch_lightning/utilities/model_helpers.py +++ b/src/pytorch_lightning/utilities/model_helpers.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Optional, Type +from typing import Any, Optional, Type from unittest.mock import Mock +from lightning_utilities.core.imports import RequirementCache +from torch import nn + import pytorch_lightning as pl @@ -54,3 +57,18 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O raise ValueError("The parent should define the method") return instance_attr.__code__ != parent_attr.__code__ + + +def get_torchvision_model(model_name: str, **kwargs: Any) -> nn.Module: + from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE + + if not _TORCHVISION_AVAILABLE: + raise ModuleNotFoundError(str(_TORCHVISION_AVAILABLE)) + + from torchvision import models + + torchvision_greater_equal_0_14 = RequirementCache("torchvision>=0.14.0") + # TODO: deprecate this function when 0.14 is the minimum supported torchvision + if torchvision_greater_equal_0_14: + return models.get_model(model_name, **kwargs) + return getattr(models, model_name)(**kwargs) diff --git a/tests/tests_pytorch/helpers/advanced_models.py b/tests/tests_pytorch/helpers/advanced_models.py index a305fe04e6..4b8ce9c60e 100644 --- a/tests/tests_pytorch/helpers/advanced_models.py +++ b/tests/tests_pytorch/helpers/advanced_models.py @@ -20,11 +20,12 @@ from torch.utils.data import DataLoader from pytorch_lightning.core.module import LightningModule from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE +from pytorch_lightning.utilities.model_helpers import get_torchvision_model from tests_pytorch import _PATH_DATASETS from tests_pytorch.helpers.datasets import AverageDataset, MNIST, TrialMNIST if _TORCHVISION_AVAILABLE: - from torchvision import models, transforms + from torchvision import transforms from torchvision.datasets import CIFAR10 @@ -217,13 +218,13 @@ class ParityModuleMNIST(LightningModule): class ParityModuleCIFAR(LightningModule): - def __init__(self, backbone="resnet101", hidden_dim=1024, learning_rate=1e-3, pretrained=True): + def __init__(self, backbone="resnet101", hidden_dim=1024, learning_rate=1e-3, weights="DEFAULT"): super().__init__() self.save_hyperparameters() self.learning_rate = learning_rate self.num_classes = 10 - self.backbone = getattr(models, backbone)(pretrained=pretrained) + self.backbone = get_torchvision_model(backbone, weights=weights) self.classifier = torch.nn.Sequential( torch.nn.Linear(1000, hidden_dim), torch.nn.Linear(hidden_dim, self.num_classes) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 46fc7e9b62..4d6a609a00 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -520,7 +520,7 @@ def test_lightning_cli_submodules(tmpdir): assert isinstance(cli.model.submodule2, BoringModel) -@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="Tests a bug with torchvision, but it's not available") +@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE)) def test_lightning_cli_torch_modules(tmpdir): class TestModule(BoringModel): def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):