2022-09-26 18:50:11 +00:00
|
|
|
from typing import Any, Iterator
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
ruff: replace isort with ruff +TPU (#17684)
* ruff: replace isort with ruff
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing & imports
* lines in warning test
* docs
* fix enum import
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing
* import
* fix lines
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* type ClusterEnvironment
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
|
|
|
from lightning.fabric import Fabric
|
2022-09-26 18:50:11 +00:00
|
|
|
from torch import Tensor
|
|
|
|
from torch.nn import Module
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
from torch.utils.data import DataLoader, Dataset, IterableDataset
|
|
|
|
|
|
|
|
|
|
|
|
class RandomDataset(Dataset):
|
|
|
|
def __init__(self, size: int, length: int) -> None:
|
|
|
|
self.len = length
|
|
|
|
self.data = torch.randn(length, size)
|
|
|
|
|
|
|
|
def __getitem__(self, index: int) -> Tensor:
|
|
|
|
return self.data[index]
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
return self.len
|
|
|
|
|
|
|
|
|
|
|
|
class RandomIterableDataset(IterableDataset):
|
|
|
|
def __init__(self, size: int, count: int) -> None:
|
|
|
|
self.count = count
|
|
|
|
self.size = size
|
|
|
|
|
|
|
|
def __iter__(self) -> Iterator[Tensor]:
|
|
|
|
for _ in range(self.count):
|
|
|
|
yield torch.randn(self.size)
|
|
|
|
|
|
|
|
|
2023-01-10 15:02:05 +00:00
|
|
|
class BoringFabric(Fabric):
|
2022-09-26 18:50:11 +00:00
|
|
|
def get_model(self) -> Module:
|
|
|
|
return nn.Linear(32, 2)
|
|
|
|
|
2022-09-29 17:46:49 +00:00
|
|
|
def get_optimizer(self, module: Module) -> Optimizer:
|
2023-04-16 18:11:49 +00:00
|
|
|
return torch.optim.Adam(module.parameters(), lr=0.1)
|
2022-09-29 17:46:49 +00:00
|
|
|
|
2022-09-26 18:50:11 +00:00
|
|
|
def get_dataloader(self) -> DataLoader:
|
|
|
|
return DataLoader(RandomDataset(32, 64))
|
|
|
|
|
|
|
|
def step(self, model: Module, batch: Any) -> Tensor:
|
|
|
|
output = model(batch)
|
2023-05-05 09:34:40 +00:00
|
|
|
return torch.nn.functional.mse_loss(output, torch.ones_like(output))
|
2022-09-26 18:50:11 +00:00
|
|
|
|
2023-02-27 23:44:13 +00:00
|
|
|
def after_backward(self, model: Module, optimizer: Optimizer) -> None:
|
2022-09-26 18:50:11 +00:00
|
|
|
pass
|
|
|
|
|
|
|
|
def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None:
|
|
|
|
pass
|
|
|
|
|
|
|
|
def run(self) -> None:
|
2023-06-15 16:02:09 +00:00
|
|
|
with self.init_module():
|
|
|
|
model = self.get_model()
|
2023-04-11 19:58:53 +00:00
|
|
|
optimizer = self.get_optimizer(model)
|
|
|
|
model, optimizer = self.setup(model, optimizer)
|
2022-09-26 18:50:11 +00:00
|
|
|
|
2023-01-11 17:08:18 +00:00
|
|
|
dataloader = self.get_dataloader()
|
2022-09-26 18:50:11 +00:00
|
|
|
dataloader = self.setup_dataloaders(dataloader)
|
|
|
|
|
2022-09-29 17:46:49 +00:00
|
|
|
self.model = model
|
|
|
|
self.optimizer = optimizer
|
|
|
|
self.dataloader = dataloader
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
|
2022-09-26 18:50:11 +00:00
|
|
|
data_iter = iter(dataloader)
|
|
|
|
batch = next(data_iter)
|
|
|
|
loss = self.step(model, batch)
|
|
|
|
self.backward(loss)
|
2023-02-27 23:44:13 +00:00
|
|
|
self.after_backward(model, optimizer)
|
2022-09-26 18:50:11 +00:00
|
|
|
optimizer.step()
|
|
|
|
self.after_optimizer_step(model, optimizer)
|
|
|
|
optimizer.zero_grad()
|