[bugfix] Add set_default_tensor_type to torch.DoubleTensor with precision=64 (#7108)
* update * Update pytorch_lightning/plugins/precision/double.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/precision/double.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * Update pytorch_lightning/plugins/precision/double.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * resolve tests Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
ca21da4f3b
commit
013756404b
|
@ -242,7 +242,7 @@ class Accelerator:
|
|||
|
||||
args[0] = batch
|
||||
|
||||
with self.precision_plugin.predict_context(), self.training_type_plugin.predict_context():
|
||||
with self.precision_plugin.predict_step_context(), self.training_type_plugin.predict_step_context():
|
||||
return self.training_type_plugin.predict_step(*args)
|
||||
|
||||
def training_step_end(self, output: STEP_OUTPUT) -> STEP_OUTPUT:
|
||||
|
|
|
@ -41,6 +41,6 @@ class Plugin(ABC):
|
|||
yield
|
||||
|
||||
@contextlib.contextmanager
|
||||
def predict_context(self) -> Generator:
|
||||
def predict_step_context(self) -> Generator:
|
||||
"""A contextmanager for the predict step"""
|
||||
yield
|
||||
|
|
|
@ -11,8 +11,9 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from functools import wraps
|
||||
from typing import Any, List, Tuple
|
||||
from typing import Any, Generator, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -90,3 +91,18 @@ class DoublePrecisionPlugin(PrecisionPlugin):
|
|||
def post_dispatch(self) -> None:
|
||||
while len(self.patches) > 0:
|
||||
self.patches.pop().teardown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tensor_type_context(self) -> Generator:
|
||||
"""
|
||||
A context manager to change the default tensor type.
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
"""
|
||||
torch.set_default_tensor_type(torch.DoubleTensor)
|
||||
yield
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
|
||||
train_step_context = tensor_type_context
|
||||
val_step_context = tensor_type_context
|
||||
test_step_context = tensor_type_context
|
||||
predict_step_context = tensor_type_context
|
||||
|
|
|
@ -115,7 +115,7 @@ class NativeMixedPrecisionPlugin(MixedPrecisionPlugin):
|
|||
yield
|
||||
|
||||
@contextmanager
|
||||
def predict_context(self) -> Generator[None, None, None]:
|
||||
def predict_step_context(self) -> Generator[None, None, None]:
|
||||
"""Enable autocast context"""
|
||||
with torch.cuda.amp.autocast():
|
||||
yield
|
||||
|
|
|
@ -37,25 +37,37 @@ class DoublePrecisionBoringModel(BoringModel):
|
|||
|
||||
def training_step(self, batch, batch_idx):
|
||||
float_data, int_data = batch
|
||||
assert torch.tensor([0.]).dtype == torch.float64
|
||||
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
|
||||
assert float_data.dtype == torch.float64
|
||||
output = self(float_data)
|
||||
loss = self.loss(batch, output)
|
||||
return {"loss": loss}
|
||||
|
||||
def training_epoch_end(self, outputs) -> None:
|
||||
assert torch.tensor([0.]).dtype == torch.float32
|
||||
return super().training_epoch_end(outputs)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
assert batch.dtype == torch.float64
|
||||
assert torch.tensor([0.]).dtype == torch.float64
|
||||
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
|
||||
output = self(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"x": loss}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
assert batch.dtype == torch.float64
|
||||
assert torch.tensor([0.]).dtype == torch.float64
|
||||
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
|
||||
output = self(batch)
|
||||
loss = self.loss(batch, output)
|
||||
return {"y": loss}
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
assert batch.dtype == torch.float64
|
||||
assert torch.tensor([0.]).dtype == torch.float64
|
||||
assert torch.tensor([0.], dtype=torch.float16).dtype == torch.float16
|
||||
return self(batch)
|
||||
|
||||
def on_fit_start(self):
|
||||
|
|
Loading…
Reference in New Issue