name: Test PyTorch - TPU on: push: branches: [master, "release/*"] pull_request_target: branches: [master, "release/*"] types: [opened, reopened, ready_for_review, synchronize] # added `ready_for_review` since draft is skipped paths: - ".actions/**" - ".github/workflows/tpu-tests.yml" - "dockers/base-xla/*" - "requirements/fabric/**" - "src/lightning_fabric/**" - "tests/tests_fabric/**" - "requirements/pytorch/**" - "src/pytorch_lightning/**" - "tests/tests_pytorch/**" - "setup.cfg" # includes pytest config - "!requirements/*/docs.txt" - "!*.md" - "!**/*.md" concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref }} cancel-in-progress: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} env: PROJECT_ID: ${{ secrets.GKE_PROJECT }} GKE_CLUSTER: lightning-cluster GKE_ZONE: us-central1-a defaults: run: shell: bash jobs: test-on-tpus: runs-on: ubuntu-22.04 if: github.event.pull_request.draft == false env: PYTHON_VER: 3.7 strategy: fail-fast: false max-parallel: 1 # run sequential matrix: # TODO: add also lightning pkg-name: ["fabric", "pytorch"] timeout-minutes: 100 # should match the timeout in `tpu_workflow.jsonnet` steps: - uses: actions/checkout@v3 with: ref: ${{ github.event.pull_request.head.sha }} - uses: actions/setup-python@v4 with: python-version: ${{ env.PYTHON_VER }} - name: Checkout ml-testing-accelerators run: | git clone cd ml-testing-accelerators git fetch origin 5e88ac24f631c27045e62f0e8d5dfcf34e425e25:stable git checkout stable - uses: actions/setup-go@v3 with: go-version: '1.19' - name: Install jsonnet run: go install - name: Update jsonnet env: SCOPE: ${{ matrix.pkg-name }} XLA_VER: 1.12 PR_NUMBER: ${{ github.event.pull_request.number }} SHA: ${{ github.event.pull_request.head.sha }} run: | import os fname = f'dockers/base-xla/tpu_workflow_{os.getenv("SCOPE")}.jsonnet' with open(fname) as fo: data = data = data.replace('{PYTORCH_VERSION}', os.getenv("XLA_VER")) data = data.replace('{PYTHON_VERSION}', os.getenv("PYTHON_VER")) data = data.replace('{PR_NUMBER}', os.getenv("PR_NUMBER")) data = data.replace('{SHA}', os.getenv("SHA")) with open(fname, "w") as fw: fw.write(data) shell: python - name: Show jsonnet run: cat dockers/base-xla/tpu_workflow_${{ matrix.pkg-name }}.jsonnet - uses: google-github-actions/auth@v1 with: credentials_json: ${{ secrets.GKE_SA_KEY_BASE64 }} # - uses: google-github-actions/get-gke-credentials@v1 with: cluster_name: ${{ env.GKE_CLUSTER }} location: ${{ env.GKE_ZONE }} - name: Deploy cluster run: | export PATH=$PATH:$HOME/go/bin job_name=$(jsonnet -J ml-testing-accelerators/ dockers/base-xla/tpu_workflow_${{ matrix.pkg-name }}.jsonnet | kubectl create -f -) job_name=${job_name#job.batch/} job_name=${job_name% created} pod_name=$(kubectl get po -l controller-uid=`kubectl get job $job_name -o "jsonpath={.metadata.labels.controller-uid}"` | awk 'match($0,!/NAME/) {print $1}') echo "GKE pod name: $pod_name" echo "Waiting on kubernetes job: $job_name" status_code=2 && # Check on the job periodically. Set the status code depending on what happened to the job in Kubernetes. printf "Waiting for job to finish: " while true; do if kubectl get jobs $job_name -o jsonpath='Failed:{.status.failed}' | grep "Failed:1"; then status_code=1 && break; elif kubectl get jobs $job_name -o jsonpath='Succeeded:{.status.succeeded}' | grep "Succeeded:1"; then status_code=0 && break; else printf "."; fi; sleep 5; done echo "Done waiting. Job status code: $status_code" kubectl logs -f $pod_name --container=train > /tmp/full_output.txt if grep -q '' /tmp/full_output.txt; then # successful run. split the output into logs + coverage report csplit /tmp/full_output.txt '//'; cat xx00 # test logs mv xx01 coverage.xml else # failed run, print everything cat /tmp/full_output.txt; fi exit $status_code shell: bash - name: Upload coverage to Codecov uses: codecov/codecov-action@v3 # see: continue-on-error: true with: token: ${{ secrets.CODECOV_TOKEN }} file: coverage.xml flags: tpu,pytest,python${{ env.PYTHON_VER }} name: TPU-coverage fail_ci_if_error: false