[feat] Support custom filesystems in LightningModule.to_torchscript (#7617)

* [feat] Support custom filesystems in LightningModule.to_torchscript

* Update CHANGELOG.md

* Update test_torchscript.py

* Update test_torchscript.py

* Update CHANGELOG.md

* Update test_torchscript.py
This commit is contained in:
ananthsub 2021-05-21 04:23:15 -07:00 committed by GitHub
parent e8a46bee15
commit f6d892ac21
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 1 deletions

View File

@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added
- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))
- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow

View File

@ -40,6 +40,7 @@ from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITI
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
@ -1860,7 +1861,9 @@ class LightningModule(
self.train(mode)
if file_path is not None:
torch.jit.save(torchscript_module, file_path)
fs = get_filesystem(file_path)
with fs.open(file_path, "wb") as f:
torch.jit.save(torchscript_module, f)
return torchscript_module

View File

@ -12,9 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import fsspec
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from pytorch_lightning.utilities.cloud_io import get_filesystem
from tests.helpers import BoringModel
from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN
from tests.helpers.datamodules import MNISTDataModule
@ -139,6 +144,34 @@ def test_torchscript_save_load(tmpdir, modelclass):
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
@pytest.mark.parametrize("modelclass", [
BoringModel,
ParityModuleRNN,
BasicGAN,
])
@RunIf(min_torch="1.5.0")
def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass):
""" Test that scripted LightningModule is correctly saved and can be loaded with custom filesystems. """
_DUMMY_PRFEIX = "dummy"
_PREFIX_SEPARATOR = "://"
class DummyFileSystem(LocalFileSystem):
...
fsspec.register_implementation(_DUMMY_PRFEIX, DummyFileSystem, clobber=True)
model = modelclass()
output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmpdir, "model.pt")
script = model.to_torchscript(file_path=output_file)
fs = get_filesystem(output_file)
with fs.open(output_file, "rb") as f:
loaded_script = torch.jit.load(f)
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
def test_torchcript_invalid_method(tmpdir):
"""Test that an error is thrown with invalid torchscript method"""
model = BoringModel()