[bugfix] Add set_default_tensor_type to torch.DoubleTensor with precision=64 (#7108)

* update

* Update pytorch_lightning/plugins/precision/double.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/plugins/precision/double.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* Update pytorch_lightning/plugins/precision/double.py

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>

* resolve tests

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
thomas chaton 2021-04-20 16:25:37 +01:00 committed by GitHub
parent ca21da4f3b
commit 013756404b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 32 additions and 4 deletions

View File

@ -242,7 +242,7 @@ class Accelerator:
args[0] = batch
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
return self.training_type_plugin.predict_step(*args)
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:

View File

@ -41,6 +41,6 @@ class Plugin(ABC):
yield
@contextlib.contextmanager
def predict_context(self) -> Generator:
def predict_step_context(self) -> Generator:
"""A contextmanager for the predict step"""
yield

View File

@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from functools import wraps
from typing import Any, List, Tuple
from typing import Any, Generator, List, Tuple
import torch
import torch.nn as nn
@ -90,3 +91,18 @@ class DoublePrecisionPlugin(PrecisionPlugin):
def post_dispatch(self) -> None:
while len(self.patches) > 0:
self.patches.pop().teardown()
@contextlib.contextmanager
def tensor_type_context(self) -> Generator:
"""
A context manager to change the default tensor type.
See: :meth:`torch.set_default_tensor_type`
"""
torch.set_default_tensor_type(torch.DoubleTensor)
yield
torch.set_default_tensor_type(torch.FloatTensor)
train_step_context = tensor_type_context
val_step_context = tensor_type_context
test_step_context = tensor_type_context
predict_step_context = tensor_type_context

View File

@ -115,7 +115,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
yield
@contextmanager
def predict_context(self) -> Generator[None, None, None]:
def predict_step_context(self) -> Generator[None, None, None]:
"""Enable autocast context"""
with torch.cuda.amp.autocast():
yield

View File

@ -37,25 +37,37 @@ class DoublePrecisionBoringModel(BoringModel):
def training_step(self, batch, batch_idx):
float_data, int_data = batch
assert torch.tensor([0.]).dtype == torch.float64
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
assert float_data.dtype == torch.float64
output = self(float_data)
loss = self.loss(batch, output)
return {"loss": loss}
def training_epoch_end(self, outputs) -> None:
assert torch.tensor([0.]).dtype == torch.float32
return super().training_epoch_end(outputs)
def validation_step(self, batch, batch_idx):
assert batch.dtype == torch.float64
assert torch.tensor([0.]).dtype == torch.float64
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
output = self(batch)
loss = self.loss(batch, output)
return {"x": loss}
def test_step(self, batch, batch_idx):
assert batch.dtype == torch.float64
assert torch.tensor([0.]).dtype == torch.float64
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}
def predict_step(self, batch, batch_idx, dataloader_idx=None):
assert batch.dtype == torch.float64
assert torch.tensor([0.]).dtype == torch.float64
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
return self(batch)
def on_fit_start(self):