Merge pull request #64 from williamFalcon/cov
added test model to do also
This commit is contained in:
commit
0895a41fb9
|
@ -40,3 +40,5 @@ comment:
|
||||||
require_changes: false
|
require_changes: false
|
||||||
behavior: default # update if exists else create new
|
behavior: default # update if exists else create new
|
||||||
# branches: *
|
# branches: *
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ removed until windows install issues resolved.
|
||||||
removed until codecov badge isn't empy. likely a config error showing nothing on master.
|
removed until codecov badge isn't empy. likely a config error showing nothing on master.
|
||||||
[![codecov](https://codecov.io/gh/Borda/pytorch-lightning/branch/master/graph/badge.svg)](https://codecov.io/gh/Borda/pytorch-lightning)
|
[![codecov](https://codecov.io/gh/Borda/pytorch-lightning/branch/master/graph/badge.svg)](https://codecov.io/gh/Borda/pytorch-lightning)
|
||||||
-->
|
-->
|
||||||
[![Coverage](https://github.com/williamFalcon/pytorch-lightning/blob/master/coverage.svg)](https://github.com/williamFalcon/pytorch-lightning/tree/master/tests#running-coverage)
|
[![Coverage](https://github.com/williamFalcon/pytorch-lightning/blob/master/docs/source/_static/coverage.svg)](https://github.com/williamFalcon/pytorch-lightning/tree/master/tests#running-coverage)
|
||||||
[![CodeFactor](https://www.codefactor.io/repository/github/borda/pytorch-lightning/badge)](https://www.codefactor.io/repository/github/borda/pytorch-lightning)
|
[![CodeFactor](https://www.codefactor.io/repository/github/borda/pytorch-lightning/badge)](https://www.codefactor.io/repository/github/borda/pytorch-lightning)
|
||||||
[![ReadTheDocs](https://readthedocs.org/projects/pytorch-lightning/badge/?version=latest)](https://pytorch-lightning.readthedocs.io/en/latest)
|
[![ReadTheDocs](https://readthedocs.org/projects/pytorch-lightning/badge/?version=latest)](https://pytorch-lightning.readthedocs.io/en/latest)
|
||||||
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/williamFalcon/pytorch-lightning/blob/master/LICENSE)
|
[![license](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/williamFalcon/pytorch-lightning/blob/master/LICENSE)
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
<g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
|
<g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
|
||||||
<text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text>
|
<text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text>
|
||||||
<text x="31.5" y="14">coverage</text>
|
<text x="31.5" y="14">coverage</text>
|
||||||
<text x="80" y="15" fill="#010101" fill-opacity=".3">96%</text>
|
<text x="80" y="15" fill="#010101" fill-opacity=".3">99%</text>
|
||||||
<text x="80" y="14">96%</text>
|
<text x="80" y="14">99%</text>
|
||||||
</g>
|
</g>
|
||||||
</svg>
|
</svg>
|
Before Width: | Height: | Size: 901 B After Width: | Height: | Size: 901 B |
|
@ -135,16 +135,16 @@ class LightningTestModel(LightningModule):
|
||||||
val_acc = val_acc.unsqueeze(0)
|
val_acc = val_acc.unsqueeze(0)
|
||||||
|
|
||||||
# alternate possible outputs to test
|
# alternate possible outputs to test
|
||||||
if self.trainer.batch_nb % 1 == 0:
|
if batch_i % 1 == 0:
|
||||||
output = OrderedDict({
|
output = OrderedDict({
|
||||||
'val_loss': loss_val,
|
'val_loss': loss_val,
|
||||||
'val_acc': val_acc,
|
'val_acc': val_acc,
|
||||||
})
|
})
|
||||||
return output
|
return output
|
||||||
if self.trainer.batch_nb % 2 == 0:
|
if batch_i % 2 == 0:
|
||||||
return val_acc
|
return val_acc
|
||||||
|
|
||||||
if self.trainer.batch_nb % 3 == 0:
|
if batch_i % 3 == 0:
|
||||||
output = OrderedDict({
|
output = OrderedDict({
|
||||||
'val_loss': loss_val,
|
'val_loss': loss_val,
|
||||||
'val_acc': val_acc,
|
'val_acc': val_acc,
|
||||||
|
@ -232,7 +232,7 @@ class LightningTestModel(LightningModule):
|
||||||
return self.__dataloader(train=False)
|
return self.__dataloader(train=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def add_model_specific_args(parent_parser, root_dir):
|
def add_model_specific_args(parent_parser, root_dir): # pragma: no cover
|
||||||
"""
|
"""
|
||||||
Parameters you define here will be available to your model through self.hparams
|
Parameters you define here will be available to your model through self.hparams
|
||||||
:param parent_parser:
|
:param parent_parser:
|
||||||
|
|
|
@ -31,6 +31,8 @@ exclude_lines =
|
||||||
print(traceback.print_exc())
|
print(traceback.print_exc())
|
||||||
return *
|
return *
|
||||||
raise Exception
|
raise Exception
|
||||||
|
raise *
|
||||||
|
except *
|
||||||
warnings
|
warnings
|
||||||
print
|
print
|
||||||
raise RuntimeError
|
raise RuntimeError
|
||||||
|
@ -42,6 +44,7 @@ omit =
|
||||||
pytorch_lightning/callbacks/pt_callbacks.py
|
pytorch_lightning/callbacks/pt_callbacks.py
|
||||||
tests/test_models.py
|
tests/test_models.py
|
||||||
pytorch_lightning/testing_models/lm_test_module.py
|
pytorch_lightning/testing_models/lm_test_module.py
|
||||||
|
pytorch_lightning/utilities/arg_parse.py
|
||||||
|
|
||||||
[flake8]
|
[flake8]
|
||||||
ignore = E731,W504,F401,F841
|
ignore = E731,W504,F401,F841
|
||||||
|
|
|
@ -681,10 +681,14 @@ def get_hparams(continue_training=False, hpc_exp_number=0):
|
||||||
return hparams
|
return hparams
|
||||||
|
|
||||||
|
|
||||||
def get_model():
|
def get_model(use_test_model=False):
|
||||||
# set up model with these hyperparams
|
# set up model with these hyperparams
|
||||||
hparams = get_hparams()
|
hparams = get_hparams()
|
||||||
model = LightningTemplateModel(hparams)
|
|
||||||
|
if use_test_model:
|
||||||
|
model = LightningTestModel(hparams)
|
||||||
|
else:
|
||||||
|
model = LightningTemplateModel(hparams)
|
||||||
|
|
||||||
return model, hparams
|
return model, hparams
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue