lightning/pytorch_lightning/core/saving.py

66 lines
1.7 KiB
Python
Raw Normal View History

import csv
2020-03-12 16:41:37 +00:00
import os
from argparse import Namespace
from typing import Union, Dict, Any
from pytorch_lightning import _logger as log
2019-03-31 01:45:16 +00:00
class ModelIO(object):
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
2019-03-31 01:45:16 +00:00
"""
Do something with the checkpoint
Gives model a chance to load something before state_dict is restored
2019-03-31 01:45:16 +00:00
:param checkpoint:
:return:
"""
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
2019-03-31 01:45:16 +00:00
"""
2019-07-27 02:04:27 +00:00
Give the model a chance to add something to the checkpoint.
state_dict is already there
2019-03-31 01:45:16 +00:00
"""
# -------------------------
# OPTIONAL HOOKS
# -------------------------
def on_hpc_save(self, checkpoint: Dict[str, Any]) -> None:
"""
Hook to do whatever you need right before Slurm manager saves the model
"""
def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
"""
Hook to do whatever you need right before Slurm manager loads the model
"""
def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()
with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
ns = Namespace(**tags)
return ns
def convert(val: str) -> Union[int, float, bool, str]:
constructors = [int, float, str]
if isinstance(val, str):
if val.lower() == 'true':
return True
if val.lower() == 'false':
return False
for c in constructors:
try:
return c(val)
except ValueError:
pass
return val