lightning/docs/source-app/levels/intermediate/scripts/two_work_comms.py

36 lines
1.2 KiB
Python

# app.py
import lightning as L
import time
class TrainComponent(L.LightningWork):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.last_checkpoint_path = None
def run(self):
# pretend to train and save a checkpoint every 10 steps
for step in (range(1000)):
time.sleep(1.0)
fake_loss = round(1/(step + 0.00001), 4)
print(f'{step=}: {fake_loss=} ')
if step % 10 == 0:
self.last_checkpoint_path = f'/some/path/{step=}_{fake_loss=}'
print(f'TRAIN COMPONENT: saved new checkpoint: {self.last_checkpoint_path}')
class ModelDeploymentComponent(L.LightningWork):
def run(self, new_checkpoint):
print(f'DEPLOY COMPONENT: load new model from checkpoint: {new_checkpoint}')
class ContinuousDeployment(L.LightningFlow):
def __init__(self) -> None:
super().__init__()
self.train = TrainComponent(parallel=True)
self.model_deployment = ModelDeploymentComponent(parallel=True)
def run(self):
self.train.run()
if self.train.last_checkpoint_path:
self.model_deployment.run(self.train.last_checkpoint_path)
app = L.LightningApp(ContinuousDeployment())