lightning/pytorch_lightning/core/saving.py

68 lines
1.6 KiB
Python
Raw Normal View History

import os
import csv
import logging as log
from argparse import Namespace
2019-03-31 01:45:16 +00:00
class ModelIO(object):
2019-07-27 02:04:27 +00:00
def on_load_checkpoint(self, checkpoint):
2019-03-31 01:45:16 +00:00
"""
Do something with the checkpoint
Gives model a chance to load something before state_dict is restored
2019-03-31 01:45:16 +00:00
:param checkpoint:
:return:
"""
2019-07-27 02:04:27 +00:00
def on_save_checkpoint(self, checkpoint):
2019-03-31 01:45:16 +00:00
"""
2019-07-27 02:04:27 +00:00
Give the model a chance to add something to the checkpoint.
state_dict is already there
2019-03-31 01:45:16 +00:00
"""
# -------------------------
# OPTIONAL HOOKS
# -------------------------
2019-07-27 02:04:27 +00:00
def on_hpc_save(self, checkpoint):
"""
Hook to do whatever you need right before Slurm manager saves the model
:return:
"""
2019-07-27 02:04:27 +00:00
def on_hpc_load(self, checkpoint):
"""
Hook to do whatever you need right before Slurm manager loads the model
:return:
"""
def load_hparams_from_tags_csv(tags_csv):
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()
tags = {}
with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
for row in list(csv_reader)[1:]:
tags[row[0]] = convert(row[1])
ns = Namespace(**tags)
return ns
def convert(val):
constructors = [int, float, str]
if isinstance(val, str):
if val.lower() == 'true':
return True
if val.lower() == 'false':
return False
for c in constructors:
try:
return c(val)
except ValueError:
pass
return val