[feat] Support custom filesystems in LightningModule.to_torchscript (#7617)
* [feat] Support custom filesystems in LightningModule.to_torchscript * Update CHANGELOG.md * Update test_torchscript.py * Update test_torchscript.py * Update CHANGELOG.md * Update test_torchscript.py
This commit is contained in:
parent
e8a46bee15
commit
f6d892ac21
|
@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Added
|
||||
|
||||
- Added support to `LightningModule.to_torchscript` for saving to custom filesystems with fsspec ([#7617](https://github.com/PyTorchLightning/pytorch-lightning/pull/7617))
|
||||
|
||||
|
||||
- Added `KubeflowEnvironment` for use with the `PyTorchJob` operator in Kubeflow
|
||||
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ from pytorch_lightning.core.saving import ALLOWED_CONFIG_TYPES, ModelIO, PRIMITI
|
|||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
|
||||
|
@ -1860,7 +1861,9 @@ class LightningModule(
|
|||
self.train(mode)
|
||||
|
||||
if file_path is not None:
|
||||
torch.jit.save(torchscript_module, file_path)
|
||||
fs = get_filesystem(file_path)
|
||||
with fs.open(file_path, "wb") as f:
|
||||
torch.jit.save(torchscript_module, f)
|
||||
|
||||
return torchscript_module
|
||||
|
||||
|
|
|
@ -12,9 +12,14 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
|
||||
import fsspec
|
||||
import pytest
|
||||
import torch
|
||||
from fsspec.implementations.local import LocalFileSystem
|
||||
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.advanced_models import BasicGAN, ParityModuleRNN
|
||||
from tests.helpers.datamodules import MNISTDataModule
|
||||
|
@ -139,6 +144,34 @@ def test_torchscript_save_load(tmpdir, modelclass):
|
|||
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("modelclass", [
|
||||
BoringModel,
|
||||
ParityModuleRNN,
|
||||
BasicGAN,
|
||||
])
|
||||
@RunIf(min_torch="1.5.0")
|
||||
def test_torchscript_save_load_custom_filesystem(tmpdir, modelclass):
|
||||
""" Test that scripted LightningModule is correctly saved and can be loaded with custom filesystems. """
|
||||
|
||||
_DUMMY_PRFEIX = "dummy"
|
||||
_PREFIX_SEPARATOR = "://"
|
||||
|
||||
class DummyFileSystem(LocalFileSystem):
|
||||
...
|
||||
|
||||
fsspec.register_implementation(_DUMMY_PRFEIX, DummyFileSystem, clobber=True)
|
||||
|
||||
model = modelclass()
|
||||
output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmpdir, "model.pt")
|
||||
script = model.to_torchscript(file_path=output_file)
|
||||
|
||||
fs = get_filesystem(output_file)
|
||||
with fs.open(output_file, "rb") as f:
|
||||
loaded_script = torch.jit.load(f)
|
||||
|
||||
assert torch.allclose(next(script.parameters()), next(loaded_script.parameters()))
|
||||
|
||||
|
||||
def test_torchcript_invalid_method(tmpdir):
|
||||
"""Test that an error is thrown with invalid torchscript method"""
|
||||
model = BoringModel()
|
||||
|
|
Loading…
Reference in New Issue