2020-10-13 11:18:07 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-06-17 07:08:22 +00:00
|
|
|
import pickle
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
import pytest
|
2020-10-08 02:54:32 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.multiprocessing as mp
|
2021-03-16 14:55:31 +00:00
|
|
|
from torchmetrics import Metric
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-02-08 10:52:02 +00:00
|
|
|
import tests.helpers.utils as tutils
|
2021-06-17 07:08:22 +00:00
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
|
|
|
|
from tests.helpers import BoringModel
|
2021-03-02 09:36:01 +00:00
|
|
|
from tests.helpers.runif import RunIf
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
class DummyMetric(Metric):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.add_state("x", torch.tensor(0), dist_reduce_fx="sum")
|
|
|
|
|
|
|
|
def update(self, x):
|
|
|
|
self.x += x
|
|
|
|
|
|
|
|
def compute(self):
|
|
|
|
return self.x
|
|
|
|
|
|
|
|
|
|
|
|
def _setup_ddp(rank, worldsize):
|
|
|
|
import os
|
|
|
|
|
|
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
|
|
|
|
|
|
# initialize the process group
|
|
|
|
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
|
|
|
|
|
|
|
|
|
|
|
def _ddp_test_fn(rank, worldsize):
|
|
|
|
_setup_ddp(rank, worldsize)
|
2020-12-21 05:40:55 +00:00
|
|
|
torch.tensor([1.0])
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
metric_a = DummyMetric()
|
|
|
|
metric_b = DummyMetric()
|
|
|
|
metric_c = DummyMetric()
|
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
metric_a = metric_a.to(f"cuda:{rank}")
|
|
|
|
metric_b = metric_b.to(f"cuda:{rank}")
|
|
|
|
metric_c = metric_c.to(f"cuda:{rank}")
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
for _ in range(3):
|
|
|
|
cumulative_sum = 0
|
2020-10-08 02:54:32 +00:00
|
|
|
for i in range(5):
|
|
|
|
metric_a(i)
|
|
|
|
metric_b(i)
|
|
|
|
metric_c(i)
|
|
|
|
|
|
|
|
cumulative_sum += i
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
result.log("h", "a", metric_a, on_step=True, on_epoch=True)
|
|
|
|
result.log("h", "b", metric_b, on_step=False, on_epoch=True)
|
|
|
|
result.log("h", "c", metric_c, on_step=True, on_epoch=False)
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
batch_log = result.metrics(True)[MetricSource.LOG]
|
|
|
|
assert batch_log == {"a_step": i, "c": i}
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
epoch_log = result.metrics(False)[MetricSource.LOG]
|
2021-04-19 13:48:48 +00:00
|
|
|
result.reset()
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
# assert metric state reset to default values
|
2021-07-26 11:37:35 +00:00
|
|
|
assert metric_a.x == metric_a._defaults["x"], (metric_a.x, metric_a._defaults["x"])
|
|
|
|
assert metric_b.x == metric_b._defaults["x"]
|
|
|
|
assert metric_c.x == metric_c._defaults["x"]
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
@RunIf(skip_windows=True, min_gpus=2)
|
2020-10-08 02:54:32 +00:00
|
|
|
def test_result_reduce_ddp():
|
|
|
|
"""Make sure result logging works with DDP"""
|
|
|
|
tutils.set_random_master_port()
|
|
|
|
|
|
|
|
worldsize = 2
|
2021-07-26 11:37:35 +00:00
|
|
|
mp.spawn(_ddp_test_fn, args=(worldsize,), nprocs=worldsize)
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_result_metric_integration():
|
|
|
|
metric_a = DummyMetric()
|
|
|
|
metric_b = DummyMetric()
|
|
|
|
metric_c = DummyMetric()
|
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
result = ResultCollection(True, torch.device("cpu"))
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
for _ in range(3):
|
2020-10-08 02:54:32 +00:00
|
|
|
cumulative_sum = 0
|
|
|
|
for i in range(5):
|
|
|
|
metric_a(i)
|
|
|
|
metric_b(i)
|
|
|
|
metric_c(i)
|
|
|
|
|
|
|
|
cumulative_sum += i
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
result.log("h", "a", metric_a, on_step=True, on_epoch=True)
|
|
|
|
result.log("h", "b", metric_b, on_step=False, on_epoch=True)
|
|
|
|
result.log("h", "c", metric_c, on_step=True, on_epoch=False)
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
batch_log = result.metrics(True)[MetricSource.LOG]
|
|
|
|
assert batch_log == {"a_step": i, "c": i}
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
epoch_log = result.metrics(False)[MetricSource.LOG]
|
2021-04-19 13:48:48 +00:00
|
|
|
result.reset()
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
# assert metric state reset to default values
|
2021-07-26 11:37:35 +00:00
|
|
|
assert metric_a.x == metric_a._defaults["x"]
|
|
|
|
assert metric_b.x == metric_b._defaults["x"]
|
|
|
|
assert metric_c.x == metric_c._defaults["x"]
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}
|
|
|
|
|
|
|
|
assert str(result) == (
|
|
|
|
"ResultCollection(True, cpu, {"
|
2021-06-25 19:16:11 +00:00
|
|
|
"'h.a': ResultMetric('a', value=DummyMetric()), "
|
|
|
|
"'h.b': ResultMetric('b', value=DummyMetric()), "
|
|
|
|
"'h.c': ResultMetric('c', value=DummyMetric())"
|
2021-06-08 20:20:17 +00:00
|
|
|
"})"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def test_result_collection_simple_loop():
|
|
|
|
result = ResultCollection(True, torch.device("cpu"))
|
|
|
|
current_fx_name = None
|
|
|
|
batch_idx = None
|
|
|
|
|
|
|
|
def lightning_log(fx, *args, **kwargs):
|
|
|
|
nonlocal current_fx_name
|
|
|
|
if current_fx_name != fx and batch_idx in (None, 0):
|
|
|
|
result.reset(metrics=False, fx=fx)
|
|
|
|
result.log(fx, *args, **kwargs)
|
|
|
|
current_fx_name = fx
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
lightning_log("a0", "a", torch.tensor(0.0), on_step=True, on_epoch=True)
|
|
|
|
lightning_log("a1", "a", torch.tensor(0.0), on_step=True, on_epoch=True)
|
2021-06-08 20:20:17 +00:00
|
|
|
for epoch in range(2):
|
2021-07-26 11:37:35 +00:00
|
|
|
lightning_log("b0", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True)
|
|
|
|
lightning_log("b1", "a", torch.tensor(1.0) + epoch, on_step=True, on_epoch=True)
|
2021-06-08 20:20:17 +00:00
|
|
|
for batch_idx in range(2):
|
2021-07-26 11:37:35 +00:00
|
|
|
lightning_log("c0", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
|
|
|
|
lightning_log("c1", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
|
|
|
|
lightning_log("c2", "a", torch.tensor(2.0) + epoch, on_step=True, on_epoch=True)
|
2021-06-08 20:20:17 +00:00
|
|
|
batch_idx = None
|
2021-07-26 11:37:35 +00:00
|
|
|
lightning_log("d0", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True)
|
|
|
|
lightning_log("d1", "a", torch.tensor(3.0) + epoch, on_step=False, on_epoch=True)
|
2021-06-08 20:20:17 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
for k in ("a0.a", "a1.a"):
|
|
|
|
assert result[k].value == torch.tensor(0.0), k
|
|
|
|
assert result[k].cumulated_batch_size == torch.tensor(1.0), k
|
2021-06-08 20:20:17 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
for k in ("b0.a", "b1.a"):
|
|
|
|
assert result[k].value == torch.tensor(1.0) + epoch, k
|
|
|
|
assert result[k].cumulated_batch_size == torch.tensor(1.0), k
|
2021-06-08 20:20:17 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
for k in ("c0.a", "c1.a", "c2.a"):
|
|
|
|
assert result[k].value == torch.tensor(4.0) + epoch * 2, k
|
|
|
|
assert result[k].cumulated_batch_size == torch.tensor(2.0), k
|
2021-06-08 20:20:17 +00:00
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
for k in ("d0.a", "d1.a"):
|
|
|
|
assert result[k].value == torch.tensor(3.0) + epoch, k
|
|
|
|
assert result[k].cumulated_batch_size == torch.tensor(1.0), k
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
|
2021-06-25 19:16:11 +00:00
|
|
|
def my_sync_dist(x, *_, **__):
|
2021-06-17 07:08:22 +00:00
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def test_result_collection_restoration(tmpdir):
|
2021-07-26 11:37:35 +00:00
|
|
|
"""
|
2021-06-17 07:08:22 +00:00
|
|
|
This test make sure metrics are properly reloaded on failure.
|
|
|
|
"""
|
|
|
|
|
|
|
|
result = ResultCollection(True, torch.device("cpu"))
|
|
|
|
metric_a = DummyMetric()
|
|
|
|
metric_b = DummyMetric()
|
|
|
|
metric_c = DummyMetric()
|
|
|
|
metric_d = DummyMetric()
|
|
|
|
current_fx_name = None
|
|
|
|
batch_idx = None
|
|
|
|
|
|
|
|
def lightning_log(fx, *args, **kwargs):
|
|
|
|
nonlocal current_fx_name
|
|
|
|
if current_fx_name != fx and batch_idx in (None, 0):
|
|
|
|
result.reset(metrics=False, fx=fx)
|
|
|
|
result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
|
|
|
|
current_fx_name = fx
|
|
|
|
|
2021-06-25 19:16:11 +00:00
|
|
|
for epoch in range(2):
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
cumulative_sum = 0
|
|
|
|
|
|
|
|
for i in range(3):
|
|
|
|
|
|
|
|
a = metric_a(i)
|
|
|
|
b = metric_b(i)
|
|
|
|
c = metric_c(i)
|
|
|
|
metric_d(i)
|
|
|
|
|
|
|
|
cumulative_sum += i
|
|
|
|
|
|
|
|
metric = metric_a if i < 1 else metric_d
|
2021-07-26 11:37:35 +00:00
|
|
|
lightning_log("training_step", "a", metric, on_step=True, on_epoch=True, metric_attribute="metric")
|
|
|
|
lightning_log("training_step", "b", metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b")
|
|
|
|
lightning_log("training_step", "c", metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
|
|
|
|
lightning_log("training_step", "a_1", a, on_step=True, on_epoch=True)
|
|
|
|
lightning_log("training_step", "b_1", b, on_step=False, on_epoch=True)
|
|
|
|
lightning_log("training_step", "c_1", {"1": c, "2": c}, on_step=True, on_epoch=False)
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
batch_log = result.metrics(on_step=True)[MetricSource.LOG]
|
|
|
|
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
|
2021-07-26 11:37:35 +00:00
|
|
|
assert set(batch_log["c_1"]) == {"1", "2"}
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
result_copy = deepcopy(result)
|
|
|
|
new_result = ResultCollection(True, torch.device("cpu"))
|
|
|
|
state_dict = result.state_dict()
|
|
|
|
# check the sync fn was dropped
|
2021-07-26 11:37:35 +00:00
|
|
|
assert "fn" not in state_dict["items"]["training_step.a"]["meta"]["_sync"]
|
2021-06-25 19:16:11 +00:00
|
|
|
|
|
|
|
assert not new_result.result_metrics
|
|
|
|
assert len(result.result_metrics) == 7 + epoch > 0
|
|
|
|
|
|
|
|
new_result.load_state_dict(
|
2021-07-26 11:37:35 +00:00
|
|
|
state_dict, metrics={"metric": metric, "metric_b": metric_b, "metric_c": metric_c}
|
2021-06-25 19:16:11 +00:00
|
|
|
)
|
2021-06-17 07:08:22 +00:00
|
|
|
# should match
|
|
|
|
assert result_copy == new_result
|
|
|
|
# the sync fn has been kept
|
2021-07-26 11:37:35 +00:00
|
|
|
assert result_copy["training_step.a"].meta.sync.fn == new_result["training_step.a"].meta.sync.fn
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
|
|
|
|
epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
|
|
|
|
assert epoch_log == epoch_log_copy
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
lightning_log("train_epoch_end", "a", metric_a, on_step=False, on_epoch=True)
|
2021-06-17 07:08:22 +00:00
|
|
|
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
|
|
|
|
assert epoch_log == {
|
2021-07-26 11:37:35 +00:00
|
|
|
"a_1_epoch": 1,
|
|
|
|
"a_epoch": cumulative_sum,
|
|
|
|
"a": cumulative_sum,
|
|
|
|
"b": cumulative_sum,
|
|
|
|
"b_1": 1,
|
2021-06-17 07:08:22 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
# make sure can be pickled
|
|
|
|
pickle.loads(pickle.dumps(result))
|
|
|
|
# make sure can be torch.loaded
|
2021-07-26 11:37:35 +00:00
|
|
|
filepath = str(tmpdir / "result")
|
2021-06-17 07:08:22 +00:00
|
|
|
torch.save(result, filepath)
|
|
|
|
torch.load(filepath)
|
|
|
|
|
|
|
|
# assert metric state reset to default values
|
|
|
|
result.reset()
|
2021-07-26 11:37:35 +00:00
|
|
|
assert metric_a.x == metric_a._defaults["x"]
|
|
|
|
assert metric_b.x == metric_b._defaults["x"]
|
|
|
|
assert metric_c.x == metric_c._defaults["x"]
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
batch_idx = None
|
|
|
|
|
|
|
|
|
2021-07-26 11:37:35 +00:00
|
|
|
@pytest.mark.parametrize("device", ("cpu", pytest.param("cuda", marks=RunIf(min_gpus=1))))
|
2021-06-17 07:08:22 +00:00
|
|
|
def test_lightning_module_logging_result_collection(tmpdir, device):
|
|
|
|
class LoggingModel(BoringModel):
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.metric = DummyMetric()
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
v = self.metric(batch_idx)
|
|
|
|
self.log_dict({"v": v, "m": self.metric})
|
|
|
|
return super().validation_step(batch, batch_idx)
|
|
|
|
|
|
|
|
def on_save_checkpoint(self, checkpoint) -> None:
|
|
|
|
results = self.trainer._results
|
2021-06-25 19:16:11 +00:00
|
|
|
# simplify logic
|
|
|
|
state_dict = results.state_dict(drop_value=False)
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
# check device
|
2021-07-26 11:37:35 +00:00
|
|
|
assert results["validation_step.v"].value.device.type == device
|
|
|
|
assert state_dict["items"]["validation_step.v"]["value"].device.type == device
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
# sync fn should be kept
|
2021-07-26 11:37:35 +00:00
|
|
|
assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
# sync fn dropped from the state dict
|
2021-07-26 11:37:35 +00:00
|
|
|
assert "fn" not in state_dict["items"]["validation_step.v"]["meta"]["_sync"]
|
2021-06-17 07:08:22 +00:00
|
|
|
results.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
# check device after loading
|
2021-07-26 11:37:35 +00:00
|
|
|
assert results["validation_step.v"].value.device.type == device
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
# sync fn was preserved in the original result
|
2021-07-26 11:37:35 +00:00
|
|
|
assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
# default sync fn
|
|
|
|
new_results = ResultCollection(False, device)
|
2021-07-26 11:37:35 +00:00
|
|
|
new_results.load_state_dict(state_dict, map_location="cpu")
|
|
|
|
assert new_results["validation_step.v"].meta.sync.fn == _Sync.no_op
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
# check map location
|
2021-07-26 11:37:35 +00:00
|
|
|
assert new_results["validation_step.v"].value.device.type == "cpu"
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
model = LoggingModel()
|
2021-07-13 14:47:59 +00:00
|
|
|
ckpt = ModelCheckpoint(dirpath=tmpdir, save_on_train_epoch_end=False)
|
2021-06-17 07:08:22 +00:00
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
max_epochs=2,
|
|
|
|
limit_train_batches=2,
|
|
|
|
limit_val_batches=2,
|
|
|
|
callbacks=[ckpt],
|
2021-07-26 11:37:35 +00:00
|
|
|
gpus=1 if device == "cuda" else 0,
|
2021-06-17 07:08:22 +00:00
|
|
|
)
|
|
|
|
trainer.fit(model)
|