lightning/tests/parity_fabric/test_parity_simple.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

154 lines
4.6 KiB
Python
Raw Normal View History

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from copy import deepcopy
from typing import Callable
import pytest
import torch
import torch.distributed
import torch.nn.functional
from lightning.fabric.fabric import Fabric
from parity_fabric.models import ConvNet
from parity_fabric.utils import (
cuda_reset,
get_model_input_dtype,
is_cuda_memory_close,
is_state_dict_equal,
is_timing_close,
make_deterministic,
)
from tests_fabric.helpers.runif import RunIf
def train_torch(
move_to_device: Callable,
precision_context,
input_dtype=torch.float32,
):
make_deterministic(warn_only=True)
memory_stats = {}
model = ConvNet()
model = move_to_device(model)
dataloader = model.get_dataloader()
optimizer = model.get_optimizer()
loss_fn = model.get_loss_function()
memory_stats["start"] = torch.cuda.memory_stats()
model.train()
iteration_timings = []
iterator = iter(dataloader)
for _ in range(model.num_steps):
t0 = time.perf_counter()
inputs, labels = next(iterator)
inputs, labels = move_to_device(inputs), move_to_device(labels)
optimizer.zero_grad()
with precision_context():
outputs = model(inputs.to(input_dtype))
loss = loss_fn(outputs.float(), labels)
loss.backward()
optimizer.step()
t1 = time.perf_counter()
iteration_timings.append(t1 - t0)
memory_stats["end"] = torch.cuda.memory_stats()
return model.state_dict(), torch.tensor(iteration_timings), memory_stats
def train_fabric(fabric):
make_deterministic(warn_only=True)
memory_stats = {}
model = ConvNet()
initial_state_dict = deepcopy(model.state_dict())
optimizer = model.get_optimizer()
model, optimizer = fabric.setup(model, optimizer)
dataloader = model.get_dataloader()
dataloader = fabric.setup_dataloaders(dataloader)
loss_fn = model.get_loss_function()
memory_stats["start"] = torch.cuda.memory_stats()
model.train()
iteration_timings = []
iterator = iter(dataloader)
for _ in range(model.num_steps):
t0 = time.perf_counter()
inputs, labels = next(iterator)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
fabric.backward(loss)
optimizer.step()
t1 = time.perf_counter()
iteration_timings.append(t1 - t0)
memory_stats["end"] = torch.cuda.memory_stats()
# check that the model has changed
assert not is_state_dict_equal(initial_state_dict, model.state_dict())
return model.state_dict(), torch.tensor(iteration_timings), memory_stats
@pytest.mark.flaky(reruns=3)
@pytest.mark.usefixtures("reset_deterministic_algorithm", "reset_cudnn_benchmark")
@pytest.mark.parametrize(
("precision", "accelerator"),
[
(32, "cpu"),
pytest.param(32, "cuda", marks=RunIf(min_cuda_gpus=1)),
# pytest.param(16, "cuda", marks=RunIf(min_cuda_gpus=1)), # TODO: requires GradScaler
pytest.param("bf16", "cpu", marks=RunIf(skip_windows=True)),
pytest.param("bf16", "cuda", marks=RunIf(min_cuda_gpus=1, bf16_cuda=True)),
pytest.param(32, "mps", marks=RunIf(mps=True)),
],
)
def test_parity_single_device(precision, accelerator):
input_dtype = get_model_input_dtype(precision)
cuda_reset()
# Train with Fabric
fabric = Fabric(precision=precision, accelerator=accelerator, devices=1)
state_dict_fabric, timings_fabric, memory_fabric = train_fabric(fabric)
cuda_reset()
# Train with raw PyTorch
state_dict_torch, timings_torch, memory_torch = train_torch(
fabric.to_device, precision_context=fabric.autocast, input_dtype=input_dtype
)
# Compare the final weights
assert is_state_dict_equal(state_dict_torch, state_dict_fabric)
# Compare the time per iteration
assert is_timing_close(timings_torch, timings_fabric, rtol=1e-2, atol=0.1)
# Compare memory usage
if accelerator == "cuda":
assert is_cuda_memory_close(memory_torch["start"], memory_fabric["start"])
assert is_cuda_memory_close(memory_torch["end"], memory_fabric["end"])