lightning/tests/trainer/loops/test_all.py

92 lines
3.1 KiB
Python
Raw Normal View History

move batch to device before sending it to hooks (#7378) * update train step * test * x * limits * val * typeo * x * x * step * min gpus * run all loops * x * limit test * profiler * clean up accelerator code * move files * rename * move tests * changelog * reorder callbacks and model hooks * add test description * replace unneccessary method * fix chlog * adjust batch_to_device for DP Plugin * update tests for dataloader idx * unused imports * hook change * switch None * clear memory * change to None * None * None * memory savings * remove redundant todo * hack * cheat * Revert "cheat" This reverts commit a8433bd0b4bd35f218993335f7d4ff18977ae423. * Revert "hack" This reverts commit 43a6d1edeb62a15ac69ef69ef2352581ba1947a5. * update new epoch loop * remove from old loop code * update chlog * update hook test * changelog * teardown * integrate changes in new eval loop * fix hook calls * add prediction step * bad merge * Revert "bad merge" This reverts commit 488080863cf012dcf04446be3b7d973b7340687e. * fix train batch hook test * rm -rf _notebooks * update chlog * release memory * fix type * notebooks mess * debug * Revert "debug" This reverts commit eec4ee2f77b5eb39965211a250598ed5d2320e88. * teardown * fix teardown bug * debug * x * debug * Revert "debug" This reverts commit a6e61019462b80d09d31b65bed289fa6e4dd15f6. Revert "debug" This reverts commit 5ddeaec06911e96730aade1be6ee71d097b46b9a. debug debug Revert "debug" This reverts commit 605be746f7daedf265b2c05a1c153ce543394435. Revert "Revert "debug"" This reverts commit a7612d5410409ed886cfb609457349ecf44cbfa8. debug x x x s tol x tol * Fix changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
2021-07-05 08:31:39 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import Callback, Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
class BatchHookObserverCallback(Callback):
def on_train_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device
def on_train_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device
def on_validation_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device
def on_test_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device
def on_test_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device
def on_predict_batch_start(self, trainer, pl_module, batch, *args):
assert batch.device == pl_module.device
def on_predict_batch_end(self, trainer, pl_module, outputs, batch, *args):
assert batch.device == pl_module.device
class BatchHookObserverModel(BoringModel):
def on_train_batch_start(self, batch, *args):
assert batch.device == self.device
def on_train_batch_end(self, outputs, batch, *args):
assert batch.device == self.device
def on_validation_batch_start(self, batch, *args):
assert batch.device == self.device
def on_validation_batch_end(self, outputs, batch, *args):
assert batch.device == self.device
def on_test_batch_start(self, batch, *args):
assert batch.device == self.device
def on_test_batch_end(self, outputs, batch, *args):
assert batch.device == self.device
def on_predict_batch_start(self, batch, *args):
assert batch.device == self.device
def on_predict_batch_end(self, outputs, batch, *args):
assert batch.device == self.device
@RunIf(min_gpus=1)
def test_callback_batch_on_device(tmpdir):
"""Test that the batch object sent to the on_*_batch_start/end hooks is on the right device."""
move batch to device before sending it to hooks (#7378) * update train step * test * x * limits * val * typeo * x * x * step * min gpus * run all loops * x * limit test * profiler * clean up accelerator code * move files * rename * move tests * changelog * reorder callbacks and model hooks * add test description * replace unneccessary method * fix chlog * adjust batch_to_device for DP Plugin * update tests for dataloader idx * unused imports * hook change * switch None * clear memory * change to None * None * None * memory savings * remove redundant todo * hack * cheat * Revert "cheat" This reverts commit a8433bd0b4bd35f218993335f7d4ff18977ae423. * Revert "hack" This reverts commit 43a6d1edeb62a15ac69ef69ef2352581ba1947a5. * update new epoch loop * remove from old loop code * update chlog * update hook test * changelog * teardown * integrate changes in new eval loop * fix hook calls * add prediction step * bad merge * Revert "bad merge" This reverts commit 488080863cf012dcf04446be3b7d973b7340687e. * fix train batch hook test * rm -rf _notebooks * update chlog * release memory * fix type * notebooks mess * debug * Revert "debug" This reverts commit eec4ee2f77b5eb39965211a250598ed5d2320e88. * teardown * fix teardown bug * debug * x * debug * Revert "debug" This reverts commit a6e61019462b80d09d31b65bed289fa6e4dd15f6. Revert "debug" This reverts commit 5ddeaec06911e96730aade1be6ee71d097b46b9a. debug debug Revert "debug" This reverts commit 605be746f7daedf265b2c05a1c153ce543394435. Revert "Revert "debug"" This reverts commit a7612d5410409ed886cfb609457349ecf44cbfa8. debug x x x s tol x tol * Fix changelog Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
2021-07-05 08:31:39 +00:00
batch_callback = BatchHookObserverCallback()
model = BatchHookObserverModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_steps=1,
limit_train_batches=1,
limit_val_batches=1,
limit_test_batches=1,
limit_predict_batches=1,
gpus=1,
callbacks=[batch_callback],
)
trainer.fit(model)
trainer.validate(model)
trainer.test(model)
trainer.predict(model)