From f6d892ac213ce3562b7f01779d86c7ddbde47fba Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 21 May 2021 04:23:15 -0700 Subject: [PATCH] [feat] Support custom filesystems in LightningModule.to_torchscript (#7617) * [feat] Support custom filesystems in LightningModule.to_torchscript * Update CHANGELOG.md * Update test_torchscript.py * Update test_torchscript.py * Update CHANGELOG.md * Update test_torchscript.py --- CHANGELOG.md | 3 +++ pytorch_lightning/core/lightning.py | 5 ++++- tests/models/test_torchscript.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8696f739d4..67e6e43c7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617)) + + - Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 79bdd38080..b974f57741 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -40,6 +40,7 @@ from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITI from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors +from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters @@ -1860,7 +1861,9 @@ class LightningModule( self.train(mode) if file_path is not None: - torch.jit.save(torchscript_module, file_path) + fs = get_filesystem(file_path) + with fs.open(file_path, "wb") as f: + torch.jit.save(torchscript_module, f) return torchscript_module diff --git a/tests/models/test_torchscript.py b/tests/models/test_torchscript.py index b03ed0806d..fab433688f 100644 --- a/tests/models/test_torchscript.py +++ b/tests/models/test_torchscript.py @@ -12,9 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + +import fsspec import pytest import torch +from fsspec.implementations.local import LocalFileSystem +from pytorch_lightning.utilities.cloud_io import get_filesystem from tests.helpers import BoringModel from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN from tests.helpers.datamodules import MNISTDataModule @@ -139,6 +144,34 @@ def test_torchscript_save_load(tmpdir, modelclass): assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) +@pytest.mark.parametrize("modelclass", [ + BoringModel, + ParityModuleRNN, + BasicGAN, +]) +@RunIf(min_torch="1.5.0") +def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass): + """ Test that scripted LightningModule is correctly saved and can be loaded with custom filesystems. """ + + _DUMMY_PRFEIX = "dummy" + _PREFIX_SEPARATOR = "://" + + class DummyFileSystem(LocalFileSystem): + ... + + fsspec.register_implementation(_DUMMY_PRFEIX, DummyFileSystem, clobber=True) + + model = modelclass() + output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmpdir, "model.pt") + script = model.to_torchscript(file_path=output_file) + + fs = get_filesystem(output_file) + with fs.open(output_file, "rb") as f: + loaded_script = torch.jit.load(f) + + assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) + + def test_torchcript_invalid_method(tmpdir): """Test that an error is thrown with invalid torchscript method""" model = BoringModel()