corrent bugs

This commit is contained in:
Your Name 2021-02-09 15:31:58 +00:00
parent 607d682798
commit ffb8d47547
4 changed files with 9 additions and 3 deletions

View File

@ -171,6 +171,9 @@ class DDPSpawnPlugin(ParallelPlugin):
return None return None
return [self.root_device.index] return [self.root_device.index]
def on_save(self, checkpoint: dict) -> dict:
return checkpoint
def transfer_distrib_spawn_state_on_fit_end(self, results): def transfer_distrib_spawn_state_on_fit_end(self, results):
# TODO: is there a better way than accessing callback through model -> trainer -> callback? # TODO: is there a better way than accessing callback through model -> trainer -> callback?
best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path best_model_path = self.lightning_module.trainer.checkpoint_callback.best_model_path
@ -183,7 +186,7 @@ class DDPSpawnPlugin(ParallelPlugin):
# TODO: is there a better way than accessing trainer through model -> trainer? # TODO: is there a better way than accessing trainer through model -> trainer?
if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0: if not self.lightning_module.trainer.testing and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
atomic_save(self.lightning_module.state_dict(), last_path) atomic_save(self.on_save(self.lightning_module.state_dict()), last_path)
# todo, pass complete checkpoint as state dictionary # todo, pass complete checkpoint as state dictionary
self.mp_queue.put(best_model_path) self.mp_queue.put(best_model_path)

View File

@ -95,6 +95,7 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
Recommended on XLA Guide: Recommended on XLA Guide:
https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors https://github.com/pytorch/xla/blob/master/API_GUIDE.md#saving-and-loading-xla-tensors
""" """
print("on_save")
return move_data_to_device(checkpoint, torch.device("cpu")) return move_data_to_device(checkpoint, torch.device("cpu"))
def broadcast(self, obj: object, src: int = 0) -> object: def broadcast(self, obj: object, src: int = 0) -> object:

View File

@ -544,6 +544,8 @@ class TrainLoop:
self.trainer.batch_idx = batch_idx self.trainer.batch_idx = batch_idx
print("batch_idx")
# ------------------------------------ # ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END # TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------ # ------------------------------------

View File

@ -87,8 +87,8 @@ def test_model_tpu_cores_8(tmpdir):
progress_bar_refresh_rate=0, progress_bar_refresh_rate=0,
max_epochs=1, max_epochs=1,
tpu_cores=8, tpu_cores=8,
limit_train_batches=0.4, limit_train_batches=4,
limit_val_batches=0.4, limit_val_batches=4,
) )
model = EvalModelTemplate() model = EvalModelTemplate()