lightning/tests/base/model_utilities.py

28 lines
761 B
Python
Raw Normal View History

from torch.utils.data import DataLoader
from tests.base.datasets import TrialMNIST
class ModelTemplateData:
def dataloader(self, train: bool, num_samples: int = 100):
dataset = TrialMNIST(root=self.data_root, train=train, num_samples=num_samples, download=True)
loader = DataLoader(
dataset=dataset,
replace Hparams by init args (#1896) * remove the need for hparams * remove the need for hparams * remove the need for hparams * remove the need for hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * replace self.hparams * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * fixed * finished moco * basic * testing * todo * recurse * hparams * persist * hparams * chlog * tests * tests * tests * tests * tests * tests * review * saving * tests * tests * tests * docs * finished moco * hparams * review * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * hparams * overwrite * transform * transform * transform * transform * cleaning * cleaning * tests * examples * examples * examples * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * chp key * tests * Apply suggestions from code review * class * updated docs * updated docs * updated docs * updated docs * save * wip * fix * flake8 Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
2020-05-24 22:59:08 +00:00
batch_size=self.batch_size,
Bugfix/_has_len (#2307) * deal with NotImplementedError raised by torchtext * deal with NotImplementedError raised by torchtext * Added tests for dataloader which raise NotImplementedError in __len__() * Fixed some typos * enabled tests for dataloader raising NotImplementedError in __len__ and corrected match string for raised exception * deleted empty line for style compliance * refactored CustomNotImplementedErrorDataloader to derive from CustomInfDataloader * enabled reduced number of not_implemented_error dataloader test to reduce runtime for continuous integration * reduced test number of not_implemented_error dataloader test further to reduce test time * reduced test number of not_implemented_error dataloader test to one to reduce test time * disabled all not_implemented_error dataloader test to see if test pass in time * added __next__ with a reduced number (5) of elements after which CustomNotImplementedErrorDataloader stops to speedup test. * enabling all not_implemented_error dataloader test * added brief description of change and relation of torchtext * CustomNotImplementedErrorDataloader reduced number of batches served to 2. * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Disable parallelism in dataloader Suspect that it might cause pytest to hang more frequent * added max_steps=None to Trainer in not_implemented_error dataloader tests * rearranged not_implemented_error test in file to group them together * disabled parallel data loading Reason: testing if that stops the test framework from hanging. * Apply suggestions from code review * Apply suggestions from code review Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-26 13:31:08 +00:00
num_workers=0,
shuffle=train,
)
return loader
class ModelTemplateUtils:
def get_output_metric(self, output, name):
if isinstance(output, dict):
val = output[name]
else: # if it is 2level deep -> per dataloader and per batch
val = sum(out[name] for out in output) / len(output)
return val