lightning/tests/tests_pytorch/accelerators/test_cpu.py

80 lines
2.7 KiB
Python

import os
from pathlib import Path
from typing import Any, Dict, Union
from unittest.mock import Mock
import lightning.pytorch as pl
import pytest
import torch
from lightning.fabric.plugins import TorchCheckpointIO
from lightning.pytorch import Trainer
from lightning.pytorch.accelerators import CPUAccelerator
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins.precision.precision import Precision
from lightning.pytorch.strategies import SingleDeviceStrategy
from tests_pytorch.helpers.runif import RunIf
def test_restore_checkpoint_after_pre_setup_default():
"""Assert default for restore_checkpoint_after_setup is False."""
plugin = SingleDeviceStrategy(
accelerator=CPUAccelerator(), device=torch.device("cpu"), precision_plugin=Precision()
)
assert not plugin.restore_checkpoint_after_setup
def test_availability():
assert CPUAccelerator.is_available()
@RunIf(psutil=True)
def test_get_device_stats(tmpdir):
gpu_stats = CPUAccelerator().get_device_stats(Mock())
fields = ["cpu_vm_percent", "cpu_percent", "cpu_swap_percent"]
for f in fields:
assert any(f in h for h in gpu_stats)
@pytest.mark.parametrize("restore_after_pre_setup", [True, False])
def test_restore_checkpoint_after_pre_setup(tmpdir, restore_after_pre_setup):
"""Test to ensure that if restore_checkpoint_after_setup is True, then we only load the state after pre- dispatch
is called."""
class TestPlugin(SingleDeviceStrategy):
setup_called = False
def setup(self, trainer: "pl.Trainer") -> None:
super().setup(trainer)
self.setup_called = True
@property
def restore_checkpoint_after_setup(self) -> bool:
return restore_after_pre_setup
def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]:
assert self.setup_called == restore_after_pre_setup
return super().load_checkpoint(checkpoint_path)
model = BoringModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
plugin = TestPlugin(
accelerator=CPUAccelerator(),
precision_plugin=Precision(),
device=torch.device("cpu"),
checkpoint_io=TorchCheckpointIO(),
)
assert plugin.restore_checkpoint_after_setup == restore_after_pre_setup
trainer = Trainer(default_root_dir=tmpdir, strategy=plugin, fast_dev_run=True)
trainer.fit(model, ckpt_path=checkpoint_path)
for func in (trainer.test, trainer.validate, trainer.predict):
plugin.setup_called = False
func(model, ckpt_path=checkpoint_path)