corrent bugs
This commit is contained in:
parent
607d682798
commit
ffb8d47547
|
@ -171,6 +171,9 @@ class DDPSpawnPlugin(ParallelPlugin):
|
||||||
return None
|
return None
|
||||||
return [self.root_device.index]
|
return [self.root_device.index]
|
||||||
|
|
||||||
|
def on_save(self, checkpoint: dict) -> dict:
|
||||||
|
return checkpoint
|
||||||
|
|
||||||
def transfer_distrib_spawn_state_on_fit_end(self, results):
|
def transfer_distrib_spawn_state_on_fit_end(self, results):
|
||||||
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
|
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
|
||||||
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
|
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
|
||||||
|
@ -183,7 +186,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
||||||
# TODO: is there a better way than accessing trainer through model -> trainer?
|
# TODO: is there a better way than accessing trainer through model -> trainer?
|
||||||
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
|
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
|
||||||
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
|
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
|
||||||
atomic_save(self.lightning_module.state_dict(), last_path)
|
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)
|
||||||
|
|
||||||
# todo, pass complete checkpoint as state dictionary
|
# todo, pass complete checkpoint as state dictionary
|
||||||
self.mp_queue.put(best_model_path)
|
self.mp_queue.put(best_model_path)
|
||||||
|
|
|
@ -95,6 +95,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
||||||
Recommended on XLA Guide:
|
Recommended on XLA Guide:
|
||||||
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
|
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
|
||||||
"""
|
"""
|
||||||
|
print("on_save")
|
||||||
return move_data_to_device(checkpoint, torch.device("cpu"))
|
return move_data_to_device(checkpoint, torch.device("cpu"))
|
||||||
|
|
||||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||||
|
|
|
@ -544,6 +544,8 @@ class TrainLoop:
|
||||||
|
|
||||||
self.trainer.batch_idx = batch_idx
|
self.trainer.batch_idx = batch_idx
|
||||||
|
|
||||||
|
print("batch_idx")
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
# TRAINING_STEP + TRAINING_STEP_END
|
# TRAINING_STEP + TRAINING_STEP_END
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
|
|
|
@ -87,8 +87,8 @@ def test_model_tpu_cores_8(tmpdir):
|
||||||
progress_bar_refresh_rate=0,
|
progress_bar_refresh_rate=0,
|
||||||
max_epochs=1,
|
max_epochs=1,
|
||||||
tpu_cores=8,
|
tpu_cores=8,
|
||||||
limit_train_batches=0.4,
|
limit_train_batches=4,
|
||||||
limit_val_batches=0.4,
|
limit_val_batches=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = EvalModelTemplate()
|
model = EvalModelTemplate()
|
||||||
|
|
Loading…
Reference in New Issue