Use TorchVision's Multi-weight Support and Model Registration API on Lightning (#14567)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Vasilis Vryniotis 2022-09-09 21:04:57 +01:00 committed by GitHub
parent d5b32c3087
commit 7e9e441843
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 38 additions and 19 deletions

View File

@ -62,7 +62,7 @@ Example: Imagenet (Computer Vision)
super().__init__()
# init a pretrained resnet
backbone = models.resnet50(pretrained=True)
backbone = models.resnet50(weights="DEFAULT")
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = nn.Sequential(*layers)

View File

@ -27,12 +27,12 @@ from os import path
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as T
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.profilers.pytorch import PyTorchProfiler
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
DEFAULT_CMD_LINE = (
"fit",
@ -49,7 +49,7 @@ DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
class ModelToProfile(LightningModule):
def __init__(self, name: str = "resnet18", automatic_optimization: bool = True):
super().__init__()
self.model = getattr(models, name)(pretrained=True)
self.model = get_torchvision_model(name, weights="DEFAULT")
self.criterion = torch.nn.CrossEntropyLoss()
self.automatic_optimization = automatic_optimization
self.training_step = (

View File

@ -50,7 +50,7 @@ from torch.optim.lr_scheduler import MultiStepLR
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision import models, transforms
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_and_extract_archive
@ -58,6 +58,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, Lightning
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.utilities import rank_zero_info
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
log = logging.getLogger(__name__)
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
@ -193,8 +194,7 @@ class TransferLearningModel(LightningModule):
"""Define model layers & loss."""
# 1. Load pre-trained network:
model_func = getattr(models, self.backbone)
backbone = model_func(pretrained=True)
backbone = get_torchvision_model(self.backbone, weights="DEFAULT")
_layers = list(backbone.children())[:-1]
self.feature_extractor = nn.Sequential(*_layers)

View File

@ -40,7 +40,6 @@ import torch.optim.lr_scheduler as lr_scheduler
import torch.utils.data
import torch.utils.data.distributed
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from torchmetrics import Accuracy
@ -49,6 +48,7 @@ from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.strategies import ParallelStrategy
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
class ImageNetLightningModel(LightningModule):
@ -63,7 +63,7 @@ class ImageNetLightningModel(LightningModule):
self,
data_path: str,
arch: str = "resnet18",
pretrained: bool = False,
weights: Optional[str] = None,
lr: float = 0.1,
momentum: float = 0.9,
weight_decay: float = 1e-4,
@ -72,14 +72,14 @@ class ImageNetLightningModel(LightningModule):
):
super().__init__()
self.arch = arch
self.pretrained = pretrained
self.weights = weights
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.data_path = data_path
self.batch_size = batch_size
self.workers = workers
self.model = models.__dict__[self.arch](pretrained=self.pretrained)
self.model = get_torchvision_model(self.arch, weights=self.weights)
self.train_dataset: Optional[Dataset] = None
self.eval_dataset: Optional[Dataset] = None
self.train_acc1 = Accuracy(top_k=1)

View File

@ -7,13 +7,13 @@ from typing import Dict, Optional
import numpy as np
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image as PILImage
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.serve import ServableModule, ServableModuleValidator
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
@ -21,7 +21,7 @@ DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
class LitModule(LightningModule):
def __init__(self, name: str = "resnet18"):
super().__init__()
self.model = getattr(models, name)(pretrained=True)
self.model = get_torchvision_model(name, weights="DEFAULT")
self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10)
self.criterion = torch.nn.CrossEntropyLoss()

View File

@ -17,7 +17,7 @@ import platform
import sys
import torch
from lightning_utilities.core.imports import compare_version, module_available, package_available
from lightning_utilities.core.imports import compare_version, module_available, package_available, RequirementCache
_IS_WINDOWS = platform.system() == "Windows"
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
@ -41,7 +41,7 @@ _POPTORCH_AVAILABLE = package_available("poptorch")
_PSUTIL_AVAILABLE = package_available("psutil")
_RICH_AVAILABLE = package_available("rich") and compare_version("rich", operator.ge, "10.2.2")
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
_TORCHVISION_AVAILABLE = package_available("torchvision")
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
_XLA_AVAILABLE: bool = package_available("torch_xla")

View File

@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional, Type
from typing import Any, Optional, Type
from unittest.mock import Mock
from lightning_utilities.core.imports import RequirementCache
from torch import nn
import pytorch_lightning as pl
@ -54,3 +57,18 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O
raise ValueError("The parent should define the method")
return instance_attr.__code__ != parent_attr.__code__
def get_torchvision_model(model_name: str, **kwargs: Any) -> nn.Module:
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError(str(_TORCHVISION_AVAILABLE))
from torchvision import models
torchvision_greater_equal_0_14 = RequirementCache("torchvision>=0.14.0")
# TODO: deprecate this function when 0.14 is the minimum supported torchvision
if torchvision_greater_equal_0_14:
return models.get_model(model_name, **kwargs)
return getattr(models, model_name)(**kwargs)

View File

@ -20,11 +20,12 @@ from torch.utils.data import DataLoader
from pytorch_lightning.core.module import LightningModule
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
from tests_pytorch import _PATH_DATASETS
from tests_pytorch.helpers.datasets import AverageDataset, MNIST, TrialMNIST
if _TORCHVISION_AVAILABLE:
from torchvision import models, transforms
from torchvision import transforms
from torchvision.datasets import CIFAR10
@ -217,13 +218,13 @@ class ParityModuleMNIST(LightningModule):
class ParityModuleCIFAR(LightningModule):
def __init__(self, backbone="resnet101", hidden_dim=1024, learning_rate=1e-3, pretrained=True):
def __init__(self, backbone="resnet101", hidden_dim=1024, learning_rate=1e-3, weights="DEFAULT"):
super().__init__()
self.save_hyperparameters()
self.learning_rate = learning_rate
self.num_classes = 10
self.backbone = getattr(models, backbone)(pretrained=pretrained)
self.backbone = get_torchvision_model(backbone, weights=weights)
self.classifier = torch.nn.Sequential(
torch.nn.Linear(1000, hidden_dim), torch.nn.Linear(hidden_dim, self.num_classes)

View File

@ -520,7 +520,7 @@ def test_lightning_cli_submodules(tmpdir):
assert isinstance(cli.model.submodule2, BoringModel)
@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="Tests a bug with torchvision, but it's not available")
@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE))
def test_lightning_cli_torch_modules(tmpdir):
class TestModule(BoringModel):
def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):