From 22d7d031188ecaad636cbad30a43af94abc52bd9 Mon Sep 17 00:00:00 2001 From: So Uchida Date: Wed, 13 May 2020 22:05:15 +0900 Subject: [PATCH] Replace meta_tags.csv with hparams.yaml (#1271) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add support for hierarchical dict * Support nested Namespace * Add docstring * Migrate hparam flattening to each logger * Modify URLs in CHANGELOG * typo * Simplify the conditional branch about Namespace Co-Authored-By: Jirka Borovec * Update CHANGELOG.md Co-Authored-By: Jirka Borovec * added examples section to docstring * renamed _dict -> input_dict * mata_tags.csv -> hparams.yaml * code style fixes * add pyyaml * remove unused import * create the member NAME_HPARAMS_FILE * improve tests * Update tensorboard.py * pass the local test w/o relavents of Horovod * formatting * update dependencies * fix dependencies * Apply suggestions from code review * add savings * warn * docstrings * tests * Apply suggestions from code review * saving * Apply suggestions from code review * use default * remove logging * typo fixes * update docs * update CHANGELOG * clean imports * add blank lines * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli * Update pytorch_lightning/core/lightning.py Co-authored-by: Adrian Wälchli * back to namespace * add docs * test fix * update dependencies * add space Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli --- CHANGELOG.md | 4 + docs/source/test_set.rst | 7 +- environment.yml | 1 + pytorch_lightning/core/lightning.py | 96 ++++++++++++++++---- pytorch_lightning/core/saving.py | 95 ++++++++++++++----- pytorch_lightning/loggers/tensorboard.py | 28 +++--- pytorch_lightning/trainer/evaluation_loop.py | 7 +- pytorch_lightning/trainer/training_io.py | 13 ++- requirements.txt | 2 +- tests/base/utils.py | 4 +- tests/models/test_restore.py | 6 +- tests/trainer/test_trainer.py | 46 ++++++++-- 12 files changed, 228 insertions(+), 81 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2fcf48e25..270cf30ed5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Replace mata_tags.csv with hparams.yaml ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271)) + - Reduction when `batch_size < num_gpus` ([#1609](https://github.com/PyTorchLightning/pytorch-lightning/pull/1609)) - Updated LightningTemplateModel to look more like Colab example ([#1577](https://github.com/PyTorchLightning/pytorch-lightning/pull/1577)) @@ -36,6 +38,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Deprecated +- Deprecated `tags_csv` in favor of `hparams_file` ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271)) + ### Removed ### Fixed diff --git a/docs/source/test_set.rst b/docs/source/test_set.rst index 7dfe40ddaa..aa2a6e4e9d 100644 --- a/docs/source/test_set.rst +++ b/docs/source/test_set.rst @@ -21,10 +21,9 @@ To run the test set on a pre-trained model, use this method. .. code-block:: python - model = MyLightningModule.load_from_metrics( - weights_path='/path/to/pytorch_checkpoint.ckpt', - tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv', - on_gpu=True, + model = MyLightningModule.load_from_checkpoint( + checkpoint_path='/path/to/pytorch_checkpoint.ckpt', + hparams_file='/path/to/test_tube/experiment/version/hparams.yaml', map_location=None ) diff --git a/environment.yml b/environment.yml index 45e0e3da30..4b06b5b0d3 100644 --- a/environment.yml +++ b/environment.yml @@ -13,6 +13,7 @@ dependencies: - pytorch>=1.1 - tensorboard>=1.14 - future>=0.17.1 + - pyyaml>=3.13 # For dev and testing - tox diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index d9c2e18620..98966bb988 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1,6 +1,7 @@ import collections import inspect import os +import warnings from abc import ABC, abstractmethod from argparse import Namespace from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence @@ -16,7 +17,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.core.grads import GradInformation from pytorch_lightning.core.hooks import ModelHooks from pytorch_lightning.core.memory import ModelSummary -from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, update_hparams +from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv, load_hparams_from_yaml, update_hparams from pytorch_lightning.core.properties import DeviceDtypeModuleMixin from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -1438,29 +1439,49 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod cls, checkpoint_path: str, map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, - tags_csv: Optional[str] = None, + hparams_file: Optional[str] = None, + tags_csv: Optional[str] = None, # backward compatible, todo: remove in v0.9.0 hparam_overrides: Optional[Dict] = None, *args, **kwargs ) -> 'LightningModule': r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the hyperparameters in the checkpoint if you initialized your :class:`LightningModule` - with an argument called ``hparams`` which is a :class:`~argparse.Namespace` - (output of :meth:`~argparse.ArgumentParser.parse_args` when parsing command line arguments). + with an argument called ``hparams`` which is an object of :class:`~dict` or + :class:`~argparse.Namespace` (output of :meth:`~argparse.ArgumentParser.parse_args` + when parsing command line arguments). + If you want `hparams` to have a hierarchical structure, you have to define it as :class:`~dict`. Any other arguments specified through \*args and \*\*kwargs will be passed to the model. Example: .. code-block:: python + # define hparams as Namespace from argparse import Namespace hparams = Namespace(**{'learning_rate': 0.1}) model = MyModel(hparams) class MyModel(LightningModule): - def __init__(self, hparams): + def __init__(self, hparams: Namespace): self.learning_rate = hparams.learning_rate + # ---------- + + # define hparams as dict + hparams = { + drop_prob: 0.2, + dataloader: { + batch_size: 32 + } + } + + model = MyModel(hparams) + + class MyModel(LightningModule): + def __init__(self, hparams: dict): + self.learning_rate = hparams['learning_rate'] + Args: checkpoint_path: Path to checkpoint. model_args: Any keyword args needed to init the model. @@ -1468,19 +1489,38 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in :func:`torch.load`. - tags_csv: Optional path to a .csv file with two columns (key, value) + hparams_file: Optional path to a .yaml file with hierarchical structure + as in this example:: + + drop_prob: 0.2 + dataloader: + batch_size: 32 + + You most likely won't need this since Lightning will always save the hyperparameters + to the checkpoint. + However, if your checkpoint weights don't have the hyperparameters saved, + use this method to pass in a .yaml file with the hparams you'd like to use. + These will be converted into a :class:`~dict` and passed into your + :class:`LightningModule` for use. + + If your model's `hparams` argument is :class:`~argparse.Namespace` + and .yaml file has hierarchical structure, you need to refactor your model to treat + `hparams` as :class:`~dict`. + + .csv files are acceptable here till v0.9.0, see tags_csv argument for detailed usage. + tags_csv: + .. warning:: .. deprecated:: 0.7.6 + + `tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0. + + Optional path to a .csv file with two columns (key, value) as in this example:: key,value drop_prob,0.2 batch_size,32 - You most likely won't need this since Lightning will always save the hyperparameters - to the checkpoint. - However, if your checkpoint weights don't have the hyperparameters saved, - use this method to pass in a .csv file with the hparams you'd like to use. - These will be converted into a :class:`~argparse.Namespace` and passed into your - :class:`LightningModule` for use. + Use this method to pass in a .csv file with the hparams you'd like to use. hparam_overrides: A dictionary with keys to override in the hparams Return: @@ -1502,7 +1542,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod # or load weights and hyperparameters from separate files. MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', - tags_csv='/path/to/hparams_file.csv' + hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values @@ -1531,9 +1571,22 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod # add the hparams from csv file to checkpoint if tags_csv is not None: - hparams = load_hparams_from_tags_csv(tags_csv) - hparams.__setattr__('on_gpu', False) - checkpoint['hparams'] = vars(hparams) + hparams_file = tags_csv + rank_zero_warn('`tags_csv` argument is deprecated in v0.7.6. Will be removed v0.9.0', DeprecationWarning) + + if hparams_file is not None: + extension = hparams_file.split('.')[-1] + if extension.lower() in ('csv'): + hparams = load_hparams_from_tags_csv(hparams_file) + elif extension.lower() in ('yml', 'yaml'): + hparams = load_hparams_from_yaml(hparams_file) + else: + raise ValueError('.csv, .yml or .yaml is required for `hparams_file`') + + hparams['on_gpu'] = False + + # overwrite hparams by the given file + checkpoint['hparams'] = hparams # override the hparam keys that were passed in if hparam_overrides is not None: @@ -1549,15 +1602,18 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod if cls_takes_hparams: if ckpt_hparams is not None: - is_namespace = checkpoint.get('hparams_type', 'namespace') == 'namespace' - hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams + hparams_type = checkpoint.get('hparams_type', 'Namespace') + if hparams_type.lower() == 'dict': + hparams = ckpt_hparams + elif hparams_type.lower() == 'namespace': + hparams = Namespace(**ckpt_hparams) else: rank_zero_warn( f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__" " contains argument 'hparams'. Will pass in an empty Namespace instead." " Did you forget to store your model hyperparameters in self.hparams?" ) - hparams = Namespace() + hparams = {} else: # The user's LightningModule does not define a hparams argument if ckpt_hparams is None: hparams = None @@ -1568,7 +1624,7 @@ class LightningModule(ABC, DeviceDtypeModuleMixin, GradInformation, ModelIO, Mod ) # load the state_dict on the model automatically - if hparams: + if cls_takes_hparams: kwargs.update(hparams=hparams) model = cls(*args, **kwargs) model.load_state_dict(checkpoint['state_dict']) diff --git a/pytorch_lightning/core/saving.py b/pytorch_lightning/core/saving.py index 7a28a40d4e..adf782e6d4 100644 --- a/pytorch_lightning/core/saving.py +++ b/pytorch_lightning/core/saving.py @@ -1,9 +1,12 @@ +import ast import csv import os +import yaml from argparse import Namespace from typing import Union, Dict, Any from pytorch_lightning import _logger as log +from pytorch_lightning.utilities import rank_zero_warn class ModelIO(object): @@ -79,30 +82,78 @@ def update_hparams(hparams: dict, updates: dict) -> None: hparams.update({k: v}) -def load_hparams_from_tags_csv(tags_csv: str) -> Namespace: - if not os.path.isfile(tags_csv): - log.warning(f'Missing Tags: {tags_csv}.') - return Namespace() +def load_hparams_from_tags_csv(tags_csv: str) -> Dict[str, Any]: + """Load hparams from a file. - with open(tags_csv) as f: - csv_reader = csv.reader(f, delimiter=',') + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') + >>> path_csv = './testing-hparams.csv' + >>> save_hparams_to_tags_csv(path_csv, hparams) + >>> hparams_new = load_hparams_from_tags_csv(path_csv) + >>> vars(hparams) == hparams_new + True + >>> os.remove(path_csv) + """ + if not os.path.isfile(tags_csv): + rank_zero_warn(f'Missing Tags: {tags_csv}.', RuntimeWarning) + return {} + + with open(tags_csv) as fp: + csv_reader = csv.reader(fp, delimiter=',') tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]} - ns = Namespace(**tags) - return ns + + return tags + + +def save_hparams_to_tags_csv(tags_csv: str, hparams: Union[dict, Namespace]) -> None: + if not os.path.isdir(os.path.dirname(tags_csv)): + raise RuntimeError(f'Missing folder: {os.path.dirname(tags_csv)}.') + + if isinstance(hparams, Namespace): + hparams = vars(hparams) + + with open(tags_csv, 'w') as fp: + fieldnames = ['key', 'value'] + writer = csv.DictWriter(fp, fieldnames=fieldnames) + writer.writerow({'key': 'key', 'value': 'value'}) + for k, v in hparams.items(): + writer.writerow({'key': k, 'value': v}) + + +def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]: + """Load hparams from a file. + + >>> hparams = Namespace(batch_size=32, learning_rate=0.001, data_root='./any/path/here') + >>> path_yaml = './testing-hparams.yaml' + >>> save_hparams_to_yaml(path_yaml, hparams) + >>> hparams_new = load_hparams_from_yaml(path_yaml) + >>> vars(hparams) == hparams_new + True + >>> os.remove(path_yaml) + """ + if not os.path.isfile(config_yaml): + rank_zero_warn(f'Missing Tags: {config_yaml}.', RuntimeWarning) + return {} + + with open(config_yaml) as fp: + tags = yaml.load(fp, Loader=yaml.SafeLoader) + + return tags + + +def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None: + if not os.path.isdir(os.path.dirname(config_yaml)): + raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.') + + if isinstance(hparams, Namespace): + hparams = vars(hparams) + + with open(config_yaml, 'w', newline='') as fp: + yaml.dump(hparams, fp) def convert(val: str) -> Union[int, float, bool, str]: - constructors = [int, float, str] - - if isinstance(val, str): - if val.lower() == 'true': - return True - if val.lower() == 'false': - return False - - for c in constructors: - try: - return c(val) - except ValueError: - pass - return val + try: + return ast.literal_eval(val) + except (ValueError, SyntaxError) as e: + log.debug(e) + return val diff --git a/pytorch_lightning/loggers/tensorboard.py b/pytorch_lightning/loggers/tensorboard.py index 90e24624a3..62965d7832 100644 --- a/pytorch_lightning/loggers/tensorboard.py +++ b/pytorch_lightning/loggers/tensorboard.py @@ -3,8 +3,8 @@ TensorBoard ----------- """ -import csv import os +import yaml from argparse import Namespace from typing import Optional, Dict, Union, Any from warnings import warn @@ -14,6 +14,7 @@ from pkg_resources import parse_version from torch.utils.tensorboard import SummaryWriter from pytorch_lightning import _logger as log +from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.base import LightningLoggerBase from pytorch_lightning.utilities import rank_zero_only @@ -42,7 +43,7 @@ class TensorBoardLogger(LightningLoggerBase): \**kwargs: Other arguments are passed directly to the :class:`SummaryWriter` constructor. """ - NAME_CSV_TAGS = 'meta_tags.csv' + NAME_HPARAMS_FILE = 'hparams.yaml' def __init__(self, save_dir: str, @@ -55,7 +56,7 @@ class TensorBoardLogger(LightningLoggerBase): self._version = version self._experiment = None - self.tags = {} + self.hparams = {} self._kwargs = kwargs @property @@ -104,8 +105,13 @@ class TensorBoardLogger(LightningLoggerBase): def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None) -> None: params = self._convert_params(params) + + # store params to output + self.hparams.update(params) + + # format params into the suitable for tensorboard params = self._flatten_dict(params) - sanitized_params = self._sanitize_params(params) + params = self._sanitize_params(params) if parse_version(torch.__version__) < parse_version("1.3.0"): warn( @@ -118,7 +124,7 @@ class TensorBoardLogger(LightningLoggerBase): if metrics is None: metrics = {} - exp, ssi, sei = hparams(sanitized_params, metrics) + exp, ssi, sei = hparams(params, metrics) writer = self.experiment._get_file_writer() writer.add_summary(exp) writer.add_summary(ssi) @@ -128,9 +134,6 @@ class TensorBoardLogger(LightningLoggerBase): # necessary for hparam comparison with metrics self.log_metrics(metrics) - # some alternative should be added - self.tags.update(sanitized_params) - @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: for k, v in metrics.items(): @@ -152,15 +155,10 @@ class TensorBoardLogger(LightningLoggerBase): dir_path = self.save_dir # prepare the file path - meta_tags_path = os.path.join(dir_path, self.NAME_CSV_TAGS) + hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE) # save the metatags file - with open(meta_tags_path, 'w', newline='') as csvfile: - fieldnames = ['key', 'value'] - writer = csv.DictWriter(csvfile, fieldnames=fieldnames) - writer.writerow({'key': 'key', 'value': 'value'}) - for k, v in self.tags.items(): - writer.writerow({'key': k, 'value': v}) + save_hparams_to_yaml(hparams_file, self.hparams) @rank_zero_only def finalize(self, status: str) -> None: diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 23f7d0edef..bbed181e3d 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -105,10 +105,9 @@ Second case is where you load a model and run the test set .. code-block:: python - model = MyLightningModule.load_from_metrics( - weights_path='/path/to/pytorch_checkpoint.ckpt', - tags_csv='/path/to/test_tube/experiment/version/meta_tags.csv', - on_gpu=True, + model = MyLightningModule.load_from_checkpoint( + checkpoint_path='/path/to/pytorch_checkpoint.ckpt', + hparams_file='/path/to/test_tube/experiment/version/hparams.yaml', map_location=None ) diff --git a/pytorch_lightning/trainer/training_io.py b/pytorch_lightning/trainer/training_io.py index 437a89e424..1aedbee1a2 100644 --- a/pytorch_lightning/trainer/training_io.py +++ b/pytorch_lightning/trainer/training_io.py @@ -344,9 +344,16 @@ class TrainerIOMixin(ABC): if hasattr(model, "hparams"): parsing.clean_namespace(model.hparams) - is_namespace = isinstance(model.hparams, Namespace) - checkpoint['hparams'] = vars(model.hparams) if is_namespace else model.hparams - checkpoint['hparams_type'] = 'namespace' if is_namespace else 'dict' + checkpoint['hparams_type'] = model.hparams.__class__.__name__ + if checkpoint['hparams_type'] == 'dict': + checkpoint['hparams'] = model.hparams + elif checkpoint['hparams_type'] == 'Namespace': + checkpoint['hparams'] = vars(model.hparams) + else: + raise ValueError( + 'The acceptable hparams type is dict or argparse.Namespace,', + f' not {checkpoint["hparams_type"]}' + ) else: rank_zero_warn( "Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters." diff --git a/requirements.txt b/requirements.txt index 81441b367e..d4bf8cad3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,4 @@ numpy>=1.16.4 torch>=1.1 tensorboard>=1.14 future>=0.17.1 # required for builtins in setup.py - +pyyaml>=3.13 diff --git a/tests/base/utils.py b/tests/base/utils.py index 973972d60b..dbf2666694 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -126,14 +126,14 @@ def get_data_path(expt_logger, path_dir=None): def load_model(logger, root_weights_dir, module_class=EvalModelTemplate, path_expt=None): # load trained model path_expt_dir = get_data_path(logger, path_dir=path_expt) - tags_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_CSV_TAGS) + hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE) checkpoints = [x for x in os.listdir(root_weights_dir) if '.ckpt' in x] weights_dir = os.path.join(root_weights_dir, checkpoints[0]) trained_model = module_class.load_from_checkpoint( checkpoint_path=weights_dir, - tags_csv=tags_path + hparams_file=hparams_path ) assert trained_model is not None, 'loading model failed' diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index db86db04a4..73e7362655 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -256,11 +256,11 @@ def test_model_saving_loading(tmpdir): trainer.save_checkpoint(new_weights_path) # load new model - tags_path = tutils.get_data_path(logger, path_dir=tmpdir) - tags_path = os.path.join(tags_path, 'meta_tags.csv') + hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(hparams_path, 'hparams.yaml') model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, - tags_csv=tags_path + hparams_file=hparams_path ) model_2.eval() diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index e6cdc65338..1c340d8ccd 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -6,12 +6,14 @@ from argparse import Namespace import pytest import torch +import yaml import tests.base.utils as tutils from pytorch_lightning import Callback, LightningModule from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint -from pytorch_lightning.core.lightning import load_hparams_from_tags_csv +from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml, save_hparams_to_tags_csv +from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import EvalModelTemplate @@ -69,9 +71,12 @@ def test_no_val_module(tmpdir): ckpt = torch.load(new_weights_path) assert 'hparams' in ckpt.keys(), 'hparams missing from checkpoints' - # won't load without hparams in the ckpt + # load new model + hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(hparams_path, 'hparams.yaml') model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, + hparams_file=hparams_path ) model_2.eval() @@ -100,11 +105,11 @@ def test_no_val_end_module(tmpdir): trainer.save_checkpoint(new_weights_path) # load new model - tags_path = tutils.get_data_path(logger, path_dir=tmpdir) - tags_path = os.path.join(tags_path, 'meta_tags.csv') + hparams_path = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(hparams_path, 'hparams.yaml') model_2 = EvalModelTemplate.load_from_checkpoint( checkpoint_path=new_weights_path, - tags_csv=tags_path + hparams_file=hparams_path ) model_2.eval() @@ -184,6 +189,8 @@ def test_gradient_accumulation_scheduling(tmpdir): def test_loading_meta_tags(tmpdir): + """ test for backward compatibility to meta_tags.csv """ + tutils.reset_seed() hparams = EvalModelTemplate.get_default_hparams() @@ -193,12 +200,37 @@ def test_loading_meta_tags(tmpdir): logger.log_hyperparams(hparams) logger.save() - # load tags + # load hparams path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(path_expt_dir, TensorBoardLogger.NAME_HPARAMS_FILE) + hparams = load_hparams_from_yaml(hparams_path) + + # save as legacy meta_tags.csv tags_path = os.path.join(path_expt_dir, 'meta_tags.csv') + save_hparams_to_tags_csv(tags_path, hparams) + tags = load_hparams_from_tags_csv(tags_path) - assert tags.batch_size == 32 and tags.hidden_dim == 1000 + assert hparams == tags + + +def test_loading_yaml(tmpdir): + tutils.reset_seed() + + hparams = EvalModelTemplate.get_default_hparams() + + # save tags + logger = tutils.get_default_logger(tmpdir) + logger.log_hyperparams(Namespace(some_str='a_str', an_int=1, a_float=2.0)) + logger.log_hyperparams(hparams) + logger.save() + + # load hparams + path_expt_dir = tutils.get_data_path(logger, path_dir=tmpdir) + hparams_path = os.path.join(path_expt_dir, 'hparams.yaml') + tags = load_hparams_from_yaml(hparams_path) + + assert tags['batch_size'] == 32 and tags['hidden_dim'] == 1000 def test_dp_output_reduce():