remove parameterize from TPU tests (#2561)

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu

* added base tests for tpu
This commit is contained in:
William Falcon 2020-07-09 06:46:07 -04:00 committed by GitHub
parent 69cbb62774
commit a95ef5a4ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 19 deletions

View File

@ -50,7 +50,7 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
result = trainer.fit(model) result = trainer.fit(model)
# correct result and ok accuracy # correct result and ok accuracy
assert result == 1, 'amp + ddp model failed to complete' assert result == 1, 'trainer failed'
# test model loading # test model loading
pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path) pretrained_model = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path)

View File

@ -7,6 +7,8 @@ from pytorch_lightning import Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate from tests.base import EvalModelTemplate
import tests.base.develop_pipelines as tpipes import tests.base.develop_pipelines as tpipes
from tests.base.datasets import TrialMNIST
from torch.utils.data import DataLoader
try: try:
import torch_xla import torch_xla
@ -21,14 +23,13 @@ else:
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.parametrize("tpu_cores", [1, [1], 8]) def test_base_tpu_model_1(tmpdir):
def test_base_tpu_model(tmpdir, tpu_cores):
"""Make sure model trains on TPU.""" """Make sure model trains on TPU."""
trainer_options = dict( trainer_options = dict(
default_root_dir=tmpdir, default_root_dir=tmpdir,
progress_bar_refresh_rate=0, progress_bar_refresh_rate=0,
max_epochs=1, max_epochs=1,
tpu_cores=tpu_cores, tpu_cores=1,
limit_train_batches=0.4, limit_train_batches=0.4,
limit_val_batches=0.4 limit_val_batches=0.4
) )
@ -38,23 +39,104 @@ def test_base_tpu_model(tmpdir, tpu_cores):
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.parametrize("tpu_cores", [1, [1], 8]) def test_base_tpu_model_idx_1(tmpdir):
def test_base_tpu_16bit_model(tmpdir, tpu_cores):
"""Make sure model trains on TPU.""" """Make sure model trains on TPU."""
trainer_options = dict( trainer_options = dict(
default_root_dir=tmpdir, default_root_dir=tmpdir,
precision=16,
progress_bar_refresh_rate=0, progress_bar_refresh_rate=0,
max_epochs=1, max_epochs=1,
tpu_cores=tpu_cores, tpu_cores=[1],
limit_train_batches=0.4,
limit_val_batches=0.4
)
model = EvalModelTemplate()
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_base_tpu_model_8(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=8,
limit_train_batches=0.4, limit_train_batches=0.4,
limit_val_batches=0.4 limit_val_batches=0.4
) )
model = EvalModelTemplate() model = EvalModelTemplate()
tpipes.run_model_test(trainer_options, model, on_gpu=False) # 8 cores needs a big dataset
def long_train_loader():
dataset = DataLoader(TrialMNIST(download=True, num_samples=15000, digits=(0, 1, 2, 5, 8)), batch_size=32)
return dataset
model.train_dataloader = long_train_loader
model.val_dataloader = long_train_loader
tpipes.run_model_test(trainer_options, model, on_gpu=False, with_hpc=False)
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_base_tpu_16bit_model_core_1(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
precision=16,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=1,
limit_train_batches=0.4,
limit_val_batches=0.4
)
model = EvalModelTemplate()
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_base_tpu_16bit_model_idx_core(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
precision=16,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=[1],
limit_train_batches=0.4,
limit_val_batches=0.4
)
model = EvalModelTemplate()
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_base_tpu_16bit_model_8_cores(tmpdir):
"""Make sure model trains on TPU."""
trainer_options = dict(
default_root_dir=tmpdir,
precision=16,
progress_bar_refresh_rate=0,
max_epochs=1,
tpu_cores=8,
limit_train_batches=0.4,
limit_val_batches=0.4
)
model = EvalModelTemplate()
# 8 cores needs a big dataset
def long_train_loader():
dataset = DataLoader(TrialMNIST(download=True, num_samples=15000, digits=(0, 1, 2, 5, 8)), batch_size=32)
return dataset
model.train_dataloader = long_train_loader
model.val_dataloader = long_train_loader
tpipes.run_model_test(trainer_options, model, on_gpu=False)
assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables" assert os.environ.get('XLA_USE_BF16') == str(1), "XLA_USE_BF16 was not set in environment variables"
@ -80,11 +162,24 @@ def test_early_stop_checkpoints_on_tpu(tmpdir, tpu_cores, expected_device):
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
@pytest.mark.parametrize(['tpu_cores', 'expected_device'], [ def test_early_stop_checkpoints_on_tpu(tmpdir):
pytest.param([1], 'xla:1'), """Test if single TPU core training works"""
pytest.param([8], 'xla:8'), model = EvalModelTemplate()
]) trainer = Trainer(
def test_single_tpu_core_model(tmpdir, tpu_cores, expected_device): early_stop_callback=True,
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
max_epochs=50,
limit_train_batches=10,
limit_val_batches=10,
tpu_cores=1,
)
trainer.fit(model)
assert torch_xla._XLAC._xla_get_default_device() == 'xla:1'
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_single_tpu_core_model(tmpdir):
"""Test if single TPU core training works""" """Test if single TPU core training works"""
model = EvalModelTemplate() model = EvalModelTemplate()
trainer = Trainer( trainer = Trainer(
@ -93,15 +188,14 @@ def test_single_tpu_core_model(tmpdir, tpu_cores, expected_device):
max_epochs=1, max_epochs=1,
train_percent_check=0.1, train_percent_check=0.1,
val_percent_check=0.1, val_percent_check=0.1,
tpu_cores=tpu_cores, tpu_cores=8,
) )
trainer.fit(model) trainer.fit(model)
assert torch_xla._XLAC._xla_get_default_device() == expected_device assert torch_xla._XLAC._xla_get_default_device() == 'xla:8'
@pytest.mark.parametrize("tpu_cores", [1, 8])
@pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine") @pytest.mark.skipif(not TPU_AVAILABLE, reason="test requires TPU machine")
def test_multi_core_tpu_model(tmpdir, tpu_cores): def test_multi_core_tpu_model(tmpdir):
"""Test if distributed TPU core training works""" """Test if distributed TPU core training works"""
model = EvalModelTemplate() model = EvalModelTemplate()
trainer = Trainer( trainer = Trainer(
@ -109,7 +203,7 @@ def test_multi_core_tpu_model(tmpdir, tpu_cores):
max_epochs=1, max_epochs=1,
train_percent_check=0.4, train_percent_check=0.4,
val_percent_check=0.2, val_percent_check=0.2,
tpu_cores=tpu_cores, tpu_cores=[1, 8],
) )
trainer.fit(model) trainer.fit(model)
assert trainer.tpu_id is None assert trainer.tpu_id is None