Fix mypy checks for double precision plugin (#7151)
This commit is contained in:
parent
d6470bf193
commit
3f1a08ab00
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
from contextlib import contextmanager
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Any, Generator, List, Tuple
|
from typing import Any, Generator, List, Tuple
|
||||||
|
|
||||||
|
@ -92,8 +92,8 @@ class DoublePrecisionPlugin(PrecisionPlugin):
|
||||||
while len(self.patches) > 0:
|
while len(self.patches) > 0:
|
||||||
self.patches.pop().teardown()
|
self.patches.pop().teardown()
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextmanager
|
||||||
def tensor_type_context(self) -> Generator:
|
def train_step_context(self) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
A context manager to change the default tensor type.
|
A context manager to change the default tensor type.
|
||||||
See: :meth:`torch.set_default_tensor_type`
|
See: :meth:`torch.set_default_tensor_type`
|
||||||
|
@ -102,7 +102,32 @@ class DoublePrecisionPlugin(PrecisionPlugin):
|
||||||
yield
|
yield
|
||||||
torch.set_default_tensor_type(torch.FloatTensor)
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
|
|
||||||
train_step_context = tensor_type_context
|
@contextmanager
|
||||||
val_step_context = tensor_type_context
|
def val_step_context(self) -> Generator[None, None, None]:
|
||||||
test_step_context = tensor_type_context
|
"""
|
||||||
predict_step_context = tensor_type_context
|
A context manager to change the default tensor type.
|
||||||
|
See: :meth:`torch.set_default_tensor_type`
|
||||||
|
"""
|
||||||
|
torch.set_default_tensor_type(torch.DoubleTensor)
|
||||||
|
yield
|
||||||
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def test_step_context(self) -> Generator[None, None, None]:
|
||||||
|
"""
|
||||||
|
A context manager to change the default tensor type.
|
||||||
|
See: :meth:`torch.set_default_tensor_type`
|
||||||
|
"""
|
||||||
|
torch.set_default_tensor_type(torch.DoubleTensor)
|
||||||
|
yield
|
||||||
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def predict_step_context(self) -> Generator[None, None, None]:
|
||||||
|
"""
|
||||||
|
A context manager to change the default tensor type.
|
||||||
|
See: :meth:`torch.set_default_tensor_type`
|
||||||
|
"""
|
||||||
|
torch.set_default_tensor_type(torch.DoubleTensor)
|
||||||
|
yield
|
||||||
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
|
|
Loading…
Reference in New Issue