tests for legacy checkpoints (#5223)
* wip * generate * clean * tests * copy * download * download * download * download * download * download * download * download * download * download * download * flake8 * extend * aws * extension * pull * pull * pull * pull * pull * pull * pull * try * try * try * got it * Apply suggestions from code review
This commit is contained in:
parent
4c6f36e6e1
commit
72525f0a83
|
@ -39,6 +39,14 @@ steps:
|
|||
# when Image has defined CUDa version we can switch to this package spec "nvidia-dali-cuda${CUDA_VERSION%%.*}0"
|
||||
- pip install --extra-index-url https://developer.download.nvidia.com/compute/redist nvidia-dali-cuda100 --upgrade-strategy only-if-needed
|
||||
- pip list
|
||||
# todo: remove unzip install after new nigtly docker is created
|
||||
- apt-get update -qq
|
||||
- apt-get install -y --no-install-recommends unzip
|
||||
# get legacy checkpoints
|
||||
- wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/
|
||||
- unzip -o legacy/checkpoints.zip -d legacy/
|
||||
- ls -l legacy/checkpoints/
|
||||
# testing...
|
||||
- python -m coverage run --source pytorch_lightning -m pytest pytorch_lightning tests -v --durations=25 # --flake8
|
||||
# Running special tests
|
||||
- sh tests/special_tests.sh
|
||||
|
|
|
@ -34,10 +34,21 @@ jobs:
|
|||
# todo this probably does not work with docker images, rather cache dockers
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: Datasets # This path is specific to Ubuntu
|
||||
# Look to see if there is a cache hit for the corresponding requirements file
|
||||
path: Datasets
|
||||
key: pl-dataset
|
||||
|
||||
- name: Pull checkpoints from S3
|
||||
# todo: consider adding coma caching, but ATM all models have less then 100KB
|
||||
run: |
|
||||
# todo: remove unzip install after new nigtly docker is created
|
||||
apt-get update -qq
|
||||
apt-get install -y --no-install-recommends unzip
|
||||
# enter legacy and update checkpoints from S3
|
||||
cd legacy
|
||||
curl https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip --output checkpoints.zip
|
||||
unzip -o checkpoints.zip
|
||||
ls -l checkpoints/
|
||||
|
||||
- name: Tests
|
||||
run: |
|
||||
# NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003
|
||||
|
|
|
@ -104,6 +104,16 @@ jobs:
|
|||
restore-keys: |
|
||||
${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}-
|
||||
|
||||
- name: Pull checkpoints from S3
|
||||
# todo: consider adding some caching, but ATM all models have less then 100KB
|
||||
run: |
|
||||
cd legacy
|
||||
# wget is simpler but does not work on Windows
|
||||
python -c "from urllib.request import urlretrieve ; urlretrieve('https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip', 'checkpoints.zip')"
|
||||
ls -l .
|
||||
unzip -o checkpoints.zip
|
||||
ls -l checkpoints/
|
||||
|
||||
- name: Install dependencies
|
||||
env:
|
||||
# MAKEFLAGS: "-j2"
|
||||
|
@ -136,8 +146,7 @@ jobs:
|
|||
- name: Cache datasets
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: Datasets # This path is specific to Ubuntu
|
||||
# Look to see if there is a cache hit for the corresponding requirements file
|
||||
path: Datasets
|
||||
key: pl-dataset
|
||||
|
||||
- name: Tests
|
||||
|
|
|
@ -35,7 +35,7 @@ jobs:
|
|||
with:
|
||||
time: 5m
|
||||
|
||||
# We do this, since failures on test.pypi aren't that bad
|
||||
# We do this, since failures on test.pypi aren't that bad
|
||||
- name: Publish to Test PyPI
|
||||
uses: pypa/gh-action-pypi-publish@v1.4.1
|
||||
with:
|
||||
|
|
|
@ -5,7 +5,7 @@ on: # Trigger the workflow on push or pull request, but only for the master bra
|
|||
push:
|
||||
branches: [master, "release/*"] # include release branches like release/1.0.x
|
||||
release:
|
||||
types: [created, "release/*"]
|
||||
types: [created]
|
||||
|
||||
|
||||
jobs:
|
||||
|
@ -61,3 +61,51 @@ jobs:
|
|||
with:
|
||||
user: __token__
|
||||
password: ${{ secrets.pypi_password }}
|
||||
|
||||
# Note: This uses an internal pip API and may not always work
|
||||
# https://github.com/actions/cache/blob/master/examples.md#multiple-oss-in-a-workflow
|
||||
- name: Cache pip
|
||||
uses: actions/cache@v2
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||
restore-keys: ${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -r requirements.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet
|
||||
pip install virtualenv
|
||||
pip install awscli
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@v1
|
||||
with:
|
||||
aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
aws-secret-access-key: ${{ secrets.AWS_SECRET_KEY_ID }}
|
||||
aws-region: us-east-1
|
||||
|
||||
- name: Pull files from S3
|
||||
run: |
|
||||
aws s3 cp --recursive s3://pl-public-data/legacy/checkpoints/ legacy/checkpoints/ # --acl public-read
|
||||
ls -l legacy/checkpoints/
|
||||
|
||||
- name: Generate checkpoint
|
||||
if: startsWith(github.event.ref, 'refs/tags') || github.event_name == 'release'
|
||||
run: |
|
||||
virtualenv vEnv --system-site-packages
|
||||
source vEnv/bin/activate
|
||||
pip install dist/*
|
||||
|
||||
pl_ver=$(python -c "import pytorch_lightning as pl ; print(pl.__version__)" 2>&1)
|
||||
# generate checkpoint to this version
|
||||
bash legacy/generate_checkpoints.sh $pl_ver
|
||||
|
||||
deactivate
|
||||
rm -rf vEnv
|
||||
|
||||
- name: Push files to S3
|
||||
run: |
|
||||
aws s3 sync legacy/checkpoints/ s3://pl-public-data/legacy/checkpoints/
|
||||
cd legacy
|
||||
zip -r checkpoints.zip checkpoints
|
||||
aws s3 cp checkpoints.zip s3://pl-public-data/legacy/ --acl public-read
|
||||
|
|
|
@ -27,6 +27,7 @@ timit_data/
|
|||
# C extensions
|
||||
*.so
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# Distribution / packaging
|
||||
|
@ -126,11 +127,14 @@ ENV/
|
|||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
# pytest
|
||||
.pytest_cache/
|
||||
|
||||
# data
|
||||
.data/
|
||||
Datasets/
|
||||
mnist/
|
||||
legacy/checkpoints/
|
||||
|
||||
# pl tests
|
||||
ml-runs/
|
||||
|
|
|
@ -69,3 +69,4 @@ prune temp*
|
|||
prune test*
|
||||
prune benchmark*
|
||||
prune dockers
|
||||
prune legacy
|
||||
|
|
|
@ -40,7 +40,9 @@ RUN apt-get update -qq && \
|
|||
build-essential \
|
||||
cmake \
|
||||
git \
|
||||
wget \
|
||||
curl \
|
||||
unzip \
|
||||
ca-certificates \
|
||||
&& \
|
||||
|
||||
|
|
|
@ -45,6 +45,8 @@ RUN apt-get update -qq && \
|
|||
cmake \
|
||||
git \
|
||||
wget \
|
||||
curl \
|
||||
unzip \
|
||||
ca-certificates \
|
||||
software-properties-common \
|
||||
&& \
|
||||
|
|
|
@ -23,6 +23,12 @@ MAINTAINER PyTorchLightning <https://github.com/PyTorchLightning>
|
|||
|
||||
COPY ./ ./pytorch-lightning/
|
||||
|
||||
# Pull the legacy checkpoints
|
||||
RUN cd pytorch-lightning && \
|
||||
wget https://pl-public-data.s3.amazonaws.com/legacy/checkpoints.zip -P legacy/ && \
|
||||
unzip -o legacy/checkpoints.zip -d legacy/ && \
|
||||
ls -l legacy/checkpoints/
|
||||
|
||||
# If using this image for tests, intall more dependencies and don"t delete the source code where the tests live.
|
||||
RUN \
|
||||
# Install pytorch-lightning at the current PR, plus dependencies.
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
#!/bin/bash
|
||||
# Sample call:
|
||||
# bash generate_checkpoints.sh 1.0.2 1.0.3 1.0.4
|
||||
|
||||
LEGACY_PATH="$( cd "$(dirname "$0")" >/dev/null 2>&1 ; pwd -P )"
|
||||
|
||||
echo $LEGACY_PATH
|
||||
# install some PT version here so it does not need to reinstalled for each env
|
||||
pip install virtualenv "torch==1.5" --quiet --no-cache-dir
|
||||
|
||||
ENV_PATH="$LEGACY_PATH/vEnv"
|
||||
|
||||
# iterate over all arguments assuming that each argument is version
|
||||
for ver in "$@"
|
||||
do
|
||||
echo "processing version: $ver"
|
||||
# mkdir "$LEGACY_PATH/$ver"
|
||||
|
||||
# create local env
|
||||
echo $ENV_PATH
|
||||
virtualenv $ENV_PATH --system-site-packages
|
||||
# activate and install PL version
|
||||
source "$ENV_PATH/bin/activate"
|
||||
pip install "pytorch_lightning==$ver" --quiet -U --no-cache-dir
|
||||
|
||||
python --version
|
||||
pip --version
|
||||
pip list | grep torch
|
||||
|
||||
python "$LEGACY_PATH/zero_training.py"
|
||||
cp "$LEGACY_PATH/zero_training.py" ${LEGACY_PATH}/checkpoints/${ver}
|
||||
|
||||
mv ${LEGACY_PATH}/checkpoints/${ver}/lightning_logs/version_0/checkpoints/*.ckpt ${LEGACY_PATH}/checkpoints/${ver}/
|
||||
rm -rf ${LEGACY_PATH}/checkpoints/${ver}/lightning_logs
|
||||
|
||||
deactivate
|
||||
# clear env
|
||||
rm -rf $ENV_PATH
|
||||
|
||||
done
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
||||
PATH_LEGACY = os.path.dirname(__file__)
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
def __init__(self, size, length: int = 100):
|
||||
self.len = length
|
||||
self.data = torch.randn(length, size)
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.data[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
|
||||
class DummyModel(pl.LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.Linear(32, 2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
def _loss(self, batch, prediction):
|
||||
# An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
|
||||
return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
|
||||
|
||||
def _step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self._loss(batch, output)
|
||||
return loss
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
return self._step(batch, batch_idx)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
self._step(batch, batch_idx)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
self._step(batch, batch_idx)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
|
||||
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def train_dataloader(self):
|
||||
return torch.utils.data.DataLoader(RandomDataset(32, 64))
|
||||
|
||||
def val_dataloader(self):
|
||||
return torch.utils.data.DataLoader(RandomDataset(32, 64))
|
||||
|
||||
def test_dataloader(self):
|
||||
return torch.utils.data.DataLoader(RandomDataset(32, 64))
|
||||
|
||||
|
||||
def main_train(dir_path, max_epochs: int = 5):
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=dir_path,
|
||||
checkpoint_callback=True,
|
||||
max_epochs=max_epochs,
|
||||
)
|
||||
|
||||
model = DummyModel()
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
path_dir = os.path.join(PATH_LEGACY, 'checkpoints', str(pl.__version__))
|
||||
main_train(path_dir)
|
2
setup.py
2
setup.py
|
@ -69,7 +69,7 @@ setup(
|
|||
url=pytorch_lightning.__homepage__,
|
||||
download_url='https://github.com/PyTorchLightning/pytorch-lightning',
|
||||
license=pytorch_lightning.__license__,
|
||||
packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks']),
|
||||
packages=find_packages(exclude=['tests', 'tests/*', 'benchmarks', 'legacy', 'legacy/*']),
|
||||
|
||||
long_description=_load_long_description(PATH_ROOT),
|
||||
long_description_content_type='text/markdown',
|
||||
|
|
|
@ -18,6 +18,8 @@ import numpy as np
|
|||
TEST_ROOT = os.path.dirname(__file__)
|
||||
PROJECT_ROOT = os.path.dirname(TEST_ROOT)
|
||||
TEMP_PATH = os.path.join(PROJECT_ROOT, 'test_temp')
|
||||
DATASETS_PATH = os.path.join(PROJECT_ROOT, 'Datasets')
|
||||
LEGACY_PATH = os.path.join(PROJECT_ROOT, 'legacy')
|
||||
|
||||
# todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages
|
||||
if PROJECT_ROOT not in os.getenv('PYTHONPATH', ""):
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import glob
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from tests import LEGACY_PATH
|
||||
|
||||
LEGACY_CHECKPOINTS_PATH = os.path.join(LEGACY_PATH, 'checkpoints')
|
||||
CHECKPOINT_EXTENSION = ".ckpt"
|
||||
|
||||
|
||||
# todo: add more legacy checkpoints :]
|
||||
@pytest.mark.parametrize("pl_version", [
|
||||
"0.10.0", "1.0.0", "1.0.1", "1.0.2", "1.0.3", "1.0.4", "1.0.5", "1.0.6", "1.0.7", "1.0.8"
|
||||
])
|
||||
def test_resume_legacy_checkpoints(tmpdir, pl_version):
|
||||
path_dir = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version)
|
||||
|
||||
# todo: make this as mock, so it is cleaner...
|
||||
orig_sys_paths = list(sys.path)
|
||||
sys.path.insert(0, path_dir)
|
||||
from zero_training import DummyModel
|
||||
|
||||
path_ckpts = sorted(glob.glob(os.path.join(path_dir, f'*{CHECKPOINT_EXTENSION}')))
|
||||
assert path_ckpts, 'No checkpoints found in folder "%s"' % path_dir
|
||||
path_ckpt = path_ckpts[-1]
|
||||
|
||||
model = DummyModel.load_from_checkpoint(path_ckpt)
|
||||
trainer = Trainer(default_root_dir=tmpdir, max_epochs=6)
|
||||
result = trainer.fit(model)
|
||||
assert result
|
||||
|
||||
# todo
|
||||
# model = DummyModel()
|
||||
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, resume_from_checkpoint=path_ckpt)
|
||||
# result = trainer.fit(model)
|
||||
# assert result
|
||||
|
||||
sys.path = orig_sys_paths
|
Loading…
Reference in New Issue