Replace meta_tags.csv with hparams.yaml (#1271)

* Add support for hierarchical dict

* Support nested Namespace

* Add docstring

* Migrate hparam flattening to each logger

* Modify URLs in CHANGELOG

* typo

* Simplify the conditional branch about Namespace

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* Update CHANGELOG.md

Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com>

* added examples section to docstring

* renamed _dict -> input_dict

* mata_tags.csv -> hparams.yaml

* code style fixes

* add pyyaml

* remove unused import

* create the member NAME_HPARAMS_FILE

* improve tests

* Update tensorboard.py

* pass the local test w/o relavents of Horovod

* formatting

* update dependencies

* fix dependencies

* Apply suggestions from code review

* add savings

* warn

* docstrings

* tests

* Apply suggestions from code review

* saving

* Apply suggestions from code review

* use default

* remove logging

* typo fixes

* update docs

* update CHANGELOG

* clean imports

* add blank lines

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update pytorch_lightning/core/lightning.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* back to namespace

* add docs

* test fix

* update dependencies

* add space

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
So Uchida 2020-05-13 22:05:15 +09:00 committed by GitHub
parent 35fe2efe27
commit 22d7d03118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 228 additions and 81 deletions

View File

@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed
- Replace mata_tags.csv with hparams.yaml ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))
- Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609))
- Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577))
@ -36,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated
- Deprecated `tags_csv` in favor of `hparams_file` ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))
### Removed
### Fixed

View File

@ -21,10 +21,9 @@ To run the test set on a pre-trained model, use this method.
.. code-block:: python
model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
on_gpu=True,
model = MyLightningModule.load_from_checkpoint(
checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
map_location=None
)

View File

@ -13,6 +13,7 @@ dependencies:
- pytorch>=1.1
- tensorboard>=1.14
- future>=0.17.1
- pyyaml>=3.13
# For dev and testing
- tox

View File

