corrent bugs
This commit is contained in:
parent
607d682798
commit
ffb8d47547
|
@ -171,6 +171,9 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
return None
|
||||
return [self.root_device.index]
|
||||
|
||||
def on_save(self, checkpoint: dict) -> dict:
|
||||
return checkpoint
|
||||
|
||||
def transfer_distrib_spawn_state_on_fit_end(self, results):
|
||||
# TODO: is there a better way than accessing callback through model -> trainer -> callback?
|
||||
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
|
||||
|
@ -183,7 +186,7 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
# TODO: is there a better way than accessing trainer through model -> trainer?
|
||||
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
|
||||
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
|
||||
atomic_save(self.lightning_module.state_dict(), last_path)
|
||||
atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)
|
||||
|
||||
# todo, pass complete checkpoint as state dictionary
|
||||
self.mp_queue.put(best_model_path)
|
||||
|
|
|
@ -95,6 +95,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
Recommended on XLA Guide:
|
||||
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
|
||||
"""
|
||||
print("on_save")
|
||||
return move_data_to_device(checkpoint, torch.device("cpu"))
|
||||
|
||||
def broadcast(self, obj: object, src: int = 0) -> object:
|
||||
|
|
|
@ -544,6 +544,8 @@ class TrainLoop:
|
|||
|
||||
self.trainer.batch_idx = batch_idx
|
||||
|
||||
print("batch_idx")
|
||||
|
||||
# ------------------------------------
|
||||
# TRAINING_STEP + TRAINING_STEP_END
|
||||
# ------------------------------------
|
||||
|
|
|
@ -87,8 +87,8 @@ def test_model_tpu_cores_8(tmpdir):
|
|||
progress_bar_refresh_rate=0,
|
||||
max_epochs=1,
|
||||
tpu_cores=8,
|
||||
limit_train_batches=0.4,
|
||||
limit_val_batches=0.4,
|
||||
limit_train_batches=4,
|
||||
limit_val_batches=4,
|
||||
)
|
||||
|
||||
model = EvalModelTemplate()
|
||||
|
|
Loading…
Reference in New Issue