lightning/pytorch_lightning/plugins/precision/double.py

93 lines
3.6 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import wraps
from typing import Any, List, Sequence, Tuple
import torch
import torch.nn as nn
from torch.optim import Optimizer
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.apply_func import apply_to_collection
class _DoublePrecisionPatch:
"""Class to handle patching of methods in the ``LightningModule`` and subsequent teardown."""
def __init__(self, model: nn.Module, method_name: str, old_method: Any) -> None:
self.model = model
self.method_name = method_name
self.old_method = old_method
def teardown(self) -> None:
setattr(self.model, self.method_name, self.old_method)
@staticmethod
def _to_double_precision(data: torch.Tensor) -> torch.Tensor:
if data.is_floating_point():
return data.double()
return data
@staticmethod
def _move_float_tensors_to_double(collection: Any) -> Any:
DeepSpeed ZeRO Update (#6546) * Add context to call hook to handle all modules defined within the hook * Expose some additional parameters * Added docs, exposed parameters * Make sure we only configure if necessary * Setup activation checkpointing regardless, saves the user having to do it manually * Add some tests that fail currently * update * update * update * add tests * change docstring * resolve accumulate_grad_batches * resolve flake8 * Update DeepSpeed to use latest version, add some comments * add metrics * update * Small formatting fixes, clean up some code * Few cleanups * No need for default state * Fix tests, add some boilerplate that should move eventually * Add hook removal * Add a context manager to handle hook * Small naming cleanup * wip * move save_checkpoint responsability to accelerator * resolve flake8 * add BC * Change recommended scale to 16 * resolve flake8 * update test * update install * update * update test * update * update * update test * resolve flake8 * update * update * update on comments * Push * pull * Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update pytorch_lightning/plugins/training_type/deepspeed.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * update * Apply suggestions from code review * Swap to using world size defined by plugin * update * update todo * Remove deepspeed from extra, keep it in the base cuda docker install * Push * pull * update * update * update * update * Minor changes * duplicate * format * format2 Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-03-30 17:39:02 +00:00
return apply_to_collection(collection, torch.Tensor, function=_DoublePrecisionPatch._to_double_precision)
@classmethod
def patch(cls, model: nn.Module, method_name: str) -> '_DoublePrecisionPatch':
old_method = getattr(model, method_name)
@wraps(old_method)
def new_method(*args: Any, **kwargs: Any) -> Any:
return old_method(
*_DoublePrecisionPatch._move_float_tensors_to_double(args),
**_DoublePrecisionPatch._move_float_tensors_to_double(kwargs)
)
setattr(model, method_name, new_method if callable(old_method) else old_method)
return cls(model, method_name, old_method)
class DoublePrecisionPlugin(PrecisionPlugin):
"""Plugin for training with double (``torch.float64``) precision."""
precision: int = 64
def __init__(self) -> None:
Add `Trainer(gradient_clip_algorithm='value'|'norm')` (#6123) * add changelog * add clip by value * fix bug in training tricks.rst * fix bug in trainer.rst * Update trainer.rst * Update trainer.rst * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/plugins/precision/deepspeed_precision.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/utilities/enums.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * yapf formatting * update training tricks * update based on comment * update based on comment * Update pytorch_lightning/trainer/trainer.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * update based on comment * pep8 * mypy * mypy * Update docs/source/advanced/training_tricks.rst Co-authored-by: thomas chaton <thomas@grid.ai> * Update sharded_native_amp.py * Update test_sharded_parity.py * update test codes * Update test_tpu.py * Update pytorch_lightning/trainer/connectors/training_trick_connector.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update test_trainer.py * Update enums.py * Update enums.py * add super-class initialization to precision plugins. * add clip_grad horovod cpu test * add clip_grad horovod cpu test * use subprocess check_call * change order of horovod tests * set max_epochs 2 in horovod test * remove clip_grad_val test from horovod-cpu * remove "type: ignore" * divide clip grad val test in horovod * update based on comments * add super-class initialization to precision plugins. * bugfix * bugfix * revert some changes * revert some changes * Update tests/models/test_horovod.py * merge master * Delete signature test No point in testing a signature Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: thomas chaton <thomas@grid.ai> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-04-06 13:27:37 +00:00
super().__init__()
self.patches: List[_DoublePrecisionPatch] = []
def connect(
self,
model: nn.Module,
optimizers: Sequence[Optimizer],
lr_schedulers: Sequence[Any],
) -> Tuple[nn.Module, Sequence[Optimizer], Sequence[Any]]:
"""Converts the model to double precision and wraps the `training_step`, `validation_step`, `test_step`,
`predict_step`, and `forward` methods to convert incoming floating point data to double. Does not alter
`optimizers` or `lr_schedulers`."""
model = model.to(dtype=torch.float64)
if isinstance(model, LightningModule):
self.patches.append(_DoublePrecisionPatch.patch(model, 'training_step'))
self.patches.append(_DoublePrecisionPatch.patch(model, 'validation_step'))
self.patches.append(_DoublePrecisionPatch.patch(model, 'test_step'))
self.patches.append(_DoublePrecisionPatch.patch(model, 'predict_step'))
self.patches.append(_DoublePrecisionPatch.patch(model, 'forward'))
return super().connect(model, optimizers, lr_schedulers)
def post_dispatch(self) -> None:
while len(self.patches) > 0:
self.patches.pop().teardown()