lightning/examples/app_hpo/app_wo_ui.py

59 lines
1.9 KiB
Python

from pathlib import Path
import optuna
from objective import ObjectiveWork
import lightning as L
from lightning.app.structures import Dict
class RootHPOFlow(L.LightningFlow):
def __init__(self, script_path, data_dir, total_trials, simultaneous_trials):
super().__init__()
self.script_path = script_path
self.data_dir = data_dir
self.total_trials = total_trials
self.simultaneous_trials = simultaneous_trials
self.num_trials = simultaneous_trials
self._study = optuna.create_study()
self.ws = Dict()
def run(self):
if self.num_trials >= self.total_trials:
self._exit()
has_told_study = []
for trial_idx in range(self.num_trials):
work_name = f"objective_work_{trial_idx}"
if work_name not in self.ws:
objective_work = ObjectiveWork(
script_path=self.script_path,
data_dir=self.data_dir,
cloud_compute=L.CloudCompute("cpu"),
)
self.ws[work_name] = objective_work
if not self.ws[work_name].has_started:
trial = self._study.ask(ObjectiveWork.distributions())
self.ws[work_name].run(trial_id=trial._trial_id, **trial.params)
if self.ws[work_name].metric and not self.ws[work_name].has_told_study:
self._study.tell(self.ws[work_name].trial_id, self.ws[work_name].metric)
self.ws[work_name].has_told_study = True
has_told_study.append(self.ws[work_name].has_told_study)
if all(has_told_study):
self.num_trials += self.simultaneous_trials
if __name__ == "__main__":
app = L.LightningApp(
RootHPOFlow(
script_path=str(Path(__file__).parent / "pl_script.py"),
data_dir="data/hymenoptera_data_version_0",
total_trials=6,
simultaneous_trials=2,
)
)