Re-enable passing BytesIO as path in `.to_onnx()` (#20172)

This commit is contained in:
GdoongMathew 2024-08-07 23:07:02 +08:00 committed by GitHub
parent be0ae06596
commit 828fd99896
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 2 deletions

View File

@ -17,6 +17,7 @@ import logging
import numbers
import weakref
from contextlib import contextmanager
from io import BytesIO
from pathlib import Path
from typing import (
IO,
@ -1364,7 +1365,7 @@ class LightningModule(
)
@torch.no_grad()
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
"""Saves the model in ONNX format.
Args:
@ -1403,7 +1404,8 @@ class LightningModule(
input_sample = self._on_before_batch_transfer(input_sample)
input_sample = self._apply_batch_transfer_handler(input_sample)
torch.onnx.export(self, input_sample, str(file_path), **kwargs)
file_path = str(file_path) if isinstance(file_path, Path) else file_path
torch.onnx.export(self, input_sample, file_path, **kwargs)
self.train(mode)
@torch.no_grad()

View File

@ -13,6 +13,7 @@
# limitations under the License.
import operator
import os
from io import BytesIO
from pathlib import Path
from unittest.mock import patch
@ -45,6 +46,10 @@ def test_model_saves_with_input_sample(tmp_path):
assert os.path.isfile(file_path)
assert os.path.getsize(file_path) > 4e2
file_path = BytesIO()
model.to_onnx(file_path=file_path, input_sample=input_sample)
assert len(file_path.getvalue()) > 4e2
@pytest.mark.parametrize(
"accelerator", [pytest.param("mps", marks=RunIf(mps=True)), pytest.param("gpu", marks=RunIf(min_cuda_gpus=True))]