43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
|
import importlib
|
||
|
import os
|
||
|
import tempfile
|
||
|
import types
|
||
|
from pathlib import Path
|
||
|
|
||
|
import pytest
|
||
|
|
||
|
import lightning
|
||
|
import lightning.store
|
||
|
|
||
|
|
||
|
def reload_package(package):
|
||
|
# credit: https://stackoverflow.com/a/28516918/4521646
|
||
|
assert hasattr(package, "__package__")
|
||
|
fn = package.__file__
|
||
|
fn_dir = os.path.dirname(fn) + os.sep
|
||
|
module_visit = {fn}
|
||
|
del fn
|
||
|
|
||
|
def reload_recursive_ex(module):
|
||
|
importlib.reload(module)
|
||
|
|
||
|
for module_child in vars(module).values():
|
||
|
if not isinstance(module_child, types.ModuleType):
|
||
|
continue
|
||
|
fn_child = getattr(module_child, "__file__", None)
|
||
|
if (fn_child is not None) and fn_child.startswith(fn_dir) and fn_child not in module_visit:
|
||
|
# print("reloading:", fn_child, "from", module)
|
||
|
module_visit.add(fn_child)
|
||
|
reload_recursive_ex(module_child)
|
||
|
|
||
|
return reload_recursive_ex(package)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="function", autouse=True)
|
||
|
def lit_home(monkeypatch):
|
||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||
|
monkeypatch.setattr(Path, "home", lambda: tmp_dirname)
|
||
|
# we need to reload whole subpackage to apply the mock/fixture
|
||
|
reload_package(lightning.store)
|
||
|
yield os.path.join(tmp_dirname, ".lightning")
|