2020-10-13 11:18:07 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2021-06-17 07:08:22 +00:00
|
|
|
import pickle
|
|
|
|
from copy import deepcopy
|
|
|
|
|
|
|
|
import pytest
|
2020-10-08 02:54:32 +00:00
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.multiprocessing as mp
|
2021-03-16 14:55:31 +00:00
|
|
|
from torchmetrics import Metric
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-02-08 10:52:02 +00:00
|
|
|
import tests.helpers.utils as tutils
|
2021-06-17 07:08:22 +00:00
|
|
|
from pytorch_lightning import Trainer
|
|
|
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
|
|
|
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
|
|
|
|
from tests.helpers import BoringModel
|
2021-03-02 09:36:01 +00:00
|
|
|
from tests.helpers.runif import RunIf
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
class DummyMetric(Metric):
|
2021-02-06 11:07:26 +00:00
|
|
|
|
2020-10-08 02:54:32 +00:00
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.add_state("x", torch.tensor(0), dist_reduce_fx="sum")
|
|
|
|
|
|
|
|
def update(self, x):
|
|
|
|
self.x += x
|
|
|
|
|
|
|
|
def compute(self):
|
|
|
|
return self.x
|
|
|
|
|
|
|
|
|
|
|
|
def _setup_ddp(rank, worldsize):
|
|
|
|
import os
|
|
|
|
|
|
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
|
|
|
|
|
|
# initialize the process group
|
|
|
|
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
|
|
|
|
|
|
|
|
|
|
|
|
def _ddp_test_fn(rank, worldsize):
|
|
|
|
_setup_ddp(rank, worldsize)
|
2020-12-21 05:40:55 +00:00
|
|
|
torch.tensor([1.0])
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
metric_a = DummyMetric()
|
|
|
|
metric_b = DummyMetric()
|
|
|
|
metric_c = DummyMetric()
|
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
metric_a = metric_a.to(f"cuda:{rank}")
|
|
|
|
metric_b = metric_b.to(f"cuda:{rank}")
|
|
|
|
metric_c = metric_c.to(f"cuda:{rank}")
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
for _ in range(3):
|
|
|
|
cumulative_sum = 0
|
2020-10-08 02:54:32 +00:00
|
|
|
for i in range(5):
|
|
|
|
metric_a(i)
|
|
|
|
metric_b(i)
|
|
|
|
metric_c(i)
|
|
|
|
|
|
|
|
cumulative_sum += i
|
|
|
|
|
2021-06-09 14:24:45 +00:00
|
|
|
result.log('h', 'a', metric_a, on_step=True, on_epoch=True)
|
|
|
|
result.log('h', 'b', metric_b, on_step=False, on_epoch=True)
|
|
|
|
result.log('h', 'c', metric_c, on_step=True, on_epoch=False)
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
batch_log = result.metrics(True)[MetricSource.LOG]
|
|
|
|
assert batch_log == {"a_step": i, "c": i}
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
epoch_log = result.metrics(False)[MetricSource.LOG]
|
2021-04-19 13:48:48 +00:00
|
|
|
result.reset()
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
# assert metric state reset to default values
|
2021-06-08 20:20:17 +00:00
|
|
|
assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x'])
|
2020-10-08 02:54:32 +00:00
|
|
|
assert metric_b.x == metric_b._defaults['x']
|
|
|
|
assert metric_c.x == metric_c._defaults['x']
|
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
@RunIf(skip_windows=True, min_gpus=2)
|
2020-10-08 02:54:32 +00:00
|
|
|
def test_result_reduce_ddp():
|
|
|
|
"""Make sure result logging works with DDP"""
|
|
|
|
tutils.set_random_master_port()
|
|
|
|
|
|
|
|
worldsize = 2
|
2021-02-06 11:07:26 +00:00
|
|
|
mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize)
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_result_metric_integration():
|
|
|
|
metric_a = DummyMetric()
|
|
|
|
metric_b = DummyMetric()
|
|
|
|
metric_c = DummyMetric()
|
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
result = ResultCollection(True, torch.device("cpu"))
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
for _ in range(3):
|
2020-10-08 02:54:32 +00:00
|
|
|
cumulative_sum = 0
|
|
|
|
for i in range(5):
|
|
|
|
metric_a(i)
|
|
|
|
metric_b(i)
|
|
|
|
metric_c(i)
|
|
|
|
|
|
|
|
cumulative_sum += i
|
|
|
|
|
2021-06-09 14:24:45 +00:00
|
|
|
result.log('h', 'a', metric_a, on_step=True, on_epoch=True)
|
|
|
|
result.log('h', 'b', metric_b, on_step=False, on_epoch=True)
|
|
|
|
result.log('h', 'c', metric_c, on_step=True, on_epoch=False)
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
batch_log = result.metrics(True)[MetricSource.LOG]
|
|
|
|
assert batch_log == {"a_step": i, "c": i}
|
2020-10-08 02:54:32 +00:00
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
epoch_log = result.metrics(False)[MetricSource.LOG]
|
2021-04-19 13:48:48 +00:00
|
|
|
result.reset()
|
2020-10-08 02:54:32 +00:00
|
|
|
|
|
|
|
# assert metric state reset to default values
|
|
|
|
assert metric_a.x == metric_a._defaults['x']
|
|
|
|
assert metric_b.x == metric_b._defaults['x']
|
|
|
|
assert metric_c.x == metric_c._defaults['x']
|
|
|
|
|
2021-06-08 20:20:17 +00:00
|
|
|
assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}
|
|
|
|
|
|
|
|
assert str(result) == (
|
|
|
|
"ResultCollection(True, cpu, {"
|
|
|
|
"'h.a': ResultMetric(value=DummyMetric()), "
|
|
|
|
"'h.b': ResultMetric(value=DummyMetric()), "
|
|
|
|
"'h.c': ResultMetric(value=DummyMetric())"
|
|
|
|
"})"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def test_result_collection_simple_loop():
|
|
|
|
result = ResultCollection(True, torch.device("cpu"))
|
|
|
|
current_fx_name = None
|
|
|
|
batch_idx = None
|
|
|
|
|
|
|
|
def lightning_log(fx, *args, **kwargs):
|
|
|
|
nonlocal current_fx_name
|
|
|
|
if current_fx_name != fx and batch_idx in (None, 0):
|
|
|
|
result.reset(metrics=False, fx=fx)
|
|
|
|
result.log(fx, *args, **kwargs)
|
|
|
|
current_fx_name = fx
|
|
|
|
|
|
|
|
lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
|
|
|
|
lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
|
|
|
|
for epoch in range(2):
|
|
|
|
lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
|
|
|
|
lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
|
|
|
|
for batch_idx in range(2):
|
|
|
|
lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
|
|
|
|
lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
|
|
|
|
lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
|
|
|
|
batch_idx = None
|
|
|
|
lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
|
|
|
|
lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
|
|
|
|
|
|
|
|
for k in ('a0.a', 'a1.a'):
|
|
|
|
assert result[k].value == torch.tensor(0.), k
|
|
|
|
assert result[k].cumulated_batch_size == torch.tensor(1.), k
|
|
|
|
|
|
|
|
for k in ('b0.a', 'b1.a'):
|
|
|
|
assert result[k].value == torch.tensor(1.) + epoch, k
|
|
|
|
assert result[k].cumulated_batch_size == torch.tensor(1.), k
|
|
|
|
|
|
|
|
for k in ('c0.a', 'c1.a', 'c2.a'):
|
|
|
|
assert result[k].value == torch.tensor(4.) + epoch * 2, k
|
|
|
|
assert result[k].cumulated_batch_size == torch.tensor(2.), k
|
|
|
|
|
|
|
|
for k in ('d0.a', 'd1.a'):
|
|
|
|
assert result[k].value == torch.tensor(3.) + epoch, k
|
|
|
|
assert result[k].cumulated_batch_size == torch.tensor(1.), k
|
2021-06-17 07:08:22 +00:00
|
|
|
|
|
|
|
|
|
|
|
def my_sync_dist(x):
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def test_result_collection_restoration(tmpdir):
|
|
|
|
""""
|
|
|
|
This test make sure metrics are properly reloaded on failure.
|
|
|
|
"""
|
|
|
|
|
|
|
|
result = ResultCollection(True, torch.device("cpu"))
|
|
|
|
metric_a = DummyMetric()
|
|
|
|
metric_b = DummyMetric()
|
|
|
|
metric_c = DummyMetric()
|
|
|
|
metric_d = DummyMetric()
|
|
|
|
current_fx_name = None
|
|
|
|
batch_idx = None
|
|
|
|
|
|
|
|
def lightning_log(fx, *args, **kwargs):
|
|
|
|
nonlocal current_fx_name
|
|
|
|
if current_fx_name != fx and batch_idx in (None, 0):
|
|
|
|
result.reset(metrics=False, fx=fx)
|
|
|
|
result.log(fx, *args, **kwargs, sync_dist_fn=my_sync_dist)
|
|
|
|
current_fx_name = fx
|
|
|
|
|
|
|
|
for _ in range(2):
|
|
|
|
|
|
|
|
cumulative_sum = 0
|
|
|
|
|
|
|
|
for i in range(3):
|
|
|
|
|
|
|
|
a = metric_a(i)
|
|
|
|
b = metric_b(i)
|
|
|
|
c = metric_c(i)
|
|
|
|
metric_d(i)
|
|
|
|
|
|
|
|
cumulative_sum += i
|
|
|
|
|
|
|
|
metric = metric_a if i < 1 else metric_d
|
|
|
|
lightning_log('training_step', 'a', metric, on_step=True, on_epoch=True)
|
|
|
|
lightning_log('training_step', 'b', metric_b, on_step=False, on_epoch=True)
|
|
|
|
lightning_log('training_step', 'c', metric_c, on_step=True, on_epoch=False)
|
|
|
|
lightning_log('training_step', 'a_1', a, on_step=True, on_epoch=True)
|
|
|
|
lightning_log('training_step', 'b_1', b, on_step=False, on_epoch=True)
|
|
|
|
lightning_log('training_step', 'c_1', {'1': c, '2': c}, on_step=True, on_epoch=False)
|
|
|
|
|
|
|
|
batch_log = result.metrics(on_step=True)[MetricSource.LOG]
|
|
|
|
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
|
|
|
|
assert set(batch_log['c_1']) == {'1', '2'}
|
|
|
|
|
|
|
|
result_copy = deepcopy(result)
|
|
|
|
new_result = ResultCollection(True, torch.device("cpu"))
|
|
|
|
state_dict = result.state_dict()
|
|
|
|
# check the sync fn was dropped
|
|
|
|
assert 'fn' not in state_dict['items']['training_step.a']['meta']['_sync']
|
|
|
|
new_result.load_state_dict(state_dict)
|
|
|
|
# should match
|
|
|
|
assert result_copy == new_result
|
|
|
|
# the sync fn has been kept
|
|
|
|
assert result_copy['training_step.a'].meta.sync.fn == new_result['training_step.a'].meta.sync.fn
|
|
|
|
|
|
|
|
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
|
|
|
|
epoch_log_copy = result_copy.metrics(on_step=False)[MetricSource.LOG]
|
|
|
|
assert epoch_log == epoch_log_copy
|
|
|
|
|
|
|
|
lightning_log('train_epoch_end', 'a', metric_a, on_step=False, on_epoch=True)
|
|
|
|
epoch_log = result.metrics(on_step=False)[MetricSource.LOG]
|
|
|
|
assert epoch_log == {
|
|
|
|
'a_1_epoch': 1,
|
|
|
|
'a_epoch': cumulative_sum,
|
|
|
|
'a': cumulative_sum,
|
|
|
|
'b': cumulative_sum,
|
|
|
|
'b_1': 1
|
|
|
|
}
|
|
|
|
|
|
|
|
# make sure can be pickled
|
|
|
|
pickle.loads(pickle.dumps(result))
|
|
|
|
# make sure can be torch.loaded
|
|
|
|
filepath = str(tmpdir / 'result')
|
|
|
|
torch.save(result, filepath)
|
|
|
|
torch.load(filepath)
|
|
|
|
|
|
|
|
# assert metric state reset to default values
|
|
|
|
result.reset()
|
|
|
|
assert metric_a.x == metric_a._defaults['x']
|
|
|
|
assert metric_b.x == metric_b._defaults['x']
|
|
|
|
assert metric_c.x == metric_c._defaults['x']
|
|
|
|
|
|
|
|
batch_idx = None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize('device', ('cpu', pytest.param('cuda', marks=RunIf(min_gpus=1))))
|
|
|
|
def test_lightning_module_logging_result_collection(tmpdir, device):
|
|
|
|
|
|
|
|
class LoggingModel(BoringModel):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.metric = DummyMetric()
|
|
|
|
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
|
|
v = self.metric(batch_idx)
|
|
|
|
self.log_dict({"v": v, "m": self.metric})
|
|
|
|
return super().validation_step(batch, batch_idx)
|
|
|
|
|
|
|
|
def on_save_checkpoint(self, checkpoint) -> None:
|
|
|
|
results = self.trainer._results
|
|
|
|
state_dict = results.state_dict()
|
|
|
|
|
|
|
|
# check device
|
|
|
|
assert results['validation_step.v'].value.device.type == device
|
|
|
|
assert state_dict['items']['validation_step.v']['value'].device.type == device
|
|
|
|
|
|
|
|
# sync fn should be kept
|
|
|
|
assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce
|
|
|
|
|
|
|
|
# sync fn dropped from the state dict
|
|
|
|
assert 'fn' not in state_dict['items']['validation_step.v']['meta']['_sync']
|
|
|
|
results.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
# check device after loading
|
|
|
|
assert results['validation_step.v'].value.device.type == device
|
|
|
|
|
|
|
|
# sync fn was preserved in the original result
|
|
|
|
assert results['validation_step.v'].meta.sync.fn == self.trainer.training_type_plugin.reduce
|
|
|
|
|
|
|
|
# default sync fn
|
|
|
|
new_results = ResultCollection(False, device)
|
|
|
|
new_results.load_state_dict(state_dict, map_location='cpu')
|
|
|
|
assert new_results['validation_step.v'].meta.sync.fn == _Sync.no_op
|
|
|
|
|
|
|
|
# check map location
|
|
|
|
assert new_results['validation_step.v'].value.device.type == 'cpu'
|
|
|
|
|
|
|
|
model = LoggingModel()
|
|
|
|
ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmpdir,
|
|
|
|
max_epochs=2,
|
|
|
|
limit_train_batches=2,
|
|
|
|
limit_val_batches=2,
|
|
|
|
callbacks=[ckpt],
|
|
|
|
gpus=1 if device == 'cuda' else 0,
|
|
|
|
)
|
|
|
|
trainer.fit(model)
|