@ -1,6 +1,7 @@
import collections
import inspect
import os
import warnings
from abc import ABC, abstractmethod
from argparse import Namespace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
@ -16,7 +17,7 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.core.grads import GradInformation
from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams
from pytorch_lightning.core.properties import DeviceDtypeModuleMixin
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -1438,29 +1439,49 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
cls,
checkpoint_path: str,
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
tags_csv: Optional[str] = None,
hparams_file: Optional[str] = None,
tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0
hparam_overrides: Optional[Dict] = None,
*args, **kwargs
) -> 'LightningModule':
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule`
with an argument called ``hparams`` which is a :class:`~argparse.Namespace`
(output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments).
with an argument called ``hparams`` which is an object of :class:`~dict` or
:class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args`
when parsing command line arguments).
If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`.
Any other arguments specified through \*args and \*\*kwargs will be passed to the model.
Example:
.. code-block:: python
# define hparams as Namespace
from argparse import Namespace
hparams = Namespace(**{'learning_rate': 0.1})
model = MyModel(hparams)
class MyModel(LightningModule):
def __init__(self, hparams):
def __init__(self, hparams: Namespace):
self.learning_rate = hparams.learning_rate
# ----------
# define hparams as dict
hparams = {
drop_prob: 0.2,
dataloader: {
batch_size: 32
}
}
model = MyModel(hparams)
class MyModel(LightningModule):
def __init__(self, hparams: dict):
self.learning_rate = hparams['learning_rate']
Args:
checkpoint_path: Path to checkpoint.
model_args: Any keyword args needed to init the model.
@ -1468,19 +1489,38 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
The behaviour is the same as in :func:`torch.load`.
tags_csv: Optional path to a .csv file with two columns (key, value)
hparams_file: Optional path to a .yaml file with hierarchical structure
as in this example::
drop_prob: 0.2
dataloader:
batch_size: 32
You most likely won't need this since Lightning will always save the hyperparameters
to the checkpoint.
However, if your checkpoint weights don't have the hyperparameters saved,
use this method to pass in a .yaml file with the hparams you'd like to use.
These will be converted into a :class:`~dict` and passed into your
:class:`LightningModule` for use.
If your model's `hparams` argument is :class:`~argparse.Namespace`
and .yaml file has hierarchical structure, you need to refactor your model to treat
`hparams` as :class:`~dict`.
.csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage.
tags_csv:
.. warning:: .. deprecated:: 0.7.6
`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0.
Optional path to a .csv file with two columns (key, value)
as in this example::
key,value
drop_prob,0.2
batch_size,32
You most likely won't need this since Lightning will always save the hyperparameters
to the checkpoint.
However, if your checkpoint weights don't have the hyperparameters saved,
use this method to pass in a .csv file with the hparams you'd like to use.
These will be converted into a :class:`~argparse.Namespace` and passed into your
:class:`LightningModule` for use.
Use this method to pass in a .csv file with the hparams you'd like to use.
hparam_overrides: A dictionary with keys to override in the hparams
Return:
@ -1502,7 +1542,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
# or load weights and hyperparameters from separate files.
MyLightningModule.load_from_checkpoint(
'path/to/checkpoint.ckpt',
tags_csv='/path/to/hparams_file.csv'
hparams_file='/path/to/hparams_file.yaml'
)
# override some of the params with new values
@ -1531,9 +1571,22 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
# add the hparams from csv file to checkpoint
if tags_csv is not None:
hparams = load_hparams_from_tags_csv(tags_csv)
hparams.__setattr__('on_gpu', False)
checkpoint['hparams'] = vars(hparams)
hparams_file = tags_csv
rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning)
if hparams_file is not None:
extension = hparams_file.split('.')[-1]
if extension.lower() in ('csv'):
hparams = load_hparams_from_tags_csv(hparams_file)
elif extension.lower() in ('yml', 'yaml'):
hparams = load_hparams_from_yaml(hparams_file)
else:
raise ValueError('.csv, .yml or .yaml is required for `hparams_file`')
hparams['on_gpu'] = False
# overwrite hparams by the given file
checkpoint['hparams'] = hparams
# override the hparam keys that were passed in
if hparam_overrides is not None:
@ -1549,15 +1602,18 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
if cls_takes_hparams:
if ckpt_hparams is not None:
is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace'
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
hparams_type = checkpoint.get('hparams_type', 'Namespace')
if hparams_type.lower() == 'dict':
hparams = ckpt_hparams
elif hparams_type.lower() == 'namespace':
hparams = Namespace(**ckpt_hparams)
else:
rank_zero_warn(
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__"
" contains argument 'hparams'. Will pass in an empty Namespace instead."
" Did you forget to store your model hyperparameters in self.hparams?"
)
hparams = Namespace()
hparams = {}
else: # The user's LightningModule does not define a hparams argument
if ckpt_hparams is None:
hparams = None
@ -1568,7 +1624,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod
)
# load the state_dict on the model automatically
if hparams:
if cls_takes_hparams:
kwargs.update(hparams=hparams)
model = cls(*args, **kwargs)
model.load_state_dict(checkpoint['state_dict'])

View File

@ -1,9 +1,12 @@
import ast
import csv
import os
import yaml
from argparse import Namespace
from typing import Union, Dict, Any
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import rank_zero_warn
class ModelIO(object):
@ -79,30 +82,78 @@ def update_hparams(hparams: dict, updates: dict) -> None:
hparams.update({k: v})
def load_hparams_from_tags_csv(tags_csv: str) -> Namespace:
if not os.path.isfile(tags_csv):
log.warning(f'Missing Tags: {tags_csv}.')
return Namespace()
def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]:
"""Load hparams from a file.
with open(tags_csv) as f:
csv_reader = csv.reader(f, delimiter=',')
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_csv = './testing-hparams.csv'
>>> save_hparams_to_tags_csv(path_csv, hparams)
>>> hparams_new = load_hparams_from_tags_csv(path_csv)
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_csv)
"""
if not os.path.isfile(tags_csv):
rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning)
return {}
with open(tags_csv) as fp:
csv_reader = csv.reader(fp, delimiter=',')
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
ns = Namespace(**tags)
return ns
return tags
def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(tags_csv)):
raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.')
if isinstance(hparams, Namespace):
hparams = vars(hparams)
with open(tags_csv, 'w') as fp:
fieldnames = ['key', 'value']
writer = csv.DictWriter(fp, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
for k, v in hparams.items():
writer.writerow({'key': k, 'value': v})
def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
"""Load hparams from a file.
>>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here')
>>> path_yaml = './testing-hparams.yaml'
>>> save_hparams_to_yaml(path_yaml, hparams)
>>> hparams_new = load_hparams_from_yaml(path_yaml)
>>> vars(hparams) == hparams_new
True
>>> os.remove(path_yaml)
"""
if not os.path.isfile(config_yaml):
rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning)
return {}
with open(config_yaml) as fp:
tags = yaml.load(fp, Loader=yaml.SafeLoader)
return tags
def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
if not os.path.isdir(os.path.dirname(config_yaml)):
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
if isinstance(hparams, Namespace):
hparams = vars(hparams)
with open(config_yaml, 'w', newline='') as fp:
yaml.dump(hparams, fp)
def convert(val: str) -> Union[int, float, bool, str]:
constructors = [int, float, str]
if isinstance(val, str):
if val.lower() == 'true':
return True
if val.lower() == 'false':
return False
for c in constructors:
try:
return c(val)
except ValueError:
pass
return val
try:
return ast.literal_eval(val)
except (ValueError, SyntaxError) as e:
log.debug(e)
return val

