Fixed encoding issues on terminals that do not support unicode characters (#12828)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
194c7c7f51
commit
eea6b44e5a
|
@ -16,6 +16,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- When using custom DataLoaders in LightningDataModule, multiple inheritance is resolved properly ([#12716](https://github.com/PyTorchLightning/pytorch-lightning/pull/12716))
|
||||
|
||||
|
||||
- Fixed encoding issues on terminals that do not support unicode characters ([#12828](https://github.com/PyTorchLightning/pytorch-lightning/pull/12828))
|
||||
|
||||
|
||||
- Fixed support for `ModelCheckpoint` monitors with dots ([#12783](https://github.com/PyTorchLightning/pytorch-lightning/pull/12783))
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from collections import ChainMap, OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, IO, Iterable, List, Optional, Sequence, Type, Union
|
||||
|
@ -336,6 +337,10 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
|
||||
@staticmethod
|
||||
def _print_results(results: List[_OUT_DICT], stage: str, file: Optional[IO[str]] = None) -> None:
|
||||
# print to stdout by default
|
||||
if file is None:
|
||||
file = sys.stdout
|
||||
|
||||
# remove the dl idx suffix
|
||||
results = [{k.split("/dataloader_idx_")[0]: v for k, v in result.items()} for result in results]
|
||||
metrics = sorted({k for keys in apply_to_collection(results, dict, EvaluationLoop._get_keys) for k in keys})
|
||||
|
@ -384,7 +389,16 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
row_format = f"{{:^{max_length}}}" * len(table_headers)
|
||||
half_term_size = int(term_size / 2)
|
||||
|
||||
bar = "─" * term_size
|
||||
try:
|
||||
# some terminals do not support this character
|
||||
if hasattr(file, "encoding") and file.encoding is not None:
|
||||
"─".encode(file.encoding)
|
||||
except UnicodeEncodeError:
|
||||
bar_character = "-"
|
||||
else:
|
||||
bar_character = "─"
|
||||
bar = bar_character * term_size
|
||||
|
||||
lines = [bar, row_format.format(*table_headers).rstrip(), bar]
|
||||
for metric, row in zip(metrics, table_rows):
|
||||
# deal with column overflow
|
||||
|
|
|
@ -870,6 +870,23 @@ def test_native_print_results(monkeypatch, inputs, expected):
|
|||
assert out.getvalue().replace(os.linesep, "\n") == expected.lstrip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("encoding", ["latin-1", "utf-8"])
|
||||
def test_native_print_results_encodings(monkeypatch, encoding):
|
||||
import pytorch_lightning.loops.dataloader.evaluation_loop as imports
|
||||
|
||||
monkeypatch.setattr(imports, "_RICH_AVAILABLE", False)
|
||||
|
||||
out = mock.Mock()
|
||||
out.encoding = encoding
|
||||
EvaluationLoop._print_results(*inputs0, file=out)
|
||||
|
||||
# Attempt to encode everything the file is told to write with the given encoding
|
||||
for call_ in out.method_calls:
|
||||
name, args, kwargs = call_
|
||||
if name == "write":
|
||||
args[0].encode(encoding)
|
||||
|
||||
|
||||
expected0 = """
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓
|
||||
┃ Test metric ┃ DataLoader 0 ┃ DataLoader 1 ┃
|
||||
|
|
Loading…
Reference in New Issue