lightning/tests/tests_fabric/helpers/models.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

77 lines
2.2 KiB
Python
Raw Normal View History

from typing import Any, Iterator
import torch
import torch.nn as nn
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
from lightning.fabric import Fabric
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.utils.data import DataLoader, Dataset, IterableDataset
class RandomDataset(Dataset):
def __init__(self, size: int, length: int) -> None:
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index: int) -> Tensor:
return self.data[index]
def __len__(self) -> int:
return self.len
class RandomIterableDataset(IterableDataset):
def __init__(self, size: int, count: int) -> None:
self.count = count
self.size = size
def __iter__(self) -> Iterator[Tensor]:
for _ in range(self.count):
yield torch.randn(self.size)
class BoringFabric(Fabric):
def get_model(self) -> Module:
return nn.Linear(32, 2)
def get_optimizer(self, module: Module) -> Optimizer:
return torch.optim.Adam(module.parameters(), lr=0.1)
def get_dataloader(self) -> DataLoader:
return DataLoader(RandomDataset(32, 64))
def step(self, model: Module, batch: Any) -> Tensor:
output = model(batch)
2023-05-05 09:34:40 +00:00
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
def after_backward(self, model: Module, optimizer: Optimizer) -> None:
pass
def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None:
pass
def run(self) -> None:
with self.init_module():
model = self.get_model()
optimizer = self.get_optimizer(model)
model, optimizer = self.setup(model, optimizer)
dataloader = self.get_dataloader()
dataloader = self.setup_dataloaders(dataloader)
self.model = model
self.optimizer = optimizer
self.dataloader = dataloader
model.train()
data_iter = iter(dataloader)
batch = next(data_iter)
loss = self.step(model, batch)
self.backward(loss)
self.after_backward(model, optimizer)
optimizer.step()
self.after_optimizer_step(model, optimizer)
optimizer.zero_grad()