lightning/tests/core/test_metric_result_integrat...

560 lines
21 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
from contextlib import suppress
from copy import deepcopy
from unittest import mock
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn import ModuleDict, ModuleList
from torchmetrics import Metric, MetricCollection
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled, _TORCH_GREATER_EQUAL_1_7
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
class DummyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx="sum")
def update(self, x):
self.x += x
def compute(self):
return self.x
def _setup_ddp(rank, worldsize):
import os
os.environ["MASTER_ADDR"] = "localhost"
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
torch.tensor([1.0])
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
metric_a = metric_a.to(f"cuda:{rank}")
metric_b = metric_b.to(f"cuda:{rank}")
metric_c = metric_c.to(f"cuda:{rank}")
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
for _ in range(3):
cumulative_sum = 0
for i in range(5):
metric_a(i)
metric_b(i)
metric_c(i)
cumulative_sum += i
result.log("h", "a", metric_a, on_step=True, on_epoch=True)
result.log("h", "b", metric_b, on_step=False, on_epoch=True)
result.log("h", "c", metric_c, on_step=True, on_epoch=False)
batch_log = result.metrics(True)[MetricSource.LOG]
assert batch_log == {"a_step": i, "c": i}
epoch_log = result.metrics(False)[MetricSource.LOG]
result.reset()
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults["x"], (metric_a.x, metric_a._defaults["x"])
assert metric_b.x == metric_b._defaults["x"]
assert metric_c.x == metric_c._defaults["x"]
assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
@RunIf(skip_windows=True, min_gpus=2)
def test_result_reduce_ddp():
"""Make sure result logging works with DDP"""
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
def test_result_metric_integration():
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
result = ResultCollection(True, torch.device("cpu"))
for _ in range(3):
cumulative_sum = 0
for i in range(5):
metric_a(i)
metric_b(i)
metric_c(i)
cumulative_sum += i
result.log("h", "a", metric_a, on_step=True, on_epoch=True)
result.log("h", "b", metric_b, on_step=False, on_epoch=True)
result.log("h", "c", metric_c, on_step=True, on_epoch=False)
batch_log = result.metrics(True)[MetricSource.LOG]
assert batch_log == {"a_step": i, "c": i}
epoch_log = result.metrics(False)[MetricSource.LOG]
result.reset()
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults["x"]
assert metric_b.x == metric_b._defaults["x"]
assert metric_c.x == metric_c._defaults["x"]
assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}
result.minimize = torch.tensor(1.0)
result.extra = {}
assert str(result) == (
"ResultCollection("
"minimize=1.0, "
"{"
"'h.a': ResultMetric('a', value=DummyMetric()), "
"'h.b': ResultMetric('b', value=DummyMetric()), "
"'h.c': ResultMetric('c', value=DummyMetric())"
"})"
)
assert repr(result) == (
"{"
"True, "
"device(type='cpu'), "
"minimize=tensor(1.), "
"{'h.a': ResultMetric('a', value=DummyMetric()), "
"'h.b': ResultMetric('b', value=DummyMetric()), "
"'h.c': ResultMetric('c', value=DummyMetric()), "
"'_extra': {}}"
"}"
)
def test_result_collection_simple_loop():
result = ResultCollection(True, torch.device("cpu"))
current_fx_name = None
batch_idx = None
def lightning_log(fx, *args, **kwargs):
nonlocal current_fx_name
if current_fx_name != fx and batch_idx in (None, 0):
result.reset(metrics=False, fx=fx)
result.log(fx, *args, **kwargs)
current_fx_name = fx
lightning_log("a0", "a", torch.tensor(0.0), on_step=True, on_epoch=True)
lightning_log("a1", "a", torch.tensor(0.0), on_step=True, on_epoch=True)
for epoch in range(2):
lightning_log("b0", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True)
lightning_log("b1", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True)
for batch_idx in range(2):
lightning_log("c0", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
lightning_log("c1", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
lightning_log("c2", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
batch_idx = None
lightning_log("d0", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True)
lightning_log("d1", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True)
for k in ("a0.a", "a1.a"):
assert result[k].value == torch.tensor(0.0), k
assert result[k].cumulated_batch_size == torch.tensor(1.0), k
for k in ("b0.a", "b1.a"):
assert result[k].value == torch.tensor(1.0) + epoch, k
assert result[k].cumulated_batch_size == torch.tensor(1.0), k
for k in ("c0.a", "c1.a", "c2.a"):
assert result[k].value == torch.tensor(4.0) + epoch * 2, k
assert result[k].cumulated_batch_size == torch.tensor(2.0), k
for k in ("d0.a", "d1.a"):
assert result[k].value == torch.tensor(3.0) + epoch, k
assert result[k].cumulated_batch_size == torch.tensor(1.0), k
def my_sync_dist(x, *_, **__):
return x
def test_result_collection_restoration(tmpdir):
"""
This test make sure metrics are properly reloaded on failure.
"""
result = ResultCollection(True, torch.device("cpu"))
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
metric_d = DummyMetric()
current_fx_name = None
batch_idx = None
def lightning_log(fx, *args, **kwargs):
nonlocal current_fx_name
if current_fx_name != fx and batch_idx in (None, 0):
result.reset(metrics=False, fx=fx)
result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
current_fx_name = fx
for epoch in range(2):
cumulative_sum = 0
for i in range(3):
a = metric_a(i)
b = metric_b(i)
c = metric_c(i)
metric_d(i)
cumulative_sum += i
metric = metric_a if i < 1 else metric_d
lightning_log("training_step", "a", metric, on_step=True, on_epoch=True, metric_attribute="metric")
lightning_log("training_step", "b", metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b")
lightning_log("training_step", "c", metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
lightning_log("training_step", "a_1", a, on_step=True, on_epoch=True)
lightning_log("training_step", "b_1", b, on_step=False, on_epoch=True)
lightning_log("training_step", "c_1", {"1": c, "2": c}, on_step=True, on_epoch=False)
batch_log = result.metrics(on_step=True)[MetricSource.LOG]
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
assert set(batch_log["c_1"]) == {"1", "2"}
result_copy = deepcopy(result)
new_result = ResultCollection(True, torch.device("cpu"))
state_dict = result.state_dict()
# check the sync fn was dropped
assert "fn" not in state_dict["items"]["training_step.a"]["meta"]["_sync"]
assert not new_result.result_metrics
assert len(result.result_metrics) == 7 + epoch > 0
new_result.load_state_dict(
state_dict, metrics={"metric": metric, "metric_b": metric_b, "metric_c": metric_c}
)
# should match
assert result_copy == new_result
# the sync fn has been kept
assert result_copy["training_step.a"].meta.sync.fn == new_result["training_step.a"].meta.sync.fn
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
assert epoch_log == epoch_log_copy
lightning_log("train_epoch_end", "a", metric_a, on_step=False, on_epoch=True)
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
assert epoch_log == {
"a_1_epoch": 1,
"a_epoch": cumulative_sum,
"a": cumulative_sum,
"b": cumulative_sum,
"b_1": 1,
}
# make sure can be pickled
pickle.loads(pickle.dumps(result))
# make sure can be torch.loaded
filepath = str(tmpdir / "result")
torch.save(result, filepath)
torch.load(filepath)
# assert metric state reset to default values
result.reset()
assert metric_a.x == metric_a._defaults["x"]
assert metric_b.x == metric_b._defaults["x"]
assert metric_c.x == metric_c._defaults["x"]
batch_idx = None
@pytest.mark.parametrize("device", ("cpu", pytest.param("cuda", marks=RunIf(min_gpus=1))))
def test_lightning_module_logging_result_collection(tmpdir, device):
class LoggingModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = DummyMetric()
def validation_step(self, batch, batch_idx):
v = self.metric(batch_idx)
self.log_dict({"v": v, "m": self.metric})
return super().validation_step(batch, batch_idx)
def on_save_checkpoint(self, checkpoint) -> None:
results = self.trainer._results
# simplify logic
state_dict = results.state_dict(drop_value=False)
# check device
assert results["validation_step.v"].value.device.type == device
assert state_dict["items"]["validation_step.v"]["value"].device.type == device
# sync fn should be kept
assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce
# sync fn dropped from the state dict
assert "fn" not in state_dict["items"]["validation_step.v"]["meta"]["_sync"]
results.load_state_dict(state_dict)
# check device after loading
assert results["validation_step.v"].value.device.type == device
# sync fn was preserved in the original result
assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce
# default sync fn
new_results = ResultCollection(False, device)
new_results.load_state_dict(state_dict, map_location="cpu")
assert new_results["validation_step.v"].meta.sync.fn == _Sync.no_op
# check map location
assert new_results["validation_step.v"].value.device.type == "cpu"
model = LoggingModel()
ckpt = ModelCheckpoint(dirpath=tmpdir, save_on_train_epoch_end=False)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
callbacks=[ckpt],
gpus=1 if device == "cuda" else 0,
)
trainer.fit(model)
def test_result_collection_extra_reference():
"""Unit-test to check that the `extra` dict reference is properly set."""
rc = ResultCollection(True)
assert rc.extra is rc["_extra"]
class DummyMeanMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("sum", torch.tensor(0), dist_reduce_fx=torch.sum)
self.add_state("count", torch.tensor(0), dist_reduce_fx=torch.sum)
def update(self, increment):
self.sum += increment
self.count += 1
def compute(self):
return self.sum // self.count
def __repr__(self) -> str:
return f"{self.__class__.__name__}(sum={self.sum}, count={self.count})"
def result_collection_reload(**kwargs):
"""
This test is going to validate ResultCollection is properly being reload
and final accumulation with Fault Tolerant Training is correct.
"""
if not _fault_tolerant_enabled():
pytest.skip("Fault tolerant not available")
num_processes = kwargs.get("gpus", 1)
class CustomException(Exception):
pass
class ExtendedBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.breaking_batch_idx = 3
self.has_validated_sum = False
self.dummy_metric = DummyMeanMetric()
@property
def results(self):
return self.trainer.fit_loop._results
def training_step(self, batch, batch_idx):
# In the training step, we will accumulate metrics using batch_idx from 0 to 4
# Without failure, we would expect to get `total=10 * world_size` and `num_batches=5 * world_size`
# Therefore, compute on `epoch_end` should provide 2 as `10 / 5`.
# However, below we will simulate a failure on `batch_idx=3`.
if self.trainer.fit_loop.restarting:
self.log("tracking", batch_idx, on_step=True, on_epoch=True)
self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True)
self.dummy_metric(batch_idx)
self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True)
value = self.results["training_step.tracking_metric"].value
value_2 = self.results["training_step.tracking"].value
# On failure, the Metric states are being accumulated on rank 0 and zeroed-out on other ranks.
# The shift indicates we failed while the state was `shift=sign(is_global_zero > 0) * [0..3]`
shift = 0
if num_processes == 2:
shift = 3 if self.trainer.is_global_zero else -3
expected = sum(range(batch_idx + 1)) + shift
assert expected == value == value_2
else:
if batch_idx == self.breaking_batch_idx:
# simulate failure mid epoch
raise CustomException
self.log("tracking", batch_idx, on_step=True, on_epoch=True)
self.log("tracking_2", batch_idx, on_step=True, on_epoch=True, sync_dist=True)
self.dummy_metric(batch_idx)
self.log("tracking_metric", self.dummy_metric, on_step=True, on_epoch=True)
value = self.results["training_step.tracking"].value
assert value == sum(range(batch_idx + 1))
value = self.results["training_step.tracking_2"]
assert value == sum(range(batch_idx + 1))
return super().training_step(batch, batch_idx)
def on_epoch_end(self) -> None:
if self.trainer.fit_loop.restarting:
total = sum(range(5)) * num_processes
metrics = self.results.metrics(on_step=False)
assert self.results["training_step.tracking"].value == total
assert metrics[MetricSource.CALLBACK]["tracking"] == self.dummy_metric.compute() == 2
assert self.results["training_step.tracking_2"].value == total
assert metrics[MetricSource.CALLBACK]["tracking_2"] == self.dummy_metric.compute() == 2
self.has_validated_sum = True
model = ExtendedBoringModel()
trainer_kwargs = {"max_epochs": 1, "limit_train_batches": 5, "limit_val_batches": 0}
trainer_kwargs.update(kwargs)
trainer = Trainer(**trainer_kwargs)
with suppress(CustomException):
trainer.fit(model)
assert not model.has_validated_sum
tmpdir = (
trainer.training_type_plugin.broadcast(trainer_kwargs["default_root_dir"], 0)
if num_processes >= 2
else trainer_kwargs["default_root_dir"]
)
ckpt_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
trainer_kwargs["resume_from_checkpoint"] = ckpt_path
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
assert model.has_validated_sum
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7")
def test_result_collection_reload(tmpdir):
result_collection_reload(default_root_dir=tmpdir)
@RunIf(min_gpus=1)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7")
def test_result_collection_reload_1_gpu_ddp(tmpdir):
result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=1)
@RunIf(min_gpus=2, special=True)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.skipif(not _TORCH_GREATER_EQUAL_1_7, reason="Requires at least PyTorch 1.7")
def test_result_collection_reload_2_gpus(tmpdir):
result_collection_reload(default_root_dir=tmpdir, accelerator="ddp", gpus=2)
def test_metric_collections(tmpdir):
"""This test ensures the metric attribute is properly found even with complex nested metric structure"""
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.metrics_list = ModuleList([DummyMetric() for _ in range(2)])
self.metrics_dict = ModuleDict({"a": DummyMetric(), "b": DummyMetric()})
self.metrics_collection_dict = MetricCollection({"a": DummyMetric(), "b": DummyMetric()})
self.metrics_collection_dict_nested = ModuleDict(
{"a": ModuleList([ModuleDict({"b": DummyMetric()}), DummyMetric()])}
)
def training_step(self, batch, batch_idx):
loss = super().training_step(batch, batch_idx)
self.metrics_list[0](batch_idx)
self.metrics_list[1](batch_idx)
self.metrics_dict["a"](batch_idx)
self.metrics_dict["b"](batch_idx)
self.metrics_collection_dict["a"](batch_idx)
self.metrics_collection_dict["b"](batch_idx)
self.metrics_collection_dict_nested["a"][0]["b"](batch_idx)
self.metrics_collection_dict_nested["a"][1](batch_idx)
self.log("a", self.metrics_list[0])
self.log("b", self.metrics_list[1])
self.log("c", self.metrics_dict["a"])
self.log("d", self.metrics_dict["b"])
self.log("e", self.metrics_collection_dict["a"])
self.log("f", self.metrics_collection_dict["b"])
self.log("g", self.metrics_collection_dict_nested["a"][0]["b"])
self.log("h", self.metrics_collection_dict_nested["a"][1])
return loss
def on_train_epoch_end(self) -> None:
results = self.trainer.fit_loop.epoch_loop._results
assert results["training_step.a"].meta.metric_attribute == "metrics_list.0"
assert results["training_step.b"].meta.metric_attribute == "metrics_list.1"
assert results["training_step.c"].meta.metric_attribute == "metrics_dict.a"
assert results["training_step.d"].meta.metric_attribute == "metrics_dict.b"
assert results["training_step.e"].meta.metric_attribute == "metrics_collection_dict.a"
assert results["training_step.f"].meta.metric_attribute == "metrics_collection_dict.b"
assert results["training_step.g"].meta.metric_attribute == "metrics_collection_dict_nested.a.0.b"
assert results["training_step.h"].meta.metric_attribute == "metrics_collection_dict_nested.a.1"
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=0)
trainer.fit(model)