View File

@ -3,8 +3,8 @@ TensorBoard
-----------
"""
import csv
import os
import yaml
from argparse import Namespace
from typing import Optional, Dict, Union, Any
from warnings import warn
@ -14,6 +14,7 @@ from pkg_resources import parse_version
from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning import _logger as log
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
@ -42,7 +43,7 @@ class TensorBoardLogger(LightningLoggerBase):
\**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor.
"""
NAME_CSV_TAGS = 'meta_tags.csv'
NAME_HPARAMS_FILE = 'hparams.yaml'
def __init__(self,
save_dir: str,
@ -55,7 +56,7 @@ class TensorBoardLogger(LightningLoggerBase):
self._version = version
self._experiment = None
self.tags = {}
self.hparams = {}
self._kwargs = kwargs
@property
@ -104,8 +105,13 @@ class TensorBoardLogger(LightningLoggerBase):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
metrics: Optional[Dict[str, Any]] = None) -> None:
params = self._convert_params(params)
# store params to output
self.hparams.update(params)
# format params into the suitable for tensorboard
params = self._flatten_dict(params)
sanitized_params = self._sanitize_params(params)
params = self._sanitize_params(params)
if parse_version(torch.__version__) < parse_version("1.3.0"):
warn(
@ -118,7 +124,7 @@ class TensorBoardLogger(LightningLoggerBase):
if metrics is None:
metrics = {}
exp, ssi, sei = hparams(sanitized_params, metrics)
exp, ssi, sei = hparams(params, metrics)
writer = self.experiment._get_file_writer()
writer.add_summary(exp)
writer.add_summary(ssi)
@ -128,9 +134,6 @@ class TensorBoardLogger(LightningLoggerBase):
# necessary for hparam comparison with metrics
self.log_metrics(metrics)
# some alternative should be added
self.tags.update(sanitized_params)
@rank_zero_only
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
for k, v in metrics.items():
@ -152,15 +155,10 @@ class TensorBoardLogger(LightningLoggerBase):
dir_path = self.save_dir
# prepare the file path
meta_tags_path = os.path.join(dir_path, self.NAME_CSV_TAGS)
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
# save the metatags file
with open(meta_tags_path, 'w', newline='') as csvfile:
fieldnames = ['key', 'value']
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writerow({'key': 'key', 'value': 'value'})
for k, v in self.tags.items():
writer.writerow({'key': k, 'value': v})
save_hparams_to_yaml(hparams_file, self.hparams)
@rank_zero_only
def finalize(self, status: str) -> None:

View File

@ -105,10 +105,9 @@ Second case is where you load a model and run the test set
.. code-block:: python
model = MyLightningModule.load_from_metrics(
weights_path='/path/to/pytorch_checkpoint.ckpt',
tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv',
on_gpu=True,
model = MyLightningModule.load_from_checkpoint(
checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
map_location=None
)

View File

@ -344,9 +344,16 @@ class TrainerIOMixin(ABC):
if hasattr(model, "hparams"):
parsing.clean_namespace(model.hparams)
is_namespace = isinstance(model.hparams, Namespace)
checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams
checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict'
checkpoint['hparams_type'] = model.hparams.__class__.__name__
if checkpoint['hparams_type'] == 'dict':
checkpoint['hparams'] = model.hparams
elif checkpoint['hparams_type'] == 'Namespace':
checkpoint['hparams'] = vars(model.hparams)
else:
raise ValueError(
'The acceptable hparams type is dict or argparse.Namespace,',
f' not {checkpoint["hparams_type"]}'
)
else:
rank_zero_warn(
"Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters."

View File

@ -5,4 +5,4 @@ numpy>=1.16.4
torch>=1.1
tensorboard>=1.14
future>=0.17.1 # required for builtins in setup.py
pyyaml>=3.13

View File

@ -126,14 +126,14 @@ def get_data_path(expt_logger, path_dir=None):
def load_model(logger, root_weights_dir, module_class=EvalModelTemplate, path_expt=None):
# load trained model
path_expt_dir = get_data_path(logger, path_dir=path_expt)
tags_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_CSV_TAGS)
hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE)
checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x]
weights_dir = os.path.join(root_weights_dir, checkpoints[0])
trained_model = module_class.load_from_checkpoint(
checkpoint_path=weights_dir,
tags_csv=tags_path
hparams_file=hparams_path
)
assert trained_model is not None, 'loading model failed'

