Use TorchVision's Multi-weight Support and Model Registration API on Lightning (#14567)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
d5b32c3087
commit
7e9e441843
|
@ -62,7 +62,7 @@ Example: Imagenet (Computer Vision)
|
|||
super().__init__()
|
||||
|
||||
# init a pretrained resnet
|
||||
backbone = models.resnet50(pretrained=True)
|
||||
backbone = models.resnet50(weights="DEFAULT")
|
||||
num_filters = backbone.fc.in_features
|
||||
layers = list(backbone.children())[:-1]
|
||||
self.feature_extractor = nn.Sequential(*layers)
|
||||
|
|
|
@ -27,12 +27,12 @@ from os import path
|
|||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as T
|
||||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.profilers.pytorch import PyTorchProfiler
|
||||
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
|
||||
|
||||
DEFAULT_CMD_LINE = (
|
||||
"fit",
|
||||
|
@ -49,7 +49,7 @@ DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
|||
class ModelToProfile(LightningModule):
|
||||
def __init__(self, name: str = "resnet18", automatic_optimization: bool = True):
|
||||
super().__init__()
|
||||
self.model = getattr(models, name)(pretrained=True)
|
||||
self.model = get_torchvision_model(name, weights="DEFAULT")
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
self.automatic_optimization = automatic_optimization
|
||||
self.training_step = (
|
||||
|
|
|
@ -50,7 +50,7 @@ from torch.optim.lr_scheduler import MultiStepLR
|
|||
from torch.optim.optimizer import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from torchmetrics import Accuracy
|
||||
from torchvision import models, transforms
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.datasets.utils import download_and_extract_archive
|
||||
|
||||
|
@ -58,6 +58,7 @@ from pytorch_lightning import cli_lightning_logo, LightningDataModule, Lightning
|
|||
from pytorch_lightning.callbacks.finetuning import BaseFinetuning
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.utilities import rank_zero_info
|
||||
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip"
|
||||
|
@ -193,8 +194,7 @@ class TransferLearningModel(LightningModule):
|
|||
"""Define model layers & loss."""
|
||||
|
||||
# 1. Load pre-trained network:
|
||||
model_func = getattr(models, self.backbone)
|
||||
backbone = model_func(pretrained=True)
|
||||
backbone = get_torchvision_model(self.backbone, weights="DEFAULT")
|
||||
|
||||
_layers = list(backbone.children())[:-1]
|
||||
self.feature_extractor = nn.Sequential(*_layers)
|
||||
|
|
|
@ -40,7 +40,6 @@ import torch.optim.lr_scheduler as lr_scheduler
|
|||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset
|
||||
from torchmetrics import Accuracy
|
||||
|
@ -49,6 +48,7 @@ from pytorch_lightning import LightningModule
|
|||
from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.strategies import ParallelStrategy
|
||||
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
|
||||
|
||||
|
||||
class ImageNetLightningModel(LightningModule):
|
||||
|
@ -63,7 +63,7 @@ class ImageNetLightningModel(LightningModule):
|
|||
self,
|
||||
data_path: str,
|
||||
arch: str = "resnet18",
|
||||
pretrained: bool = False,
|
||||
weights: Optional[str] = None,
|
||||
lr: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
weight_decay: float = 1e-4,
|
||||
|
@ -72,14 +72,14 @@ class ImageNetLightningModel(LightningModule):
|
|||
):
|
||||
super().__init__()
|
||||
self.arch = arch
|
||||
self.pretrained = pretrained
|
||||
self.weights = weights
|
||||
self.lr = lr
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.data_path = data_path
|
||||
self.batch_size = batch_size
|
||||
self.workers = workers
|
||||
self.model = models.__dict__[self.arch](pretrained=self.pretrained)
|
||||
self.model = get_torchvision_model(self.arch, weights=self.weights)
|
||||
self.train_dataset: Optional[Dataset] = None
|
||||
self.eval_dataset: Optional[Dataset] = None
|
||||
self.train_acc1 = Accuracy(top_k=1)
|
||||
|
|
|
@ -7,13 +7,13 @@ from typing import Dict, Optional
|
|||
import numpy as np
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.models as models
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image as PILImage
|
||||
|
||||
from pytorch_lightning import cli_lightning_logo, LightningDataModule, LightningModule
|
||||
from pytorch_lightning.cli import LightningCLI
|
||||
from pytorch_lightning.serve import ServableModule, ServableModuleValidator
|
||||
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
|
||||
|
||||
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
||||
|
||||
|
@ -21,7 +21,7 @@ DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
|
|||
class LitModule(LightningModule):
|
||||
def __init__(self, name: str = "resnet18"):
|
||||
super().__init__()
|
||||
self.model = getattr(models, name)(pretrained=True)
|
||||
self.model = get_torchvision_model(name, weights="DEFAULT")
|
||||
self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10)
|
||||
self.criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ import platform
|
|||
import sys
|
||||
|
||||
import torch
|
||||
from lightning_utilities.core.imports import compare_version, module_available, package_available
|
||||
from lightning_utilities.core.imports import compare_version, module_available, package_available, RequirementCache
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_IS_INTERACTIVE = hasattr(sys, "ps1") # https://stackoverflow.com/a/64523765
|
||||
|
@ -41,7 +41,7 @@ _POPTORCH_AVAILABLE = package_available("poptorch")
|
|||
_PSUTIL_AVAILABLE = package_available("psutil")
|
||||
_RICH_AVAILABLE = package_available("rich") and compare_version("rich", operator.ge, "10.2.2")
|
||||
_TORCH_QUANTIZE_AVAILABLE = bool([eg for eg in torch.backends.quantized.supported_engines if eg != "none"])
|
||||
_TORCHVISION_AVAILABLE = package_available("torchvision")
|
||||
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")
|
||||
_XLA_AVAILABLE: bool = package_available("torch_xla")
|
||||
|
||||
|
||||
|
|
|
@ -12,9 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from typing import Optional, Type
|
||||
from typing import Any, Optional, Type
|
||||
from unittest.mock import Mock
|
||||
|
||||
from lightning_utilities.core.imports import RequirementCache
|
||||
from torch import nn
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
|
||||
|
@ -54,3 +57,18 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O
|
|||
raise ValueError("The parent should define the method")
|
||||
|
||||
return instance_attr.__code__ != parent_attr.__code__
|
||||
|
||||
|
||||
def get_torchvision_model(model_name: str, **kwargs: Any) -> nn.Module:
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
|
||||
if not _TORCHVISION_AVAILABLE:
|
||||
raise ModuleNotFoundError(str(_TORCHVISION_AVAILABLE))
|
||||
|
||||
from torchvision import models
|
||||
|
||||
torchvision_greater_equal_0_14 = RequirementCache("torchvision>=0.14.0")
|
||||
# TODO: deprecate this function when 0.14 is the minimum supported torchvision
|
||||
if torchvision_greater_equal_0_14:
|
||||
return models.get_model(model_name, **kwargs)
|
||||
return getattr(models, model_name)(**kwargs)
|
||||
|
|
|
@ -20,11 +20,12 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from pytorch_lightning.core.module import LightningModule
|
||||
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
|
||||
from pytorch_lightning.utilities.model_helpers import get_torchvision_model
|
||||
from tests_pytorch import _PATH_DATASETS
|
||||
from tests_pytorch.helpers.datasets import AverageDataset, MNIST, TrialMNIST
|
||||
|
||||
if _TORCHVISION_AVAILABLE:
|
||||
from torchvision import models, transforms
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
|
||||
|
||||
|
@ -217,13 +218,13 @@ class ParityModuleMNIST(LightningModule):
|
|||
|
||||
|
||||
class ParityModuleCIFAR(LightningModule):
|
||||
def __init__(self, backbone="resnet101", hidden_dim=1024, learning_rate=1e-3, pretrained=True):
|
||||
def __init__(self, backbone="resnet101", hidden_dim=1024, learning_rate=1e-3, weights="DEFAULT"):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.learning_rate = learning_rate
|
||||
self.num_classes = 10
|
||||
self.backbone = getattr(models, backbone)(pretrained=pretrained)
|
||||
self.backbone = get_torchvision_model(backbone, weights=weights)
|
||||
|
||||
self.classifier = torch.nn.Sequential(
|
||||
torch.nn.Linear(1000, hidden_dim), torch.nn.Linear(hidden_dim, self.num_classes)
|
||||
|
|
|
@ -520,7 +520,7 @@ def test_lightning_cli_submodules(tmpdir):
|
|||
assert isinstance(cli.model.submodule2, BoringModel)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason="Tests a bug with torchvision, but it's not available")
|
||||
@pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE))
|
||||
def test_lightning_cli_torch_modules(tmpdir):
|
||||
class TestModule(BoringModel):
|
||||
def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None):
|
||||
|
|
Loading…
Reference in New Issue