From 828fd998961f6a60f92c35254bb94d6e049ad069 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 7 Aug 2024 23:07:02 +0800 Subject: [PATCH] Re-enable passing BytesIO as path in `.to_onnx()` (#20172) --- src/lightning/pytorch/core/module.py | 6 ++++-- tests/tests_pytorch/models/test_onnx.py | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 647f6e6e41..782fc40d92 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -17,6 +17,7 @@ import logging import numbers import weakref from contextlib import contextmanager +from io import BytesIO from pathlib import Path from typing import ( IO, @@ -1364,7 +1365,7 @@ class LightningModule( ) @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None: + def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None: """Saves the model in ONNX format. Args: @@ -1403,7 +1404,8 @@ class LightningModule( input_sample = self._on_before_batch_transfer(input_sample) input_sample = self._apply_batch_transfer_handler(input_sample) - torch.onnx.export(self, input_sample, str(file_path), **kwargs) + file_path = str(file_path) if isinstance(file_path, Path) else file_path + torch.onnx.export(self, input_sample, file_path, **kwargs) self.train(mode) @torch.no_grad() diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 15d0635594..ee670cd66e 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,6 +13,7 @@ # limitations under the License. import operator import os +from io import BytesIO from pathlib import Path from unittest.mock import patch @@ -45,6 +46,10 @@ def test_model_saves_with_input_sample(tmp_path): assert os.path.isfile(file_path) assert os.path.getsize(file_path) > 4e2 + file_path = BytesIO() + model.to_onnx(file_path=file_path, input_sample=input_sample) + assert len(file_path.getvalue()) > 4e2 + @pytest.mark.parametrize( "accelerator", [pytest.param("mps", marks=RunIf(mps=True)), pytest.param("gpu", marks=RunIf(min_cuda_gpus=True))]