lightning/tests/core/test_metric_result_integrat...

331 lines
12 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchmetrics import Metric
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
class DummyMetric(Metric):
def __init__(self):
super().__init__()
self.add_state("x", torch.tensor(0), dist_reduce_fx="sum")
def update(self, x):
self.x += x
def compute(self):
return self.x
def _setup_ddp(rank, worldsize):
import os
os.environ["MASTER_ADDR"] = "localhost"
# initialize the process group
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
torch.tensor([1.0])
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
metric_a = metric_a.to(f"cuda:{rank}")
metric_b = metric_b.to(f"cuda:{rank}")
metric_c = metric_c.to(f"cuda:{rank}")
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
for _ in range(3):
cumulative_sum = 0
for i in range(5):
metric_a(i)
metric_b(i)
metric_c(i)
cumulative_sum += i
result.log('h', 'a', metric_a, on_step=True, on_epoch=True)
result.log('h', 'b', metric_b, on_step=False, on_epoch=True)
result.log('h', 'c', metric_c, on_step=True, on_epoch=False)
batch_log = result.metrics(True)[MetricSource.LOG]
assert batch_log == {"a_step": i, "c": i}
epoch_log = result.metrics(False)[MetricSource.LOG]
result.reset()
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x'])
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
@RunIf(skip_windows=True, min_gpus=2)
def test_result_reduce_ddp():
"""Make sure result logging works with DDP"""
tutils.set_random_master_port()
worldsize = 2
mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize)
def test_result_metric_integration():
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
result = ResultCollection(True, torch.device("cpu"))
for _ in range(3):
cumulative_sum = 0
for i in range(5):
metric_a(i)
metric_b(i)
metric_c(i)
cumulative_sum += i
result.log('h', 'a', metric_a, on_step=True, on_epoch=True)
result.log('h', 'b', metric_b, on_step=False, on_epoch=True)
result.log('h', 'c', metric_c, on_step=True, on_epoch=False)
batch_log = result.metrics(True)[MetricSource.LOG]
assert batch_log == {"a_step": i, "c": i}
epoch_log = result.metrics(False)[MetricSource.LOG]
result.reset()
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}
assert str(result) == (
"ResultCollection(True, cpu, {"
"'h.a': ResultMetric(value=DummyMetric()), "
"'h.b': ResultMetric(value=DummyMetric()), "
"'h.c': ResultMetric(value=DummyMetric())"
"})"
)
def test_result_collection_simple_loop():
result = ResultCollection(True, torch.device("cpu"))
current_fx_name = None
batch_idx = None
def lightning_log(fx, *args, **kwargs):
nonlocal current_fx_name
if current_fx_name != fx and batch_idx in (None, 0):
result.reset(metrics=False, fx=fx)
result.log(fx, *args, **kwargs)
current_fx_name = fx
lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
for epoch in range(2):
lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
for batch_idx in range(2):
lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
batch_idx = None
lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
for k in ('a0.a', 'a1.a'):
assert result[k].value == torch.tensor(0.), k
assert result[k].cumulated_batch_size == torch.tensor(1.), k
for k in ('b0.a', 'b1.a'):
assert result[k].value == torch.tensor(1.) + epoch, k
assert result[k].cumulated_batch_size == torch.tensor(1.), k
for k in ('c0.a', 'c1.a', 'c2.a'):
assert result[k].value == torch.tensor(4.) + epoch * 2, k
assert result[k].cumulated_batch_size == torch.tensor(2.), k
for k in ('d0.a', 'd1.a'):
assert result[k].value == torch.tensor(3.) + epoch, k
assert result[k].cumulated_batch_size == torch.tensor(1.), k
def my_sync_dist(x):
return x
def test_result_collection_restoration(tmpdir):
""""
This test make sure metrics are properly reloaded on failure.
"""
result = ResultCollection(True, torch.device("cpu"))
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
metric_d = DummyMetric()
current_fx_name = None
batch_idx = None
def lightning_log(fx, *args, **kwargs):
nonlocal current_fx_name
if current_fx_name != fx and batch_idx in (None, 0):
result.reset(metrics=False, fx=fx)
result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
current_fx_name = fx
for _ in range(2):
cumulative_sum = 0
for i in range(3):
a = metric_a(i)
b = metric_b(i)
c = metric_c(i)
metric_d(i)
cumulative_sum += i
metric = metric_a if i < 1 else metric_d
lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True)
lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True)
lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False)
lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True)
lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True)
lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False)
batch_log = result.metrics(on_step=True)[MetricSource.LOG]
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
assert set(batch_log['c_1']) == {'1', '2'}
result_copy = deepcopy(result)
new_result = ResultCollection(True, torch.device("cpu"))
state_dict = result.state_dict()
# check the sync fn was dropped
assert 'fn' not in state_dict['items']['training_step.a']['meta']['_sync']
new_result.load_state_dict(state_dict)
# should match
assert result_copy == new_result
# the sync fn has been kept
assert result_copy['training_step.a'].meta.sync.fn == new_result['training_step.a'].meta.sync.fn
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
assert epoch_log == epoch_log_copy
lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True)
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
assert epoch_log == {
'a_1_epoch': 1,
'a_epoch': cumulative_sum,
'a': cumulative_sum,
'b': cumulative_sum,
'b_1': 1
}
# make sure can be pickled
pickle.loads(pickle.dumps(result))
# make sure can be torch.loaded
filepath = str(tmpdir / 'result')
torch.save(result, filepath)
torch.load(filepath)
# assert metric state reset to default values
result.reset()
assert metric_a.x == metric_a._defaults['x']
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
batch_idx = None
@pytest.mark.parametrize('device', ('cpu', pytest.param('cuda', marks=RunIf(min_gpus=1))))
def test_lightning_module_logging_result_collection(tmpdir, device):
class LoggingModel(BoringModel):
def __init__(self):
super().__init__()
self.metric = DummyMetric()
def validation_step(self, batch, batch_idx):
v = self.metric(batch_idx)
self.log_dict({"v": v, "m": self.metric})
return super().validation_step(batch, batch_idx)
def on_save_checkpoint(self, checkpoint) -> None:
results = self.trainer._results
state_dict = results.state_dict()
# check device
assert results['validation_step.v'].value.device.type == device
assert state_dict['items']['validation_step.v']['value'].device.type == device
# sync fn should be kept
assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce
# sync fn dropped from the state dict
assert 'fn' not in state_dict['items']['validation_step.v']['meta']['_sync']
results.load_state_dict(state_dict)
# check device after loading
assert results['validation_step.v'].value.device.type == device
# sync fn was preserved in the original result
assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce
# default sync fn
new_results = ResultCollection(False, device)
new_results.load_state_dict(state_dict, map_location='cpu')
assert new_results['validation_step.v'].meta.sync.fn == _Sync.no_op
# check map location
assert new_results['validation_step.v'].value.device.type == 'cpu'
model = LoggingModel()
ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=2,
limit_val_batches=2,
callbacks=[ckpt],
gpus=1 if device == 'cuda' else 0,
)
trainer.fit(model)