Fix mypy checks for double precision plugin (#7151)
This commit is contained in:
parent
d6470bf193
commit
3f1a08ab00
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
from typing import Any, Generator, List, Tuple
|
||||
|
||||
|
@ -92,8 +92,8 @@ class DoublePrecisionPlugin(PrecisionPlugin):
|
|||
while len(self.patches) > 0:
|
||||
self.patches.pop().teardown()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def tensor_type_context(self) -> Generator:
|
||||
@contextmanager
|
||||
def train_step_context(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
A context manager to change the default tensor type.
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
|
@ -102,7 +102,32 @@ class DoublePrecisionPlugin(PrecisionPlugin):
|
|||
yield
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
|
||||
train_step_context = tensor_type_context
|
||||
val_step_context = tensor_type_context
|
||||
test_step_context = tensor_type_context
|
||||
predict_step_context = tensor_type_context
|
||||
@contextmanager
|
||||
def val_step_context(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
A context manager to change the default tensor type.
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
"""
|
||||
torch.set_default_tensor_type(torch.DoubleTensor)
|
||||
yield
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
|
||||
@contextmanager
|
||||
def test_step_context(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
A context manager to change the default tensor type.
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
"""
|
||||
torch.set_default_tensor_type(torch.DoubleTensor)
|
||||
yield
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
|
||||
@contextmanager
|
||||
def predict_step_context(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
A context manager to change the default tensor type.
|
||||
See: :meth:`torch.set_default_tensor_type`
|
||||
"""
|
||||
torch.set_default_tensor_type(torch.DoubleTensor)
|
||||
yield
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
|
|
Loading…
Reference in New Issue