186 lines
6.7 KiB
Python
186 lines
6.7 KiB
Python
# Copyright The Lightning AI team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import os
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
from lightning.fabric.loggers import CSVLogger
|
|
from lightning.fabric.loggers.csv_logs import _ExperimentWriter
|
|
|
|
|
|
def test_automatic_versioning(tmp_path):
|
|
"""Verify that automatic versioning works."""
|
|
(tmp_path / "exp" / "version_0").mkdir(parents=True)
|
|
(tmp_path / "exp" / "version_1").mkdir()
|
|
(tmp_path / "exp" / "version_nonumber").mkdir()
|
|
(tmp_path / "exp" / "other").mkdir()
|
|
|
|
logger = CSVLogger(root_dir=tmp_path, name="exp")
|
|
assert logger.version == 2
|
|
|
|
|
|
def test_automatic_versioning_relative_root_dir(tmp_path, monkeypatch):
|
|
"""Verify that automatic versioning works, when root_dir is given a relative path."""
|
|
(tmp_path / "exp" / "logs" / "version_0").mkdir(parents=True)
|
|
(tmp_path / "exp" / "logs" / "version_1").mkdir()
|
|
monkeypatch.chdir(tmp_path)
|
|
logger = CSVLogger(root_dir="exp", name="logs")
|
|
assert logger.version == 2
|
|
|
|
|
|
def test_manual_versioning(tmp_path):
|
|
"""Verify that manual versioning works."""
|
|
root_dir = tmp_path / "exp"
|
|
(root_dir / "version_0").mkdir(parents=True)
|
|
(root_dir / "version_1").mkdir()
|
|
(root_dir / "version_2").mkdir()
|
|
logger = CSVLogger(root_dir=root_dir, name="exp", version=1)
|
|
assert logger.version == 1
|
|
|
|
|
|
def test_named_version(tmp_path):
|
|
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
|
|
exp_name = "exp"
|
|
(tmp_path / exp_name).mkdir()
|
|
expected_version = "2020-02-05-162402"
|
|
|
|
logger = CSVLogger(root_dir=tmp_path, name=exp_name, version=expected_version)
|
|
logger.log_metrics({"a": 1, "b": 2})
|
|
logger.save()
|
|
assert logger.version == expected_version
|
|
assert os.listdir(tmp_path / exp_name) == [expected_version]
|
|
assert os.listdir(tmp_path / exp_name / expected_version)
|
|
|
|
|
|
@pytest.mark.parametrize("name", ["", None])
|
|
def test_no_name(tmp_path, name):
|
|
"""Verify that None or empty name works."""
|
|
logger = CSVLogger(root_dir=tmp_path, name=name)
|
|
logger.log_metrics({"a": 1})
|
|
logger.save()
|
|
assert os.path.normpath(logger._root_dir) == str(tmp_path) # use os.path.normpath to handle trailing /
|
|
assert os.listdir(tmp_path / "version_0")
|
|
|
|
|
|
@pytest.mark.parametrize("step_idx", [10, None])
|
|
def test_log_metrics(tmp_path, step_idx):
|
|
logger = CSVLogger(tmp_path)
|
|
metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
|
|
logger.log_metrics(metrics, step_idx)
|
|
logger.save()
|
|
|
|
path_csv = os.path.join(logger.log_dir, _ExperimentWriter.NAME_METRICS_FILE)
|
|
with open(path_csv) as fp:
|
|
lines = fp.readlines()
|
|
assert len(lines) == 2
|
|
assert all(n in lines[0] for n in metrics)
|
|
|
|
|
|
def test_log_hyperparams(tmp_path):
|
|
logger = CSVLogger(tmp_path)
|
|
with pytest.raises(NotImplementedError):
|
|
logger.log_hyperparams({})
|
|
|
|
|
|
def test_flush_n_steps(tmp_path):
|
|
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=2)
|
|
metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
|
|
logger.save = MagicMock()
|
|
logger.log_metrics(metrics, step=0)
|
|
|
|
logger.save.assert_not_called()
|
|
logger.log_metrics(metrics, step=1)
|
|
logger.save.assert_called_once()
|
|
|
|
|
|
def test_metrics_reset_after_save(tmp_path):
|
|
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=2)
|
|
metrics = {"test": 1}
|
|
logger.log_metrics(metrics, step=0)
|
|
assert logger.experiment.metrics
|
|
logger.log_metrics(metrics, step=1) # flush triggered
|
|
assert not logger.experiment.metrics
|
|
|
|
|
|
def test_automatic_step_tracking(tmp_path):
|
|
"""Test that the logger keeps track of the step value if it isn't passed in explicitly."""
|
|
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=3)
|
|
logger.save = MagicMock()
|
|
metrics = {"test": 0.1}
|
|
logger.log_metrics(metrics, step=None)
|
|
logger.save.assert_not_called()
|
|
assert logger.experiment.metrics[0]["step"] == 0
|
|
logger.log_metrics(metrics, step=None)
|
|
logger.save.assert_not_called()
|
|
assert logger.experiment.metrics[1]["step"] == 1
|
|
logger.log_metrics(metrics, step=None)
|
|
logger.save.assert_called_once()
|
|
assert logger.experiment.metrics[2]["step"] == 2
|
|
|
|
|
|
def test_append_metrics_file(tmp_path):
|
|
"""Test that the logger appends to the file instead of rewriting it on every save."""
|
|
logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1)
|
|
|
|
# initial metrics
|
|
logger.log_metrics({"a": 1, "b": 2})
|
|
logger.log_metrics({"a": 3, "b": 4})
|
|
|
|
# create a new logger to show we append to the existing file
|
|
logger = CSVLogger(tmp_path, name="test", version=0, flush_logs_every_n_steps=1)
|
|
logger.log_metrics({"a": 100, "b": 200})
|
|
|
|
with open(logger.experiment.metrics_file_path) as file:
|
|
lines = file.readlines()
|
|
assert len(lines) == 4 # 1 header + 3 lines of metrics
|
|
|
|
|
|
def test_append_columns(tmp_path):
|
|
"""Test that the CSV file gets rewritten with new headers if the columns change."""
|
|
logger = CSVLogger(tmp_path, flush_logs_every_n_steps=1)
|
|
|
|
# initial metrics
|
|
logger.log_metrics({"a": 1, "b": 2})
|
|
|
|
# new key appears
|
|
logger.log_metrics({"a": 1, "b": 2, "c": 3})
|
|
with open(logger.experiment.metrics_file_path) as file:
|
|
header = file.readline().strip()
|
|
assert set(header.split(",")) == {"step", "a", "b", "c"}
|
|
|
|
# key disappears
|
|
logger.log_metrics({"a": 1, "c": 3})
|
|
with open(logger.experiment.metrics_file_path) as file:
|
|
header = file.readline().strip()
|
|
assert set(header.split(",")) == {"step", "a", "b", "c"}
|
|
|
|
|
|
def test_rewrite_with_new_header(tmp_path):
|
|
# write a csv file manually
|
|
with open(tmp_path / "metrics.csv", "w") as file:
|
|
file.write("step,metric1,metric2\n")
|
|
file.write("0,1,22\n")
|
|
|
|
writer = _ExperimentWriter(log_dir=str(tmp_path))
|
|
new_columns = ["step", "metric1", "metric2", "metric3"]
|
|
writer._rewrite_with_new_header(new_columns)
|
|
|
|
# the rewritten file should have the new columns
|
|
with open(tmp_path / "metrics.csv") as file:
|
|
header = file.readline().strip().split(",")
|
|
assert header == new_columns
|
|
logs = file.readline().strip().split(",")
|
|
assert logs == ["0", "1", "22", ""]
|