Support pathlib.Path file paths when saving ONNX models (#19727)
Co-authored-by: dominicgkerr <dominicgkerr1@gmail.co>
This commit is contained in:
parent
ce88483c6f
commit
76b691d80c
|
@ -1395,7 +1395,7 @@ class LightningModule(
|
|||
input_sample = self._on_before_batch_transfer(input_sample)
|
||||
input_sample = self._apply_batch_transfer_handler(input_sample)
|
||||
|
||||
torch.onnx.export(self, input_sample, file_path, **kwargs)
|
||||
torch.onnx.export(self, input_sample, str(file_path), **kwargs)
|
||||
self.train(mode)
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import operator
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
|
@ -32,11 +33,14 @@ from tests_pytorch.utilities.test_model_summary import UnorderedModel
|
|||
def test_model_saves_with_input_sample(tmp_path):
|
||||
"""Test that ONNX model saves with input sample and size is greater than 3 MB."""
|
||||
model = BoringModel()
|
||||
trainer = Trainer(fast_dev_run=True)
|
||||
trainer.fit(model)
|
||||
|
||||
file_path = os.path.join(tmp_path, "model.onnx")
|
||||
input_sample = torch.randn((1, 32))
|
||||
|
||||
file_path = os.path.join(tmp_path, "os.path.onnx")
|
||||
model.to_onnx(file_path, input_sample)
|
||||
assert os.path.isfile(file_path)
|
||||
assert os.path.getsize(file_path) > 4e2
|
||||
|
||||
file_path = Path(tmp_path) / "pathlib.onnx"
|
||||
model.to_onnx(file_path, input_sample)
|
||||
assert os.path.isfile(file_path)
|
||||
assert os.path.getsize(file_path) > 4e2
|
||||
|
|
Loading…
Reference in New Issue