2020-10-13 11:18:07 +00:00
# Copyright The PyTorch Lightning team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
2021-06-17 07:08:22 +00:00
import pickle
from copy import deepcopy
import pytest
2020-10-08 02:54:32 +00:00
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
2021-03-16 14:55:31 +00:00
from torchmetrics import Metric
2020-10-08 02:54:32 +00:00
2021-02-08 10:52:02 +00:00
import tests.helpers.utils as tutils
2021-06-17 07:08:22 +00:00
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
from tests.helpers import BoringModel
2021-03-02 09:36:01 +00:00
from tests.helpers.runif import RunIf
2020-10-08 02:54:32 +00:00
class DummyMetric(Metric):
2021-02-06 11:07:26 +00:00
2020-10-08 02:54:32 +00:00
def __init__(self):
self.add_state("x", torch.tensor(0), dist_reduce_fx="sum")
def update(self, x):
self.x += x
def compute(self):
return self.x
def _setup_ddp(rank, worldsize):
import os
os.environ["MASTER_ADDR"] = "localhost"
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
2020-12-21 05:40:55 +00:00
2020-10-08 02:54:32 +00:00
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
2021-06-08 20:20:17 +00:00
metric_a = metric_a.to(f"cuda:{rank}")
metric_b = metric_b.to(f"cuda:{rank}")
metric_c = metric_c.to(f"cuda:{rank}")
2020-10-08 02:54:32 +00:00
2021-06-08 20:20:17 +00:00
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
2020-10-08 02:54:32 +00:00
2021-06-08 20:20:17 +00:00
for _ in range(3):
cumulative_sum = 0
2020-10-08 02:54:32 +00:00
for i in range(5):
cumulative_sum += i
2021-06-09 14:24:45 +00:00
result.log('h', 'a', metric_a, on_step=True, on_epoch=True)
result.log('h', 'b', metric_b, on_step=False, on_epoch=True)
result.log('h', 'c', metric_c, on_step=True, on_epoch=False)
2020-10-08 02:54:32 +00:00
2021-06-08 20:20:17 +00:00
batch_log = result.metrics(True)[MetricSource.LOG]
assert batch_log == {"a_step": i, "c": i}
2020-10-08 02:54:32 +00:00
2021-06-08 20:20:17 +00:00
epoch_log = result.metrics(False)[MetricSource.LOG]
2021-04-19 13:48:48 +00:00
2020-10-08 02:54:32 +00:00
# assert metric state reset to default values
2021-06-08 20:20:17 +00:00
assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x'])
2020-10-08 02:54:32 +00:00
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
2021-06-08 20:20:17 +00:00
assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
2020-10-08 02:54:32 +00:00
2021-06-08 20:20:17 +00:00
@RunIf(skip_windows=True, min_gpus=2)
2020-10-08 02:54:32 +00:00
def test_result_reduce_ddp():
"""Make sure result logging works with DDP"""
worldsize = 2
2021-02-06 11:07:26 +00:00
mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize)
2020-10-08 02:54:32 +00:00
def test_result_metric_integration():
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
2021-06-08 20:20:17 +00:00
result = ResultCollection(True, torch.device("cpu"))
2020-10-08 02:54:32 +00:00
2021-06-08 20:20:17 +00:00
for _ in range(3):
2020-10-08 02:54:32 +00:00
cumulative_sum = 0
for i in range(5):
cumulative_sum += i
2021-06-09 14:24:45 +00:00
result.log('h', 'a', metric_a, on_step=True, on_epoch=True)
result.log('h', 'b', metric_b, on_step=False, on_epoch=True)
result.log('h', 'c', metric_c, on_step=True, on_epoch=False)
2020-10-08 02:54:32 +00:00
2021-06-08 20:20:17 +00:00
batch_log = result.metrics(True)[MetricSource.LOG]
assert batch_log == {"a_step": i, "c": i}
2020-10-08 02:54:32 +00:00
2021-06-08 20:20:17 +00:00
epoch_log = result.metrics(False)[MetricSource.LOG]
2021-04-19 13:48:48 +00:00
2020-10-08 02:54:32 +00:00
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
2021-06-08 20:20:17 +00:00
assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}
assert str(result) == (
"ResultCollection(True, cpu, {"
"'h.a': ResultMetric(value=DummyMetric()), "
"'h.b': ResultMetric(value=DummyMetric()), "
"'h.c': ResultMetric(value=DummyMetric())"
def test_result_collection_simple_loop():
result = ResultCollection(True, torch.device("cpu"))
current_fx_name = None
batch_idx = None
def lightning_log(fx, *args, **kwargs):
nonlocal current_fx_name
if current_fx_name != fx and batch_idx in (None, 0):
result.reset(metrics=False, fx=fx)
result.log(fx, *args, **kwargs)
current_fx_name = fx
lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
for epoch in range(2):
lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
for batch_idx in range(2):
lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
batch_idx = None
lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
for k in ('a0.a', 'a1.a'):
assert result[k].value == torch.tensor(0.), k
assert result[k].cumulated_batch_size == torch.tensor(1.), k
for k in ('b0.a', 'b1.a'):
assert result[k].value == torch.tensor(1.) + epoch, k
assert result[k].cumulated_batch_size == torch.tensor(1.), k
for k in ('c0.a', 'c1.a', 'c2.a'):
assert result[k].value == torch.tensor(4.) + epoch * 2, k
assert result[k].cumulated_batch_size == torch.tensor(2.), k
for k in ('d0.a', 'd1.a'):
assert result[k].value == torch.tensor(3.) + epoch, k
assert result[k].cumulated_batch_size == torch.tensor(1.), k
2021-06-17 07:08:22 +00:00
def my_sync_dist(x):
return x
def test_result_collection_restoration(tmpdir):
This test make sure metrics are properly reloaded on failure.
result = ResultCollection(True, torch.device("cpu"))
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
metric_d = DummyMetric()
current_fx_name = None
batch_idx = None
def lightning_log(fx, *args, **kwargs):
nonlocal current_fx_name
if current_fx_name != fx and batch_idx in (None, 0):
result.reset(metrics=False, fx=fx)
result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
current_fx_name = fx
for _ in range(2):
cumulative_sum = 0
for i in range(3):
a = metric_a(i)
b = metric_b(i)
c = metric_c(i)
cumulative_sum += i
metric = metric_a if i < 1 else metric_d
lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True)
lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True)
lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False)
lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True)
lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True)
lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False)
batch_log = result.metrics(on_step=True)[MetricSource.LOG]
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
assert set(batch_log['c_1']) == {'1', '2'}
result_copy = deepcopy(result)
new_result = ResultCollection(True, torch.device("cpu"))
state_dict = result.state_dict()
# check the sync fn was dropped
assert 'fn' not in state_dict['items']['training_step.a']['meta']['_sync']
# should match
assert result_copy == new_result
# the sync fn has been kept
assert result_copy['training_step.a'].meta.sync.fn == new_result['training_step.a'].meta.sync.fn
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
assert epoch_log == epoch_log_copy
lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True)
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
assert epoch_log == {
'a_1_epoch': 1,
'a_epoch': cumulative_sum,
'a': cumulative_sum,
'b': cumulative_sum,
'b_1': 1
# make sure can be pickled
# make sure can be torch.loaded
filepath = str(tmpdir / 'result')
torch.save(result, filepath)
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
batch_idx = None
@pytest.mark.parametrize('device', ('cpu', pytest.param('cuda', marks=RunIf(min_gpus=1))))
def test_lightning_module_logging_result_collection(tmpdir, device):
class LoggingModel(BoringModel):
def __init__(self):
self.metric = DummyMetric()
def validation_step(self, batch, batch_idx):
v = self.metric(batch_idx)
self.log_dict({"v": v, "m": self.metric})
return super().validation_step(batch, batch_idx)
def on_save_checkpoint(self, checkpoint) -> None:
results = self.trainer._results
state_dict = results.state_dict()
# check device
assert results['validation_step.v'].value.device.type == device
assert state_dict['items']['validation_step.v']['value'].device.type == device
# sync fn should be kept
assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce
# sync fn dropped from the state dict
assert 'fn' not in state_dict['items']['validation_step.v']['meta']['_sync']
# check device after loading
assert results['validation_step.v'].value.device.type == device
# sync fn was preserved in the original result
assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce
# default sync fn
new_results = ResultCollection(False, device)
new_results.load_state_dict(state_dict, map_location='cpu')
assert new_results['validation_step.v'].meta.sync.fn == _Sync.no_op
# check map location
assert new_results['validation_step.v'].value.device.type == 'cpu'
model = LoggingModel()
ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
trainer = Trainer(
gpus=1 if device == 'cuda' else 0,