Remove old test artifacts (#14574)

This commit is contained in:
Carlos Mocholí 2022-09-07 16:09:59 +02:00 committed by GitHub
parent 46519e2fc7
commit bcad90141a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 11 deletions

View File

@ -11,11 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import numpy as np
_TEST_ROOT = os.path.dirname(__file__)
_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
_TEMP_PATH = os.path.join(_PROJECT_ROOT, "test_temp")
@ -27,10 +24,6 @@ if _PROJECT_ROOT not in os.getenv("PYTHONPATH", ""):
splitter = ":" if os.environ.get("PYTHONPATH", "") else ""
os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}'
# generate a list of random seeds for each test
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
if not os.path.isdir(_TEMP_PATH):
os.mkdir(_TEMP_PATH)
logging.basicConfig(level=logging.ERROR)

View File

@ -31,7 +31,7 @@ def run_model_test_without_loggers(
# correct result and ok accuracy
assert trainer.state.finished, f"Training failed with {trainer.state}"
model2 = load_model_from_checkpoint(trainer.logger, trainer.checkpoint_callback.best_model_path, type(model))
model2 = load_model_from_checkpoint(trainer.checkpoint_callback.best_model_path, type(model))
# test new model accuracy
test_loaders = model2.test_dataloader() if not data else data.test_dataloader()
@ -68,7 +68,7 @@ def run_model_test(
assert change_ratio > 0.03, f"the model is changed of {change_ratio}"
# test model loading
_ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))
_ = load_model_from_checkpoint(trainer.checkpoint_callback.best_model_path, type(model))
# test new model accuracy
test_loaders = model.test_dataloader() if not data else data.test_dataloader()

View File

@ -17,13 +17,14 @@ import re
from contextlib import contextmanager
from typing import Optional, Type
import numpy as np
import pytest
from pytorch_lightning import seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.loggers import TensorBoardLogger
from tests_pytorch import _TEMP_PATH, RANDOM_PORTS
from tests_pytorch import _TEMP_PATH
def get_default_logger(save_dir, version=None):
@ -52,7 +53,7 @@ def get_data_path(expt_logger, path_dir=None):
return path_expt
def load_model_from_checkpoint(logger, root_weights_dir, module_class=BoringModel):
def load_model_from_checkpoint(root_weights_dir, module_class=BoringModel):
trained_model = module_class.load_from_checkpoint(root_weights_dir)
assert trained_model is not None, "loading model failed"
return trained_model
@ -68,6 +69,10 @@ def reset_seed(seed=0):
seed_everything(seed)
# generate a list of random seeds for each test
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
def set_random_main_port():
reset_seed()
port = RANDOM_PORTS.pop()