59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
from pathlib import Path
|
|
|
|
import optuna
|
|
from objective import ObjectiveWork
|
|
|
|
import lightning as L
|
|
from lightning.app.structures import Dict
|
|
|
|
|
|
class RootHPOFlow(L.LightningFlow):
|
|
def __init__(self, script_path, data_dir, total_trials, simultaneous_trials):
|
|
super().__init__()
|
|
self.script_path = script_path
|
|
self.data_dir = data_dir
|
|
self.total_trials = total_trials
|
|
self.simultaneous_trials = simultaneous_trials
|
|
self.num_trials = simultaneous_trials
|
|
self._study = optuna.create_study()
|
|
self.ws = Dict()
|
|
|
|
def run(self):
|
|
if self.num_trials >= self.total_trials:
|
|
self.stop()
|
|
|
|
has_told_study = []
|
|
|
|
for trial_idx in range(self.num_trials):
|
|
work_name = f"objective_work_{trial_idx}"
|
|
if work_name not in self.ws:
|
|
objective_work = ObjectiveWork(
|
|
script_path=self.script_path,
|
|
data_dir=self.data_dir,
|
|
cloud_compute=L.CloudCompute("cpu"),
|
|
)
|
|
self.ws[work_name] = objective_work
|
|
if not self.ws[work_name].has_started:
|
|
trial = self._study.ask(ObjectiveWork.distributions())
|
|
self.ws[work_name].run(trial_id=trial._trial_id, **trial.params)
|
|
|
|
if self.ws[work_name].metric and not self.ws[work_name].has_told_study:
|
|
self._study.tell(self.ws[work_name].trial_id, self.ws[work_name].metric)
|
|
self.ws[work_name].has_told_study = True
|
|
|
|
has_told_study.append(self.ws[work_name].has_told_study)
|
|
|
|
if all(has_told_study):
|
|
self.num_trials += self.simultaneous_trials
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app = L.LightningApp(
|
|
RootHPOFlow(
|
|
script_path=str(Path(__file__).parent / "pl_script.py"),
|
|
data_dir="data/hymenoptera_data_version_0",
|
|
total_trials=6,
|
|
simultaneous_trials=2,
|
|
)
|
|
)
|