lightning/tests/accelerators/test_cpu.py

90 lines
3.6 KiB
Python
Raw Normal View History

import os
from pathlib import Path
from typing import Any, Dict, Union
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.plugins import SingleDevicePlugin
2021-08-13 16:35:31 +00:00
from pytorch_lightning.plugins.io.torch_plugin import TorchCheckpointIO
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from tests.helpers.boring_model import BoringModel
@pytest.mark.parametrize("delay_dispatch", [True, False])
def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
"""Test when using a custom training type plugin that delays setup optimizers, we do not call setup optimizers
till ``pre_dispatch``."""
class TestModel(BoringModel):
def on_fit_start(self):
if delay_dispatch:
# Ensure we haven't setup optimizers if we've delayed dispatch
assert len(self.trainer.optimizers) == 0
else:
assert len(self.trainer.optimizers) > 0
def on_fit_end(self):
assert len(self.trainer.optimizers) > 0
class CustomPlugin(SingleDevicePlugin):
@property
def setup_optimizers_in_pre_dispatch(self) -> bool:
return delay_dispatch
model = TestModel()
Add PyTorch 1.8 Profiler 5/5 (#6618) * Refactor profilers * Update PassThrough * WIP - This is broken and will change * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: thomas chaton <thomas@grid.ai> * resolve tests * resolve tests * find output * try something * update * add support for test and predict * update * update * use getattr * test * test * update * tests * update * update * update * update * update * remove file * update * update * update * update * update * test * update# * update * update tests * update * add suport for 1.8 * rename records * add support for 1.8 * update * resolve flake8 * resolve test * Refactor basic profilers * Fixes * Unused import * Introduce setup * Profile on all ranks. Print to stdout on 0 * Introduce dirpath + filename * CHANGELOG * Add tests. Address comments * add `on_run_stage_setup` * add on_run_stage_setup function * update * add test for RegisterRecordFunction * update lightnng flow direction * move variable to private * remove trace * Undo code that should be in 3/4 * Multi-stage multi-rank * 2/5 changes * Pass stage in __del__ * Remove TODOs * Describe on_evaluation_end. Add tests * Typo * Address comments * deepcopy tests * Advanced teardown * Fix teardown test * Fix tests * Minor change * Update CHANGELOG.md * Fix test * Quick fixes * Fix 6522 * resolve ddp tests * resolve tests * resolve some tests * update tests * resolve tests * update * resolve tests * resolve some tests * Missed fixes from 3/5 * Fixes * resolve some tests * resolve test for 1.7.1 * Broken refactor * Missed stage * Minor changes * resolve tests * Update CHANGELOG * resolve bug * remove print * Typo * Cleanup * resolve ddp test * remove barrier * update profiler * update * Smaller model * update * resolve tests * update * Minor changes. CHANGELOG * Minimize diff * update to 1.8.1 * RunIf. Extra code. Check segfault * resolve tests * Typo. Bad merge * Fixing a bad merge * replace for kineto * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Update pytorch_lightning/profiler/pytorch.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * Minor changes * Bad merge * Use lists for flexibility * Use sets * predict_step * Ananth's suggestion * update * Docs * Update pl_examples/basic_examples/profiler_example.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * update example * update example Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-03-23 20:43:21 +00:00
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, plugins=CustomPlugin(device=torch.device("cpu")))
trainer.fit(model)
def test_restore_checkpoint_after_pre_dispatch_default():
"""Assert default for restore_checkpoint_after_pre_dispatch is False."""
plugin = SingleDevicePlugin(torch.device("cpu"))
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())
assert not accelerator.restore_checkpoint_after_pre_dispatch
assert not plugin.restore_checkpoint_after_pre_dispatch
@pytest.mark.parametrize("restore_after_pre_dispatch", [True, False])
def test_restore_checkpoint_after_pre_dispatch(tmpdir, restore_after_pre_dispatch):
"""Test to ensure that if restore_checkpoint_after_pre_dispatch is True, then we only load the state after pre-
dispatch is called."""
class TestPlugin(SingleDevicePlugin):
predispatched_called = False
def pre_dispatch(self) -> None:
super().pre_dispatch()
self.predispatched_called = True
@property
def restore_checkpoint_after_pre_dispatch(self) -> bool:
return restore_after_pre_dispatch
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert self.predispatched_called == restore_after_pre_dispatch
return super().load_checkpoint(checkpoint_path)
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
2021-08-13 16:35:31 +00:00
plugin = TestPlugin(torch.device("cpu"), checkpoint_io=TorchCheckpointIO())
accelerator = CPUAccelerator(training_type_plugin=plugin, precision_plugin=PrecisionPlugin())
assert accelerator.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
assert plugin.restore_checkpoint_after_pre_dispatch == restore_after_pre_dispatch
trainer = Trainer(
default_root_dir=tmpdir, accelerator=accelerator, fast_dev_run=True, resume_from_checkpoint=checkpoint_path
)
trainer.fit(model)
for func in (trainer.test, trainer.validate, trainer.predict):
accelerator.training_type_plugin.predispatched_called = False
func(model, ckpt_path=checkpoint_path)