2020-05-24 22:59:08 +00:00
|
|
|
import os
|
2020-06-04 15:25:07 +00:00
|
|
|
import sys
|
2020-05-24 22:59:08 +00:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
2020-06-04 15:25:07 +00:00
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from packaging import version
|
2020-05-24 22:59:08 +00:00
|
|
|
|
2020-05-31 12:29:51 +00:00
|
|
|
from pytorch_lightning import Trainer, LightningModule
|
2020-05-24 22:59:08 +00:00
|
|
|
from pytorch_lightning.core.lightning import CHECKPOINT_KEY_MODULE_ARGS
|
|
|
|
from tests.base import EvalModelTemplate
|
2020-05-31 12:29:51 +00:00
|
|
|
|
|
|
|
|
|
|
|
class OmegaConfModel(EvalModelTemplate):
|
|
|
|
def __init__(self, ogc):
|
|
|
|
super().__init__()
|
|
|
|
self.ogc = ogc
|
|
|
|
self.size = ogc.list[0]
|
|
|
|
|
|
|
|
|
|
|
|
def test_class_nesting(tmpdir):
|
|
|
|
|
2020-06-04 15:25:07 +00:00
|
|
|
class MyModule(LightningModule):
|
2020-05-31 12:29:51 +00:00
|
|
|
def forward(self):
|
|
|
|
return 0
|
|
|
|
|
|
|
|
# make sure PL modules are always nn.Module
|
2020-06-04 15:25:07 +00:00
|
|
|
a = MyModule()
|
2020-05-31 12:29:51 +00:00
|
|
|
assert isinstance(a, torch.nn.Module)
|
|
|
|
|
|
|
|
def test_outside():
|
2020-06-04 15:25:07 +00:00
|
|
|
a = MyModule()
|
2020-05-31 12:29:51 +00:00
|
|
|
print(a.module_arguments)
|
|
|
|
|
|
|
|
class A:
|
|
|
|
def test(self):
|
2020-06-04 15:25:07 +00:00
|
|
|
a = MyModule()
|
2020-05-31 12:29:51 +00:00
|
|
|
print(a.module_arguments)
|
|
|
|
|
|
|
|
def test2(self):
|
|
|
|
test_outside()
|
|
|
|
|
|
|
|
test_outside()
|
|
|
|
A().test2()
|
|
|
|
A().test()
|
|
|
|
|
|
|
|
|
2020-06-04 15:25:07 +00:00
|
|
|
@pytest.mark.xfail(sys.version_info >= (3, 8), reason='OmegaConf only for Python >= 3.8')
|
2020-05-31 12:29:51 +00:00
|
|
|
def test_omegaconf(tmpdir):
|
|
|
|
conf = OmegaConf.create({"k": "v", "list": [15.4, {"a": "1", "b": "2"}]})
|
|
|
|
model = OmegaConfModel(conf)
|
|
|
|
|
|
|
|
# ensure ogc passed values correctly
|
|
|
|
assert model.size == 15.4
|
|
|
|
|
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
|
|
|
|
result = trainer.fit(model)
|
|
|
|
|
|
|
|
assert result == 1
|
2020-05-24 22:59:08 +00:00
|
|
|
|
|
|
|
|
|
|
|
class SubClassEvalModel(EvalModelTemplate):
|
|
|
|
any_other_loss = torch.nn.CrossEntropyLoss()
|
|
|
|
|
|
|
|
def __init__(self, *args, subclass_arg=1200, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.subclass_arg = subclass_arg
|
2020-05-31 12:29:51 +00:00
|
|
|
self.auto_collect_arguments()
|
2020-05-24 22:59:08 +00:00
|
|
|
|
|
|
|
|
2020-06-04 12:35:50 +00:00
|
|
|
class UnconventionalArgsEvalModel(EvalModelTemplate):
|
|
|
|
""" A model that has unconventional names for "self", "*args" and "**kwargs". """
|
|
|
|
|
|
|
|
def __init__(obj, *more_args, other_arg=300, **more_kwargs):
|
|
|
|
# intentionally named obj
|
|
|
|
super().__init__(*more_args, **more_kwargs)
|
|
|
|
obj.other_arg = other_arg
|
|
|
|
other_arg = 321
|
|
|
|
obj.auto_collect_arguments()
|
|
|
|
|
|
|
|
|
2020-05-24 22:59:08 +00:00
|
|
|
class SubSubClassEvalModel(SubClassEvalModel):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class AggSubClassEvalModel(SubClassEvalModel):
|
|
|
|
|
|
|
|
def __init__(self, *args, my_loss=torch.nn.CrossEntropyLoss(), **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.my_loss = my_loss
|
2020-05-31 12:29:51 +00:00
|
|
|
self.auto_collect_arguments()
|
2020-05-24 22:59:08 +00:00
|
|
|
|
|
|
|
|
2020-06-04 12:35:50 +00:00
|
|
|
@pytest.mark.parametrize("cls", [
|
|
|
|
EvalModelTemplate,
|
|
|
|
SubClassEvalModel,
|
|
|
|
SubSubClassEvalModel,
|
|
|
|
AggSubClassEvalModel,
|
|
|
|
UnconventionalArgsEvalModel,
|
|
|
|
])
|
2020-05-24 22:59:08 +00:00
|
|
|
def test_collect_init_arguments(tmpdir, cls):
|
|
|
|
""" Test that the model automatically saves the arguments passed into the constructor """
|
|
|
|
extra_args = dict(my_loss=torch.nn.CosineEmbeddingLoss()) if cls is AggSubClassEvalModel else {}
|
|
|
|
|
|
|
|
model = cls(**extra_args)
|
|
|
|
assert model.batch_size == 32
|
|
|
|
model = cls(batch_size=179, **extra_args)
|
|
|
|
assert model.batch_size == 179
|
|
|
|
|
|
|
|
if isinstance(model, SubClassEvalModel):
|
|
|
|
assert model.subclass_arg == 1200
|
|
|
|
|
|
|
|
if isinstance(model, AggSubClassEvalModel):
|
|
|
|
assert isinstance(model.my_loss, torch.nn.CosineEmbeddingLoss)
|
|
|
|
|
|
|
|
# verify that the checkpoint saved the correct values
|
2020-05-31 12:29:51 +00:00
|
|
|
trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_pct=0.5)
|
2020-05-24 22:59:08 +00:00
|
|
|
trainer.fit(model)
|
|
|
|
raw_checkpoint_path = os.listdir(trainer.checkpoint_callback.dirpath)
|
|
|
|
raw_checkpoint_path = [x for x in raw_checkpoint_path if '.ckpt' in x][0]
|
|
|
|
raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path)
|
|
|
|
|
|
|
|
raw_checkpoint = torch.load(raw_checkpoint_path)
|
|
|
|
assert CHECKPOINT_KEY_MODULE_ARGS in raw_checkpoint
|
|
|
|
assert raw_checkpoint[CHECKPOINT_KEY_MODULE_ARGS]['batch_size'] == 179
|
|
|
|
|
|
|
|
# verify that model loads correctly
|
|
|
|
model = cls.load_from_checkpoint(raw_checkpoint_path)
|
|
|
|
assert model.batch_size == 179
|
|
|
|
|
|
|
|
if isinstance(model, AggSubClassEvalModel):
|
|
|
|
assert isinstance(model.my_loss, torch.nn.CrossEntropyLoss)
|
|
|
|
|
|
|
|
# verify that we can overwrite whatever we want
|
|
|
|
model = cls.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
|
|
|
|
assert model.batch_size == 99
|
2020-06-04 12:35:50 +00:00
|
|
|
|
|
|
|
|
|
|
|
class LocalVariableModel1(EvalModelTemplate):
|
|
|
|
""" This model has the super().__init__() call at the end. """
|
|
|
|
|
|
|
|
def __init__(self, arg1, arg2, *args, **kwargs):
|
|
|
|
self.argument1 = arg1 # arg2 intentionally not set
|
|
|
|
arg1 = 'overwritten'
|
|
|
|
local_var = 1234
|
|
|
|
super().__init__(*args, **kwargs) # this is intentionally here at the end
|
|
|
|
|
|
|
|
|
|
|
|
class LocalVariableModel2(EvalModelTemplate):
|
|
|
|
""" This model has the auto_collect_arguments() call at the end. """
|
|
|
|
|
|
|
|
def __init__(self, arg1, arg2, *args, **kwargs):
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.argument1 = arg1 # arg2 intentionally not set
|
|
|
|
arg1 = 'overwritten'
|
|
|
|
local_var = 1234
|
|
|
|
self.auto_collect_arguments() # this is intentionally here at the end
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("cls", [
|
|
|
|
LocalVariableModel1,
|
|
|
|
LocalVariableModel2,
|
|
|
|
])
|
|
|
|
def test_collect_init_arguments_with_local_vars(cls):
|
|
|
|
""" Tests that only the arguments are collected and not local variables. """
|
|
|
|
model = cls(arg1=1, arg2=2)
|
|
|
|
assert 'local_var' not in model.module_arguments
|
|
|
|
assert model.module_arguments['arg1'] == 'overwritten'
|
|
|
|
assert model.module_arguments['arg2'] == 2
|