lightning/docs/source-app/workflows/share_files_between_components/app.py

49 lines
1.4 KiB
Python

import os
import torch
import lightning as L
from lightning.app.storage.path import Path
class ModelTraining(L.LightningWork):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.checkpoints_path = Path("./checkpoints")
def run(self):
# make fake checkpoints
checkpoint_1 = torch.tensor([0, 1, 2, 3, 4])
checkpoint_2 = torch.tensor([0, 1, 2, 3, 4])
os.makedirs(self.checkpoints_path, exist_ok=True)
checkpoint_path = str(self.checkpoints_path / "checkpoint_{}.ckpt")
torch.save(checkpoint_1, str(checkpoint_path).format("1"))
torch.save(checkpoint_2, str(checkpoint_path).format("2"))
class ModelDeploy(L.LightningWork):
def __init__(self, ckpt_path, *args, **kwargs):
super().__init__()
self.ckpt_path = ckpt_path
def run(self):
ckpts = os.listdir(self.ckpt_path)
checkpoint_1 = torch.load(os.path.join(self.ckpt_path, ckpts[0]))
checkpoint_2 = torch.load(os.path.join(self.ckpt_path, ckpts[1]))
print(f"Loaded checkpoint_1: {checkpoint_1}")
print(f"Loaded checkpoint_2: {checkpoint_2}")
class LitApp(L.LightningFlow):
def __init__(self):
super().__init__()
self.train = ModelTraining()
self.deploy = ModelDeploy(ckpt_path=self.train.checkpoints_path)
def run(self):
self.train.run()
self.deploy.run()
app = L.LightningApp(LitApp())