View File

@ -256,11 +256,11 @@ def test_model_saving_loading(tmpdir):
trainer.save_checkpoint(new_weights_path)
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
hparams_file=hparams_path
)
model_2.eval()

View File

@ -6,12 +6,14 @@ from argparse import Namespace
import pytest
import torch
import yaml
import tests.base.utils as tutils
from pytorch_lightning import Callback, LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.lightning import load_hparams_from_tags_csv
from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
@ -69,9 +71,12 @@ def test_no_val_module(tmpdir):
ckpt = torch.load(new_weights_path)
assert 'hparams' in ckpt.keys(), 'hparams missing from checkpoints'
# won't load without hparams in the ckpt
# load new model
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
hparams_file=hparams_path
)
model_2.eval()
@ -100,11 +105,11 @@ def test_no_val_end_module(tmpdir):
trainer.save_checkpoint(new_weights_path)
# load new model
tags_path = tutils.get_data_path(logger, path_dir=tmpdir)
tags_path = os.path.join(tags_path, 'meta_tags.csv')
hparams_path = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(hparams_path, 'hparams.yaml')
model_2 = EvalModelTemplate.load_from_checkpoint(
checkpoint_path=new_weights_path,
tags_csv=tags_path
hparams_file=hparams_path
)
model_2.eval()
@ -184,6 +189,8 @@ def test_gradient_accumulation_scheduling(tmpdir):
def test_loading_meta_tags(tmpdir):
""" test for backward compatibility to meta_tags.csv """
tutils.reset_seed()
hparams = EvalModelTemplate.get_default_hparams()
@ -193,12 +200,37 @@ def test_loading_meta_tags(tmpdir):
logger.log_hyperparams(hparams)
logger.save()
# load tags
# load hparams
path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE)
hparams = load_hparams_from_yaml(hparams_path)
# save as legacy meta_tags.csv
tags_path = os.path.join(path_expt_dir, 'meta_tags.csv')
save_hparams_to_tags_csv(tags_path, hparams)
tags = load_hparams_from_tags_csv(tags_path)
assert tags.batch_size == 32 and tags.hidden_dim == 1000
assert hparams == tags
def test_loading_yaml(tmpdir):
tutils.reset_seed()
hparams = EvalModelTemplate.get_default_hparams()
# save tags
logger = tutils.get_default_logger(tmpdir)
logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0))
logger.log_hyperparams(hparams)
logger.save()
# load hparams
path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir)
hparams_path = os.path.join(path_expt_dir, 'hparams.yaml')
tags = load_hparams_from_yaml(hparams_path)
assert tags['batch_size'] == 32 and tags['hidden_dim'] == 1000
def test_dp_output_reduce():