lightning/examples/app_components/python/pl_script.py

66 lines
1.9 KiB
Python
Raw Normal View History

import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size: int, length: int):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
return self.layer(x)
def loss(self, batch, prediction):
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
def training_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
return {"x": loss}
def test_step(self, batch, batch_idx):
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64))
val_dataloader = train_dataloader
test_dataloader = train_dataloader
predict_dataloader = train_dataloader
if __name__ == "__main__":
model = BoringModel()
trainer = Trainer(max_epochs=1, accelerator="cpu", devices=2, strategy="ddp")
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)