162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
from pytorch_lightning import Trainer
|
|
from tests.base.deterministic_model import DeterministicModel
|
|
import pytest
|
|
import torch
|
|
|
|
|
|
@pytest.mark.spawn
|
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
|
def test_training_step_dict(tmpdir):
|
|
"""
|
|
Tests that only training_step can be used
|
|
"""
|
|
model = DeterministicModel()
|
|
model.training_step = model.training_step_dict_return
|
|
model.val_dataloader = None
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=True,
|
|
precision=16,
|
|
gpus=1,
|
|
weights_summary=None,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
# make sure correct steps were called
|
|
assert model.training_step_called
|
|
assert not model.training_step_end_called
|
|
assert not model.training_epoch_end_called
|
|
|
|
# make sure training outputs what is expected
|
|
for batch_idx, batch in enumerate(model.train_dataloader()):
|
|
break
|
|
|
|
out = trainer.run_training_batch(batch, batch_idx)
|
|
assert out.signal == 0
|
|
assert out.batch_log_metrics['log_acc1'] == 12.0
|
|
assert out.batch_log_metrics['log_acc2'] == 7.0
|
|
|
|
train_step_out = out.training_step_output_for_epoch_end
|
|
pbar_metrics = train_step_out['progress_bar']
|
|
assert 'log' in train_step_out
|
|
assert 'progress_bar' in train_step_out
|
|
assert train_step_out['train_step_test'] == 549
|
|
assert pbar_metrics['pbar_acc1'] == 17.0
|
|
assert pbar_metrics['pbar_acc2'] == 19.0
|
|
|
|
# make sure the optimizer closure returns the correct things
|
|
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
|
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
|
|
|
|
|
def training_step_with_step_end(tmpdir):
|
|
"""
|
|
Checks train_step + training_step_end
|
|
"""
|
|
model = DeterministicModel()
|
|
model.training_step = model.training_step_for_step_end_dict
|
|
model.training_step_end = model.training_step_end_dict
|
|
model.val_dataloader = None
|
|
|
|
trainer = Trainer(fast_dev_run=True, weights_summary=None)
|
|
trainer.fit(model)
|
|
|
|
# make sure correct steps were called
|
|
assert model.training_step_called
|
|
assert model.training_step_end_called
|
|
assert not model.training_epoch_end_called
|
|
|
|
# make sure training outputs what is expected
|
|
for batch_idx, batch in enumerate(model.train_dataloader()):
|
|
break
|
|
|
|
out = trainer.run_training_batch(batch, batch_idx)
|
|
assert out.signal == 0
|
|
assert out.batch_log_metrics['log_acc1'] == 14.0
|
|
assert out.batch_log_metrics['log_acc2'] == 9.0
|
|
|
|
train_step_end_out = out.training_step_output_for_epoch_end
|
|
pbar_metrics = train_step_end_out['progress_bar']
|
|
assert 'train_step_end' in train_step_end_out
|
|
assert pbar_metrics['pbar_acc1'] == 19.0
|
|
assert pbar_metrics['pbar_acc2'] == 21.0
|
|
|
|
|
|
def test_full_training_loop_dict(tmpdir):
|
|
"""
|
|
Checks train_step + training_step_end + training_epoch_end
|
|
"""
|
|
model = DeterministicModel()
|
|
model.training_step = model.training_step_for_step_end_dict
|
|
model.training_step_end = model.training_step_end_dict
|
|
model.training_epoch_end = model.training_epoch_end_dict
|
|
model.val_dataloader = None
|
|
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
max_epochs=1,
|
|
weights_summary=None,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
# make sure correct steps were called
|
|
assert model.training_step_called
|
|
assert model.training_step_end_called
|
|
assert model.training_epoch_end_called
|
|
|
|
# assert epoch end metrics were added
|
|
assert trainer.callback_metrics['epoch_end_log_1'] == 178
|
|
assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234
|
|
|
|
# make sure training outputs what is expected
|
|
for batch_idx, batch in enumerate(model.train_dataloader()):
|
|
break
|
|
|
|
out = trainer.run_training_batch(batch, batch_idx)
|
|
assert out.signal == 0
|
|
assert out.batch_log_metrics['log_acc1'] == 14.0
|
|
assert out.batch_log_metrics['log_acc2'] == 9.0
|
|
|
|
train_step_end_out = out.training_step_output_for_epoch_end
|
|
pbar_metrics = train_step_end_out['progress_bar']
|
|
assert pbar_metrics['pbar_acc1'] == 19.0
|
|
assert pbar_metrics['pbar_acc2'] == 21.0
|
|
|
|
|
|
def test_train_step_epoch_end(tmpdir):
|
|
"""
|
|
Checks train_step + training_epoch_end (NO training_step_end)
|
|
"""
|
|
model = DeterministicModel()
|
|
model.training_step = model.training_step_dict_return
|
|
model.training_step_end = None
|
|
model.training_epoch_end = model.training_epoch_end_dict
|
|
model.val_dataloader = None
|
|
|
|
trainer = Trainer(max_epochs=1, weights_summary=None)
|
|
trainer.fit(model)
|
|
|
|
# make sure correct steps were called
|
|
assert model.training_step_called
|
|
assert not model.training_step_end_called
|
|
assert model.training_epoch_end_called
|
|
|
|
# assert epoch end metrics were added
|
|
assert trainer.callback_metrics['epoch_end_log_1'] == 178
|
|
assert trainer.progress_bar_metrics['epoch_end_pbar_1'] == 234
|
|
|
|
# make sure training outputs what is expected
|
|
for batch_idx, batch in enumerate(model.train_dataloader()):
|
|
break
|
|
|
|
out = trainer.run_training_batch(batch, batch_idx)
|
|
assert out.signal == 0
|
|
assert out.batch_log_metrics['log_acc1'] == 12.0
|
|
assert out.batch_log_metrics['log_acc2'] == 7.0
|
|
|
|
train_step_end_out = out.training_step_output_for_epoch_end
|
|
pbar_metrics = train_step_end_out['progress_bar']
|
|
assert pbar_metrics['pbar_acc1'] == 17.0
|
|
assert pbar_metrics['pbar_acc2'] == 19.0
|