Support pathlib.Path file paths when saving ONNX models (#19727)

Co-authored-by: dominicgkerr <dominicgkerr1@gmail.co>
This commit is contained in:
Dominic Kerr 2024-04-04 01:42:25 +01:00 committed by GitHub
parent ce88483c6f
commit 76b691d80c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 5 deletions

View File

@ -1395,7 +1395,7 @@ class LightningModule(
input_sample = self._on_before_batch_transfer(input_sample)
input_sample = self._apply_batch_transfer_handler(input_sample)
torch.onnx.export(self, input_sample, file_path, **kwargs)
torch.onnx.export(self, input_sample, str(file_path), **kwargs)
self.train(mode)
@torch.no_grad()

View File

@ -13,6 +13,7 @@
# limitations under the License.
import operator
import os
from pathlib import Path
from unittest.mock import patch
import numpy as np
@ -32,11 +33,14 @@ from tests_pytorch.utilities.test_model_summary import UnorderedModel
def test_model_saves_with_input_sample(tmp_path):
"""Test that ONNX model saves with input sample and size is greater than 3 MB."""
model = BoringModel()
trainer = Trainer(fast_dev_run=True)
trainer.fit(model)
file_path = os.path.join(tmp_path, "model.onnx")
input_sample = torch.randn((1, 32))
file_path = os.path.join(tmp_path, "os.path.onnx")
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 4e2
file_path = Path(tmp_path) / "pathlib.onnx"
model.to_onnx(file_path, input_sample)
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 4e2