Replace meta_tags.csv with hparams.yaml (#1271)
* Add support for hierarchical dict * Support nested Namespace * Add docstring * Migrate hparam flattening to each logger * Modify URLs in CHANGELOG * typo * Simplify the conditional branch about Namespace Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update CHANGELOG.md Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * added examples section to docstring * renamed _dict -> input_dict * mata_tags.csv -> hparams.yaml * code style fixes * add pyyaml * remove unused import * create the member NAME_HPARAMS_FILE * improve tests * Update tensorboard.py * pass the local test w/o relavents of Horovod * formatting * update dependencies * fix dependencies * Apply suggestions from code review * add savings * warn * docstrings * tests * Apply suggestions from code review * saving * Apply suggestions from code review * use default * remove logging * typo fixes * update docs * update CHANGELOG * clean imports * add blank lines * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * back to namespace * add docs * test fix * update dependencies * add space Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
35fe2efe27
commit
22d7d03118
|
@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
- Replace mata_tags.csv with hparams.yaml ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))
|
||||
|
||||
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
|
||||
|
||||
- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
|
||||
|
@ -36,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Deprecated
|
||||
|
||||
- Deprecated `tags_csv` in favor of `hparams_file` ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))
|
||||
|
||||
### Removed
|
||||
|
||||
### Fixed
|
||||
|
|
|
@ -21,10 +21,9 @@ To run the test set on a pre-trained model, use this method.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyLightningModule.load_from_metrics(
|
||||
weights_path='/path/to/pytorch_checkpoint.ckpt',
|
||||
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
|
||||
on_gpu=True,
|
||||
model = MyLightningModule.load_from_checkpoint(
|
||||
checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
|
||||
hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
|
||||
map_location=None
|
||||
)
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ dependencies:
|
|||
- pytorch>=1.1
|
||||
- tensorboard>=1.14
|
||||
- future>=0.17.1
|
||||
- pyyaml>=3.13
|
||||
|
||||
# For dev and testing
|
||||
- tox
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import collections
|
||||
import inspect
|
||||
import os
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from argparse import Namespace
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
|
||||
|
@ -16,7 +17,7 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.core.grads import GradInformation
|
||||
from pytorch_lightning.core.hooks import ModelHooks
|
||||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
|
||||
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams
|
||||
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -1438,29 +1439,49 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
cls,
|
||||
checkpoint_path: str,
|
||||
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
|
||||
tags_csv: Optional[str] = None,
|
||||
hparams_file: Optional[str] = None,
|
||||
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
|
||||
hparam_overrides: Optional[Dict] = None,
|
||||
*args, **kwargs
|
||||
) -> 'LightningModule':
|
||||
r"""
|
||||
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
|
||||
it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
|
||||
with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
|
||||
(output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
|
||||
with an argument called ``hparams`` which is an object of :class:`~dict` or
|
||||
:class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args`
|
||||
when parsing command line arguments).
|
||||
If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`.
|
||||
Any other arguments specified through \*args and \*\*kwargs will be passed to the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# define hparams as Namespace
|
||||
from argparse import Namespace
|
||||
hparams = Namespace(**{'learning_rate': 0.1})
|
||||
|
||||
model = MyModel(hparams)
|
||||
|
||||
class MyModel(LightningModule):
|
||||
def __init__(self, hparams):
|
||||
def __init__(self, hparams: Namespace):
|
||||
self.learning_rate = hparams.learning_rate
|
||||
|
||||
# ----------
|
||||
|
||||
# define hparams as dict
|
||||
hparams = {
|
||||
drop_prob: 0.2,
|
||||
dataloader: {
|
||||
batch_size: 32
|
||||
}
|
||||
}
|
||||
|
||||
model = MyModel(hparams)
|
||||
|
||||
class MyModel(LightningModule):
|
||||
def __init__(self, hparams: dict):
|
||||
self.learning_rate = hparams['learning_rate']
|
||||
|
||||
Args:
|
||||
checkpoint_path: Path to checkpoint.
|
||||
model_args: Any keyword args needed to init the model.
|
||||
|
@ -1468,19 +1489,38 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
If your checkpoint saved a GPU model and you now load on CPUs
|
||||
or a different number of GPUs, use this to map to the new setup.
|
||||
The behaviour is the same as in :func:`torch.load`.
|
||||
tags_csv: Optional path to a .csv file with two columns (key, value)
|
||||
hparams_file: Optional path to a .yaml file with hierarchical structure
|
||||
as in this example::
|
||||
|
||||
drop_prob: 0.2
|
||||
dataloader:
|
||||
batch_size: 32
|
||||
|
||||
You most likely won't need this since Lightning will always save the hyperparameters
|
||||
to the checkpoint.
|
||||
However, if your checkpoint weights don't have the hyperparameters saved,
|
||||
use this method to pass in a .yaml file with the hparams you'd like to use.
|
||||
These will be converted into a :class:`~dict` and passed into your
|
||||
:class:`LightningModule` for use.
|
||||
|
||||
If your model's `hparams` argument is :class:`~argparse.Namespace`
|
||||
and .yaml file has hierarchical structure, you need to refactor your model to treat
|
||||
`hparams` as :class:`~dict`.
|
||||
|
||||
.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
|
||||
tags_csv:
|
||||
.. warning:: .. deprecated:: 0.7.6
|
||||
|
||||
`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.
|
||||
|
||||
Optional path to a .csv file with two columns (key, value)
|
||||
as in this example::
|
||||
|
||||
key,value
|
||||
drop_prob,0.2
|
||||
batch_size,32
|
||||
|
||||
You most likely won't need this since Lightning will always save the hyperparameters
|
||||
to the checkpoint.
|
||||
However, if your checkpoint weights don't have the hyperparameters saved,
|
||||
use this method to pass in a .csv file with the hparams you'd like to use.
|
||||
These will be converted into a :class:`~argparse.Namespace` and passed into your
|
||||
:class:`LightningModule` for use.
|
||||
Use this method to pass in a .csv file with the hparams you'd like to use.
|
||||
hparam_overrides: A dictionary with keys to override in the hparams
|
||||
|
||||
Return:
|
||||
|
@ -1502,7 +1542,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
# or load weights and hyperparameters from separate files.
|
||||
MyLightningModule.load_from_checkpoint(
|
||||
'path/to/checkpoint.ckpt',
|
||||
tags_csv='/path/to/hparams_file.csv'
|
||||
hparams_file='/path/to/hparams_file.yaml'
|
||||
)
|
||||
|
||||
# override some of the params with new values
|
||||
|
@ -1531,9 +1571,22 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
|
||||
# add the hparams from csv file to checkpoint
|
||||
if tags_csv is not None:
|
||||
hparams = load_hparams_from_tags_csv(tags_csv)
|
||||
hparams.__setattr__('on_gpu', False)
|
||||
checkpoint['hparams'] = vars(hparams)
|
||||
hparams_file = tags_csv
|
||||
rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning)
|
||||
|
||||
if hparams_file is not None:
|
||||
extension = hparams_file.split('.')[-1]
|
||||
if extension.lower() in ('csv'):
|
||||
hparams = load_hparams_from_tags_csv(hparams_file)
|
||||
elif extension.lower() in ('yml', 'yaml'):
|
||||
hparams = load_hparams_from_yaml(hparams_file)
|
||||
else:
|
||||
raise ValueError('.csv, .yml or .yaml is required for `hparams_file`')
|
||||
|
||||
hparams['on_gpu'] = False
|
||||
|
||||
# overwrite hparams by the given file
|
||||
checkpoint['hparams'] = hparams
|
||||
|
||||
# override the hparam keys that were passed in
|
||||
if hparam_overrides is not None:
|
||||
|
@ -1549,15 +1602,18 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
|
||||
if cls_takes_hparams:
|
||||
if ckpt_hparams is not None:
|
||||
is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace'
|
||||
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
|
||||
hparams_type = checkpoint.get('hparams_type', 'Namespace')
|
||||
if hparams_type.lower() == 'dict':
|
||||
hparams = ckpt_hparams
|
||||
elif hparams_type.lower() == 'namespace':
|
||||
hparams = Namespace(**ckpt_hparams)
|
||||
else:
|
||||
rank_zero_warn(
|
||||
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
|
||||
" contains argument 'hparams'. Will pass in an empty Namespace instead."
|
||||
" Did you forget to store your model hyperparameters in self.hparams?"
|
||||
)
|
||||
hparams = Namespace()
|
||||
hparams = {}
|
||||
else: # The user's LightningModule does not define a hparams argument
|
||||
if ckpt_hparams is None:
|
||||
hparams = None
|
||||
|
@ -1568,7 +1624,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
|
|||
)
|
||||
|
||||
# load the state_dict on the model automatically
|
||||
if hparams:
|
||||
if cls_takes_hparams:
|
||||
kwargs.update(hparams=hparams)
|
||||
model = cls(*args, **kwargs)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import ast
|
||||
import csv
|
||||
import os
|
||||
import yaml
|
||||
from argparse import Namespace
|
||||
from typing import Union, Dict, Any
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
|
||||
|
||||
class ModelIO(object):
|
||||
|
@ -79,30 +82,78 @@ def update_hparams(hparams: dict, updates: dict) -> None:
|
|||
hparams.update({k: v})
|
||||
|
||||
|
||||
def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
|
||||
if not os.path.isfile(tags_csv):
|
||||
log.warning(f'Missing Tags: {tags_csv}.')
|
||||
return Namespace()
|
||||
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
|
||||
"""Load hparams from a file.
|
||||
|
||||
with open(tags_csv) as f:
|
||||
csv_reader = csv.reader(f, delimiter=',')
|
||||
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
|
||||
>>> path_csv = './testing-hparams.csv'
|
||||
>>> save_hparams_to_tags_csv(path_csv, hparams)
|
||||
>>> hparams_new = load_hparams_from_tags_csv(path_csv)
|
||||
>>> vars(hparams) == hparams_new
|
||||
True
|
||||
>>> os.remove(path_csv)
|
||||
"""
|
||||
if not os.path.isfile(tags_csv):
|
||||
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
|
||||
return {}
|
||||
|
||||
with open(tags_csv) as fp:
|
||||
csv_reader = csv.reader(fp, delimiter=',')
|
||||
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
|
||||
ns = Namespace(**tags)
|
||||
return ns
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
|
||||
if not os.path.isdir(os.path.dirname(tags_csv)):
|
||||
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
|
||||
|
||||
if isinstance(hparams, Namespace):
|
||||
hparams = vars(hparams)
|
||||
|
||||
with open(tags_csv, 'w') as fp:
|
||||
fieldnames = ['key', 'value']
|
||||
writer = csv.DictWriter(fp, fieldnames=fieldnames)
|
||||
writer.writerow({'key': 'key', 'value': 'value'})
|
||||
for k, v in hparams.items():
|
||||
writer.writerow({'key': k, 'value': v})
|
||||
|
||||
|
||||
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
|
||||
"""Load hparams from a file.
|
||||
|
||||
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
|
||||
>>> path_yaml = './testing-hparams.yaml'
|
||||
>>> save_hparams_to_yaml(path_yaml, hparams)
|
||||
>>> hparams_new = load_hparams_from_yaml(path_yaml)
|
||||
>>> vars(hparams) == hparams_new
|
||||
True
|
||||
>>> os.remove(path_yaml)
|
||||
"""
|
||||
if not os.path.isfile(config_yaml):
|
||||
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
|
||||
return {}
|
||||
|
||||
with open(config_yaml) as fp:
|
||||
tags = yaml.load(fp, Loader=yaml.SafeLoader)
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
|
||||
if not os.path.isdir(os.path.dirname(config_yaml)):
|
||||
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
|
||||
|
||||
if isinstance(hparams, Namespace):
|
||||
hparams = vars(hparams)
|
||||
|
||||
with open(config_yaml, 'w', newline='') as fp:
|
||||
yaml.dump(hparams, fp)
|
||||
|
||||
|
||||
def convert(val: str) -> Union[int, float, bool, str]:
|
||||
constructors = [int, float, str]
|
||||
|
||||
if isinstance(val, str):
|
||||
if val.lower() == 'true':
|
||||
return True
|
||||
if val.lower() == 'false':
|
||||
return False
|
||||
|
||||
for c in constructors:
|
||||
try:
|
||||
return c(val)
|
||||
except ValueError:
|
||||
pass
|
||||
return val
|
||||
try:
|
||||
return ast.literal_eval(val)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
log.debug(e)
|
||||
return val
|
||||
|
|
|
@ -3,8 +3,8 @@ TensorBoard
|
|||
-----------
|
||||
"""
|
||||
|
||||
import csv
|
||||
import os
|
||||
import yaml
|
||||
from argparse import Namespace
|
||||
from typing import Optional, Dict, Union, Any
|
||||
from warnings import warn
|
||||
|
@ -14,6 +14,7 @@ from pkg_resources import parse_version
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.core.saving import save_hparams_to_yaml
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
|
@ -42,7 +43,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
\**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor.
|
||||
|
||||
"""
|
||||
NAME_CSV_TAGS = 'meta_tags.csv'
|
||||
NAME_HPARAMS_FILE = 'hparams.yaml'
|
||||
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
|
@ -55,7 +56,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
self._version = version
|
||||
|
||||
self._experiment = None
|
||||
self.tags = {}
|
||||
self.hparams = {}
|
||||
self._kwargs = kwargs
|
||||
|
||||
@property
|
||||
|
@ -104,8 +105,13 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
|
||||
metrics: Optional[Dict[str, Any]] = None) -> None:
|
||||
params = self._convert_params(params)
|
||||
|
||||
# store params to output
|
||||
self.hparams.update(params)
|
||||
|
||||
# format params into the suitable for tensorboard
|
||||
params = self._flatten_dict(params)
|
||||
sanitized_params = self._sanitize_params(params)
|
||||
params = self._sanitize_params(params)
|
||||
|
||||
if parse_version(torch.__version__) < parse_version("1.3.0"):
|
||||
warn(
|
||||
|
@ -118,7 +124,7 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
|
||||
if metrics is None:
|
||||
metrics = {}
|
||||
exp, ssi, sei = hparams(sanitized_params, metrics)
|
||||
exp, ssi, sei = hparams(params, metrics)
|
||||
writer = self.experiment._get_file_writer()
|
||||
writer.add_summary(exp)
|
||||
writer.add_summary(ssi)
|
||||
|
@ -128,9 +134,6 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
# necessary for hparam comparison with metrics
|
||||
self.log_metrics(metrics)
|
||||
|
||||
# some alternative should be added
|
||||
self.tags.update(sanitized_params)
|
||||
|
||||
@rank_zero_only
|
||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
for k, v in metrics.items():
|
||||
|
@ -152,15 +155,10 @@ class TensorBoardLogger(LightningLoggerBase):
|
|||
dir_path = self.save_dir
|
||||
|
||||
# prepare the file path
|
||||
meta_tags_path = os.path.join(dir_path, self.NAME_CSV_TAGS)
|
||||
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
|
||||
|
||||
# save the metatags file
|
||||
with open(meta_tags_path, 'w', newline='') as csvfile:
|
||||
fieldnames = ['key', 'value']
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writerow({'key': 'key', 'value': 'value'})
|
||||
for k, v in self.tags.items():
|
||||
writer.writerow({'key': k, 'value': v})
|
||||
save_hparams_to_yaml(hparams_file, self.hparams)
|
||||
|
||||
@rank_zero_only
|
||||
def finalize(self, status: str) -> None:
|
||||
|
|
|
@ -105,10 +105,9 @@ Second case is where you load a model and run the test set
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyLightningModule.load_from_metrics(
|
||||
weights_path='/path/to/pytorch_checkpoint.ckpt',
|
||||
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
|
||||
on_gpu=True,
|
||||
model = MyLightningModule.load_from_checkpoint(
|
||||
checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
|
||||
hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
|
||||
map_location=None
|
||||
)
|
||||
|
||||
|
|
|
@ -344,9 +344,16 @@ class TrainerIOMixin(ABC):
|
|||
|
||||
if hasattr(model, "hparams"):
|
||||
parsing.clean_namespace(model.hparams)
|
||||
is_namespace = isinstance(model.hparams, Namespace)
|
||||
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
|
||||
checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict'
|
||||
checkpoint['hparams_type'] = model.hparams.__class__.__name__
|
||||
if checkpoint['hparams_type'] == 'dict':
|
||||
checkpoint['hparams'] = model.hparams
|
||||
elif checkpoint['hparams_type'] == 'Namespace':
|
||||
checkpoint['hparams'] = vars(model.hparams)
|
||||
else:
|
||||
raise ValueError(
|
||||
'The acceptable hparams type is dict or argparse.Namespace,',
|
||||
f' not {checkpoint["hparams_type"]}'
|
||||
)
|
||||
else:
|
||||
rank_zero_warn(
|
||||
"Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters."
|
||||
|
|
|
@ -5,4 +5,4 @@ numpy>=1.16.4
|
|||
torch>=1.1
|
||||
tensorboard>=1.14
|
||||
future>=0.17.1 # required for builtins in setup.py
|
||||
|
||||
pyyaml>=3.13
|
||||
|
|
|
@ -126,14 +126,14 @@ def get_data_path(expt_logger, path_dir=None):
|
|||
def load_model(logger, root_weights_dir, module_class=EvalModelTemplate, path_expt=None):
|
||||
# load trained model
|
||||
path_expt_dir = get_data_path(logger, path_dir=path_expt)
|
||||
tags_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_CSV_TAGS)
|
||||
hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE)
|
||||
|
||||
checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x]
|
||||
weights_dir = os.path.join(root_weights_dir, checkpoints[0])
|
||||
|
||||
trained_model = module_class.load_from_checkpoint(
|
||||
checkpoint_path=weights_dir,
|
||||
tags_csv=tags_path
|
||||
hparams_file=hparams_path
|
||||
)
|
||||
|
||||
assert trained_model is not None, 'loading model failed'
|
||||
|
|
|
@ -256,11 +256,11 @@ def test_model_saving_loading(tmpdir):
|
|||
trainer.save_checkpoint(new_weights_path)
|
||||
|
||||
# load new model
|
||||
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
|
||||
model_2 = EvalModelTemplate.load_from_checkpoint(
|
||||
checkpoint_path=new_weights_path,
|
||||
tags_csv=tags_path
|
||||
hparams_file=hparams_path
|
||||
)
|
||||
model_2.eval()
|
||||
|
||||
|
|
|
@ -6,12 +6,14 @@ from argparse import Namespace
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Callback, LightningModule
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.lightning import load_hparams_from_tags_csv
|
||||
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import EvalModelTemplate
|
||||
|
@ -69,9 +71,12 @@ def test_no_val_module(tmpdir):
|
|||
ckpt = torch.load(new_weights_path)
|
||||
assert 'hparams' in ckpt.keys(), 'hparams missing from checkpoints'
|
||||
|
||||
# won't load without hparams in the ckpt
|
||||
# load new model
|
||||
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
|
||||
model_2 = EvalModelTemplate.load_from_checkpoint(
|
||||
checkpoint_path=new_weights_path,
|
||||
hparams_file=hparams_path
|
||||
)
|
||||
model_2.eval()
|
||||
|
||||
|
@ -100,11 +105,11 @@ def test_no_val_end_module(tmpdir):
|
|||
trainer.save_checkpoint(new_weights_path)
|
||||
|
||||
# load new model
|
||||
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
tags_path = os.path.join(tags_path, 'meta_tags.csv')
|
||||
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
|
||||
model_2 = EvalModelTemplate.load_from_checkpoint(
|
||||
checkpoint_path=new_weights_path,
|
||||
tags_csv=tags_path
|
||||
hparams_file=hparams_path
|
||||
)
|
||||
model_2.eval()
|
||||
|
||||
|
@ -184,6 +189,8 @@ def test_gradient_accumulation_scheduling(tmpdir):
|
|||
|
||||
|
||||
def test_loading_meta_tags(tmpdir):
|
||||
""" test for backward compatibility to meta_tags.csv """
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
|
||||
|
@ -193,12 +200,37 @@ def test_loading_meta_tags(tmpdir):
|
|||
logger.log_hyperparams(hparams)
|
||||
logger.save()
|
||||
|
||||
# load tags
|
||||
# load hparams
|
||||
path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE)
|
||||
hparams = load_hparams_from_yaml(hparams_path)
|
||||
|
||||
# save as legacy meta_tags.csv
|
||||
tags_path = os.path.join(path_expt_dir, 'meta_tags.csv')
|
||||
save_hparams_to_tags_csv(tags_path, hparams)
|
||||
|
||||
tags = load_hparams_from_tags_csv(tags_path)
|
||||
|
||||
assert tags.batch_size == 32 and tags.hidden_dim == 1000
|
||||
assert hparams == tags
|
||||
|
||||
|
||||
def test_loading_yaml(tmpdir):
|
||||
tutils.reset_seed()
|
||||
|
||||
hparams = EvalModelTemplate.get_default_hparams()
|
||||
|
||||
# save tags
|
||||
logger = tutils.get_default_logger(tmpdir)
|
||||
logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0))
|
||||
logger.log_hyperparams(hparams)
|
||||
logger.save()
|
||||
|
||||
# load hparams
|
||||
path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir)
|
||||
hparams_path = os.path.join(path_expt_dir, 'hparams.yaml')
|
||||
tags = load_hparams_from_yaml(hparams_path)
|
||||
|
||||
assert tags['batch_size'] == 32 and tags['hidden_dim'] == 1000
|
||||
|
||||
|
||||
def test_dp_output_reduce():
|
||||
|
|
Loading…
Reference in New Issue