diff --git a/tests/tests_pytorch/__init__.py b/tests/tests_pytorch/__init__.py index 9039a6e4b1..2731ae3124 100644 --- a/tests/tests_pytorch/__init__.py +++ b/tests/tests_pytorch/__init__.py @@ -11,11 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging import os -import numpy as np - _TEST_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_TEST_ROOT) _TEMP_PATH = os.path.join(_PROJECT_ROOT, "test_temp") @@ -27,10 +24,6 @@ if _PROJECT_ROOT not in os.getenv("PYTHONPATH", ""): splitter = ":" if os.environ.get("PYTHONPATH", "") else "" os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}' -# generate a list of random seeds for each test -RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000)) if not os.path.isdir(_TEMP_PATH): os.mkdir(_TEMP_PATH) - -logging.basicConfig(level=logging.ERROR) diff --git a/tests/tests_pytorch/helpers/pipelines.py b/tests/tests_pytorch/helpers/pipelines.py index 3de3d75563..3cbc49f11c 100644 --- a/tests/tests_pytorch/helpers/pipelines.py +++ b/tests/tests_pytorch/helpers/pipelines.py @@ -31,7 +31,7 @@ def run_model_test_without_loggers( # correct result and ok accuracy assert trainer.state.finished, f"Training failed with {trainer.state}" - model2 = load_model_from_checkpoint(trainer.logger, trainer.checkpoint_callback.best_model_path, type(model)) + model2 = load_model_from_checkpoint(trainer.checkpoint_callback.best_model_path, type(model)) # test new model accuracy test_loaders = model2.test_dataloader() if not data else data.test_dataloader() @@ -68,7 +68,7 @@ def run_model_test( assert change_ratio > 0.03, f"the model is changed of {change_ratio}" # test model loading - _ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model)) + _ = load_model_from_checkpoint(trainer.checkpoint_callback.best_model_path, type(model)) # test new model accuracy test_loaders = model.test_dataloader() if not data else data.test_dataloader() diff --git a/tests/tests_pytorch/helpers/utils.py b/tests/tests_pytorch/helpers/utils.py index 54503bf75b..18393e2193 100644 --- a/tests/tests_pytorch/helpers/utils.py +++ b/tests/tests_pytorch/helpers/utils.py @@ -17,13 +17,14 @@ import re from contextlib import contextmanager from typing import Optional, Type +import numpy as np import pytest from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.loggers import TensorBoardLogger -from tests_pytorch import _TEMP_PATH, RANDOM_PORTS +from tests_pytorch import _TEMP_PATH def get_default_logger(save_dir, version=None): @@ -52,7 +53,7 @@ def get_data_path(expt_logger, path_dir=None): return path_expt -def load_model_from_checkpoint(logger, root_weights_dir, module_class=BoringModel): +def load_model_from_checkpoint(root_weights_dir, module_class=BoringModel): trained_model = module_class.load_from_checkpoint(root_weights_dir) assert trained_model is not None, "loading model failed" return trained_model @@ -68,6 +69,10 @@ def reset_seed(seed=0): seed_everything(seed) +# generate a list of random seeds for each test +RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000)) + + def set_random_main_port(): reset_seed() port = RANDOM_PORTS.pop()