Fixed encoding issues on terminals that do not support unicode characters (#12828)

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
alvitawa 2022-04-25 14:24:30 +02:00 committed by lexierule
parent 194c7c7f51
commit eea6b44e5a
3 changed files with 35 additions and 1 deletions

View File

@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716))
- Fixed encoding issues on terminals that do not support unicode characters ([#12828](https://github.com/PyTorchLightning/pytorch-lightning/pull/12828))
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
import shutil
import sys
from collections import ChainMap, OrderedDict
from functools import partial
from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union
@ -336,6 +337,10 @@ class EvaluationLoop(DataLoaderLoop):
@staticmethod
def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None:
# print to stdout by default
if file is None:
file = sys.stdout
# remove the dl idx suffix
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys})
@ -384,7 +389,16 @@ class EvaluationLoop(DataLoaderLoop):
row_format = f"{{:^{max_length}}}" * len(table_headers)
half_term_size = int(term_size / 2)
bar = "" * term_size
try:
# some terminals do not support this character
if hasattr(file, "encoding") and file.encoding is not None:
"".encode(file.encoding)
except UnicodeEncodeError:
bar_character = "-"
else:
bar_character = ""
bar = bar_character * term_size
lines = [bar, row_format.format(*table_headers).rstrip(), bar]
for metric, row in zip(metrics, table_rows):
# deal with column overflow

View File

@ -870,6 +870,23 @@ def test_native_print_results(monkeypatch, inputs, expected):
assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip()
@pytest.mark.parametrize("encoding", ["latin-1", "utf-8"])
def test_native_print_results_encodings(monkeypatch, encoding):
import pytorch_lightning.loops.dataloader.evaluation_loop as imports
monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
out = mock.Mock()
out.encoding = encoding
EvaluationLoop._print_results(*inputs0, file=out)
# Attempt to encode everything the file is told to write with the given encoding
for call_ in out.method_calls:
name, args, kwargs = call_
if name == "write":
args[0].encode(encoding)
expected0 = """
Test metric DataLoader 0 DataLoader 1