Remove old test artifacts (#14574)
This commit is contained in:
parent
46519e2fc7
commit
bcad90141a
|
@ -11,11 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
_TEST_ROOT = os.path.dirname(__file__)
|
||||
_PROJECT_ROOT = os.path.dirname(_TEST_ROOT)
|
||||
_TEMP_PATH = os.path.join(_PROJECT_ROOT, "test_temp")
|
||||
|
@ -27,10 +24,6 @@ if _PROJECT_ROOT not in os.getenv("PYTHONPATH", ""):
|
|||
splitter = ":" if os.environ.get("PYTHONPATH", "") else ""
|
||||
os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}'
|
||||
|
||||
# generate a list of random seeds for each test
|
||||
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
|
||||
|
||||
if not os.path.isdir(_TEMP_PATH):
|
||||
os.mkdir(_TEMP_PATH)
|
||||
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
|
|
@ -31,7 +31,7 @@ def run_model_test_without_loggers(
|
|||
# correct result and ok accuracy
|
||||
assert trainer.state.finished, f"Training failed with {trainer.state}"
|
||||
|
||||
model2 = load_model_from_checkpoint(trainer.logger, trainer.checkpoint_callback.best_model_path, type(model))
|
||||
model2 = load_model_from_checkpoint(trainer.checkpoint_callback.best_model_path, type(model))
|
||||
|
||||
# test new model accuracy
|
||||
test_loaders = model2.test_dataloader() if not data else data.test_dataloader()
|
||||
|
@ -68,7 +68,7 @@ def run_model_test(
|
|||
assert change_ratio > 0.03, f"the model is changed of {change_ratio}"
|
||||
|
||||
# test model loading
|
||||
_ = load_model_from_checkpoint(logger, trainer.checkpoint_callback.best_model_path, type(model))
|
||||
_ = load_model_from_checkpoint(trainer.checkpoint_callback.best_model_path, type(model))
|
||||
|
||||
# test new model accuracy
|
||||
test_loaders = model.test_dataloader() if not data else data.test_dataloader()
|
||||
|
|
|
@ -17,13 +17,14 @@ import re
|
|||
from contextlib import contextmanager
|
||||
from typing import Optional, Type
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.demos.boring_classes import BoringModel
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests_pytorch import _TEMP_PATH, RANDOM_PORTS
|
||||
from tests_pytorch import _TEMP_PATH
|
||||
|
||||
|
||||
def get_default_logger(save_dir, version=None):
|
||||
|
@ -52,7 +53,7 @@ def get_data_path(expt_logger, path_dir=None):
|
|||
return path_expt
|
||||
|
||||
|
||||
def load_model_from_checkpoint(logger, root_weights_dir, module_class=BoringModel):
|
||||
def load_model_from_checkpoint(root_weights_dir, module_class=BoringModel):
|
||||
trained_model = module_class.load_from_checkpoint(root_weights_dir)
|
||||
assert trained_model is not None, "loading model failed"
|
||||
return trained_model
|
||||
|
@ -68,6 +69,10 @@ def reset_seed(seed=0):
|
|||
seed_everything(seed)
|
||||
|
||||
|
||||
# generate a list of random seeds for each test
|
||||
RANDOM_PORTS = list(np.random.randint(12000, 19000, 1000))
|
||||
|
||||
|
||||
def set_random_main_port():
|
||||
reset_seed()
|
||||
port = RANDOM_PORTS.pop()
|
||||
|
|
Loading…
Reference in New Issue