Merge pull request #64 from williamFalcon/cov

added test model to do also
This commit is contained in:
William Falcon 2019-08-07 13:04:47 -04:00 committed by GitHub
commit 0895a41fb9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 18 additions and 9 deletions

View File

@ -40,3 +40,5 @@ comment:
require_changes: false require_changes: false
behavior: default # update if exists else create new behavior: default # update if exists else create new
# branches: * # branches: *

View File

@ -17,7 +17,7 @@ removed until windows install issues resolved.
removed until codecov badge isn't empy. likely a config error showing nothing on master. removed until codecov badge isn't empy. likely a config error showing nothing on master.
[![codecov](https://codecov.io/gh/Borda/pytorch-lightning/branch/master/graph/badge.svg)](https://codecov.io/gh/Borda/pytorch-lightning) [![codecov](https://codecov.io/gh/Borda/pytorch-lightning/branch/master/graph/badge.svg)](https://codecov.io/gh/Borda/pytorch-lightning)
--> -->
[![Coverage](https://github.com/williamFalcon/pytorch-lightning/blob/master/coverage.svg)](https://github.com/williamFalcon/pytorch-lightning/tree/master/tests#running-coverage) [![Coverage](https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/coverage.svg)](https://github.com/williamFalcon/pytorch-lightning/tree/master/tests#running-coverage)
[![CodeFactor](https://www.codefactor.io/repository/github/borda/pytorch-lightning/badge)](https://www.codefactor.io/repository/github/borda/pytorch-lightning) [![CodeFactor](https://www.codefactor.io/repository/github/borda/pytorch-lightning/badge)](https://www.codefactor.io/repository/github/borda/pytorch-lightning)
[![ReadTheDocs](https://readthedocs.org/projects/pytorch-lightning/badge/?version=latest)](https://pytorch-lightning.readthedocs.io/en/latest) [![ReadTheDocs](https://readthedocs.org/projects/pytorch-lightning/badge/?version=latest)](https://pytorch-lightning.readthedocs.io/en/latest)
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/williamFalcon/pytorch-lightning/blob/master/LICENSE) [![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/williamFalcon/pytorch-lightning/blob/master/LICENSE)

View File

@ -15,7 +15,7 @@
<g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11"> <g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
<text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text> <text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text>
<text x="31.5" y="14">coverage</text> <text x="31.5" y="14">coverage</text>
<text x="80" y="15" fill="#010101" fill-opacity=".3">96%</text> <text x="80" y="15" fill="#010101" fill-opacity=".3">99%</text>
<text x="80" y="14">96%</text> <text x="80" y="14">99%</text>
</g> </g>
</svg> </svg>

Before

Width:  |  Height:  |  Size: 901 B

After

Width:  |  Height:  |  Size: 901 B

View File

@ -135,16 +135,16 @@ class LightningTestModel(LightningModule):
val_acc = val_acc.unsqueeze(0) val_acc = val_acc.unsqueeze(0)
# alternate possible outputs to test # alternate possible outputs to test
if self.trainer.batch_nb % 1 == 0: if batch_i % 1 == 0:
output = OrderedDict({ output = OrderedDict({
'val_loss': loss_val, 'val_loss': loss_val,
'val_acc': val_acc, 'val_acc': val_acc,
}) })
return output return output
if self.trainer.batch_nb % 2 == 0: if batch_i % 2 == 0:
return val_acc return val_acc
if self.trainer.batch_nb % 3 == 0: if batch_i % 3 == 0:
output = OrderedDict({ output = OrderedDict({
'val_loss': loss_val, 'val_loss': loss_val,
'val_acc': val_acc, 'val_acc': val_acc,
@ -232,7 +232,7 @@ class LightningTestModel(LightningModule):
return self.__dataloader(train=False) return self.__dataloader(train=False)
@staticmethod @staticmethod
def add_model_specific_args(parent_parser, root_dir): def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
""" """
Parameters you define here will be available to your model through self.hparams Parameters you define here will be available to your model through self.hparams
:param parent_parser: :param parent_parser:

View File

@ -31,6 +31,8 @@ exclude_lines =
print(traceback.print_exc()) print(traceback.print_exc())
return * return *
raise Exception raise Exception
raise *
except *
warnings warnings
print print
raise RuntimeError raise RuntimeError
@ -42,6 +44,7 @@ omit =
pytorch_lightning/callbacks/pt_callbacks.py pytorch_lightning/callbacks/pt_callbacks.py
tests/test_models.py tests/test_models.py
pytorch_lightning/testing_models/lm_test_module.py pytorch_lightning/testing_models/lm_test_module.py
pytorch_lightning/utilities/arg_parse.py
[flake8] [flake8]
ignore = E731,W504,F401,F841 ignore = E731,W504,F401,F841

View File

@ -681,10 +681,14 @@ def get_hparams(continue_training=False, hpc_exp_number=0):
return hparams return hparams
def get_model(): def get_model(use_test_model=False):
# set up model with these hyperparams # set up model with these hyperparams
hparams = get_hparams() hparams = get_hparams()
model = LightningTemplateModel(hparams)
if use_test_model:
model = LightningTestModel(hparams)
else:
model = LightningTemplateModel(hparams)
return model, hparams return model, hparams