lightning/pytorch_lightning/utilities/apply_func.py

168 lines
6.1 KiB
Python
Raw Normal View History

2020-08-20 02:03:22 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
2020-07-29 11:54:14 +00:00
from collections.abc import Mapping, Sequence
from copy import copy
from functools import partial
from typing import Any, Callable, Optional, Union
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
import numpy as np
import torch
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHTEXT_AVAILABLE
if _TORCHTEXT_AVAILABLE:
from torchtext.data import Batch
else:
2020-06-29 10:54:21 +00:00
Batch = type(None)
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
def to_dtype_tensor(value, dtype:torch.dtype = None, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
return torch.tensor(value, dtype=dtype, device=device)
def from_numpy(value, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
return torch.from_numpy(value).to(device)
CONVERSION_DTYPES = [
# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group
(bool, partial(to_dtype_tensor, dtype=torch.uint8)),
(int, partial(to_dtype_tensor, dtype=torch.int)),
(float, partial(to_dtype_tensor, dtype=torch.float)),
(np.ndarray, from_numpy),
]
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
def apply_to_collection(data: Any, dtype: Union[type, tuple], function: Callable, *args,
wrong_dtype: Optional[Union[type, tuple]] = None, **kwargs) -> Any:
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
"""
Recursively applies a function to all elements of a certain dtype.
2021-01-26 01:21:00 +00:00
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
Args:
data: the collection to apply the function to
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
wrong_dtype: the given function won't be applied if this type is specified and the given collections is of
the :attr:`wrong_type` even if it is of type :attr`dtype`
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
2021-01-26 01:21:00 +00:00
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
Returns:
the resulting collection
"""
elem_type = type(data)
# Breaking condition
Add Support for multiple train loaders (#1959) * add support for wrong dtype in apply_func * apply loader resetting to possible collection of loaders * add combined loader iter class * integrate combined loader iter to training loop * fix imports * fix imports * finish supporters * add tests for supporters * add test for model with multiple loaders * fix trainer integration * fix instance check * Train loaders (#4032) * patch for issues discussed in #1959, encapsulating underlying datastructures returned from train_dataloader * update data_loading.py to it uses patch discussed in #1959 * rename class * Separate CombinedLoaderIterator into two classes, and update related tests. (#4606) * Fix the bugs after rebasing. * Add custom get_len for apply_to_collection * Refactor MultiIterator to be as CombinedLoaderIterator * To get the right num_training_batches. Call the wrapper for multi trainloader in data_loading.py, instead of training_loop.py * Reload _loader_iters when calling __iter__ * Don't transform DataLoader to CombinedLoaderIterator when it's along * Updates test_fit_multiple_train_loaders for testing num_training_batches * Seperate CombinedLoaderIterator into CombinedLoaderIterator and CombinedDataLoader. Add CombinedDataset for unified DataLoader format. * Initialize CombinedDataLoader before calculating num_training_batches. Also updating self._worker_check for multiple loaders * Update tests for supporters * Update tests for multiple trainloaders. Add tests about few_workers for multiple loaders. * Fix pep8 issues * Add tests for train_loader_patch.py * Add descriptions to multiple_trainloader_mode * Remove unused variables * Add docstrings and typing * Add more tests for better converage * Remove unused commented codes * Add sampler property * Remove extract_dataset * Update typing * pep8 * Update train_loader_patch.py * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/trainer/supporters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * reviewer comments * fix stupid import * add docs * add back line separator * fix line sep * pep8 * Apply suggestions from code review * fix * fix * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Apply suggestions from code review Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> * flake8 Co-authored-by: Justus Schock <justusschock@justuss-mbp.fritz.box> Co-authored-by: Christofer Fransson <christofer_fransson@yahoo.com> Co-authored-by: YI-LIN SUNG <r06942076@ntu.edu.tw> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
2021-01-04 19:57:53 +00:00
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
return function(data, *args, **kwargs)
# Recursively apply to collection items
if isinstance(data, Mapping):
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
return elem_type({k: apply_to_collection(v, dtype, function, *args, **kwargs)
for k, v in data.items()})
if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
return elem_type(*(apply_to_collection(d, dtype, function, *args, **kwargs) for d in data))
if isinstance(data, Sequence) and not isinstance(data, str):
New metric classes (#1326) (#1877) * New metric classes (#1326) * Create metrics package * Create metric.py * Create utils.py * Create __init__.py * add tests for metric utils * add docstrings for metrics utils * add function to recursively apply other function to collection * add tests for this function * update test * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * update metric name * remove example docs * fix tests * add metric tests * fix to tensor conversion * fix apply to collection * Update CHANGELOG.md * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove tests from init * add missing type annotations * rename utils to convertors * Create metrics.rst * Update index.rst * Update index.rst * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/metrics/metric.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/utilities/test_apply_to_collection.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/metrics/convertors.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * add doctest example * rename file and fix imports * added parametrized test * replace lambda with inlined function * rename apply_to_collection to apply_func * Separated class description from init args * Apply suggestions from code review Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * adjust random values * suppress output when seeding * remove gpu from doctest * Add requested changes and add ellipsis for doctest * forgot to push these files... * add explicit check for dtype to convert to * fix ddp tests * remove explicit ddp destruction Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * move dtype device mixin to more general place * refactor to general device dtype mixin * add initial metric package description * change default to none for mac os * pep8 * fix import * Update index.rst * Update ci-testing.yml * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update CHANGELOG.md * Update pytorch_lightning/metrics/converters.py * readme * Update metric.py * Update pytorch_lightning/metrics/converters.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
2020-05-19 15:05:07 +00:00
return elem_type([apply_to_collection(d, dtype, function, *args, **kwargs) for d in data])
# data is neither of dtype, nor a collection
return data
class TransferableDataType(ABC):
"""
A custom type for data that can be moved to a torch device via `.to(...)`.
Example:
>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
True
>>> class CustomObject:
... def __init__(self):
... self.x = torch.rand(2, 2)
... def to(self, device):
... self.x = self.x.to(device)
... return self
>>> isinstance(CustomObject(), TransferableDataType)
True
"""
@classmethod
def __subclasshook__(cls, subclass):
if cls is TransferableDataType:
to = getattr(subclass, "to", None)
return callable(to)
return NotImplemented
def move_data_to_device(batch: Any, device: torch.device):
"""
Transfers a collection of data to the given device. Any object that defines a method
``to(device)`` will be moved and all other objects in the collection will be left untouched.
2021-01-26 01:21:00 +00:00
Args:
batch: A tensor or collection of tensors or anything that has a method `.to(...)`.
See :func:`apply_to_collection` for a list of supported collection types.
device: The device to which the data should be moved
2021-01-26 01:21:00 +00:00
Return:
the same collection but with all contained tensors residing on the new device.
2021-01-26 01:21:00 +00:00
See Also:
- :meth:`torch.Tensor.to`
- :class:`torch.device`
"""
Bugfix/torchtext include lengths (#2689) * Test using torchtext.data.Field with include_lengths=True/False * Fix issue that Tensors in a Batch generated by torchtext with torchtext.data.Field configured as include_lengths=True * Add description for fix of issue #2688 * changes to accomodate CodeFactor issues * Another attemt to make last CodeFactor issue pass (it's a false alarm) * temporarly disable test of test_grad_tracking to check if testing will pass * reenable test in test_grad_norm * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator as suggested by @borda * Update pytorch_lightning/utilities/apply_func.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * adding tests more specific to batch_move_data_to_device with tochtext Batch * added check that Tensors were moved to target device * removed tests using RNN models to be moved into a separate PR * fixing FLAKE8 errors that showed up after merge from master branch modified: tests/base/datamodules.py modified: tests/callbacks/test_model_checkpoint.py * parameterized test to reduce code duplication * Added check only if length tensor exist. Removed left over comments. * rearranged device parameterization and added pytest.param * Try to figure out why only one device is tested on Linux machines * Testing on CPU and GPU devices (GPU test is skip if no cuda device is available. * added test for TPU device (experimental) * Adding test parameterization for TPU test (experimental) * change import statement to limit what is imported for a TPU environment * made test work with TPU * Change to trigger CI * Change to trigger CI * uncommented TPU test to check CI * reenabling TPU test * small change to trigger CI build * small change to trigger CI build * small change to trigger CI build * adding tests/utilities/test_apply_func_torchtext.py to CI TPU test * try to make test not skipped on CI with TPU * remove testing on TPU * undo an accidental change to test_tpu.py (file should not have been touched) * small change to trigger CI build * small change to trigger CI build * Update tests/utilities/test_apply_func_torchtext.py * Revert to previous version * Apply suggestions from code review * Change to trigger CI Co-authored-by: Thomas Schaaf <tschaaf@mmm.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu>
2020-07-31 11:53:08 +00:00
def batch_to(data):
# try to move torchtext data first
if _TORCHTEXT_AVAILABLE and isinstance(data, Batch):
# Shallow copy because each Batch has a reference to Dataset which contains all examples
device_data = copy(data)
for field, field_value in data.dataset.fields.items():
if field_value is None:
continue
Bugfix/torchtext include lengths (#2689) * Test using torchtext.data.Field with include_lengths=True/False * Fix issue that Tensors in a Batch generated by torchtext with torchtext.data.Field configured as include_lengths=True * Add description for fix of issue #2688 * changes to accomodate CodeFactor issues * Another attemt to make last CodeFactor issue pass (it's a false alarm) * temporarly disable test of test_grad_tracking to check if testing will pass * reenable test in test_grad_norm * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator as suggested by @borda * Update pytorch_lightning/utilities/apply_func.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * adding tests more specific to batch_move_data_to_device with tochtext Batch * added check that Tensors were moved to target device * removed tests using RNN models to be moved into a separate PR * fixing FLAKE8 errors that showed up after merge from master branch modified: tests/base/datamodules.py modified: tests/callbacks/test_model_checkpoint.py * parameterized test to reduce code duplication * Added check only if length tensor exist. Removed left over comments. * rearranged device parameterization and added pytest.param * Try to figure out why only one device is tested on Linux machines * Testing on CPU and GPU devices (GPU test is skip if no cuda device is available. * added test for TPU device (experimental) * Adding test parameterization for TPU test (experimental) * change import statement to limit what is imported for a TPU environment * made test work with TPU * Change to trigger CI * Change to trigger CI * uncommented TPU test to check CI * reenabling TPU test * small change to trigger CI build * small change to trigger CI build * small change to trigger CI build * adding tests/utilities/test_apply_func_torchtext.py to CI TPU test * try to make test not skipped on CI with TPU * remove testing on TPU * undo an accidental change to test_tpu.py (file should not have been touched) * small change to trigger CI build * small change to trigger CI build * Update tests/utilities/test_apply_func_torchtext.py * Revert to previous version * Apply suggestions from code review * Change to trigger CI Co-authored-by: Thomas Schaaf <tschaaf@mmm.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu>
2020-07-31 11:53:08 +00:00
device_field = move_data_to_device(getattr(data, field), device)
setattr(device_data, field, device_field)
return device_data
Bugfix/torchtext include lengths (#2689) * Test using torchtext.data.Field with include_lengths=True/False * Fix issue that Tensors in a Batch generated by torchtext with torchtext.data.Field configured as include_lengths=True * Add description for fix of issue #2688 * changes to accomodate CodeFactor issues * Another attemt to make last CodeFactor issue pass (it's a false alarm) * temporarly disable test of test_grad_tracking to check if testing will pass * reenable test in test_grad_norm * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Renamed get_torchtext_data_iterator to _get_torchtext_data_iterator as suggested by @borda * Update pytorch_lightning/utilities/apply_func.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * adding tests more specific to batch_move_data_to_device with tochtext Batch * added check that Tensors were moved to target device * removed tests using RNN models to be moved into a separate PR * fixing FLAKE8 errors that showed up after merge from master branch modified: tests/base/datamodules.py modified: tests/callbacks/test_model_checkpoint.py * parameterized test to reduce code duplication * Added check only if length tensor exist. Removed left over comments. * rearranged device parameterization and added pytest.param * Try to figure out why only one device is tested on Linux machines * Testing on CPU and GPU devices (GPU test is skip if no cuda device is available. * added test for TPU device (experimental) * Adding test parameterization for TPU test (experimental) * change import statement to limit what is imported for a TPU environment * made test work with TPU * Change to trigger CI * Change to trigger CI * uncommented TPU test to check CI * reenabling TPU test * small change to trigger CI build * small change to trigger CI build * small change to trigger CI build * adding tests/utilities/test_apply_func_torchtext.py to CI TPU test * try to make test not skipped on CI with TPU * remove testing on TPU * undo an accidental change to test_tpu.py (file should not have been touched) * small change to trigger CI build * small change to trigger CI build * Update tests/utilities/test_apply_func_torchtext.py * Revert to previous version * Apply suggestions from code review * Change to trigger CI Co-authored-by: Thomas Schaaf <tschaaf@mmm.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Thomas Schaaf <tschaaf@cs.cmu.edu>
2020-07-31 11:53:08 +00:00
kwargs = dict(non_blocking=True) if isinstance(data, torch.Tensor) else {}
return data.to(device, **kwargs)
dtype = (TransferableDataType, Batch) if _TORCHTEXT_AVAILABLE else TransferableDataType
2020-09-11 14:55:58 +00:00
return apply_to_collection(batch, dtype=dtype, function=batch_to)
def convert_to_tensors(data, device: torch.device = None):
if device is None:
raise MisconfigurationException(
"device (torch.device) should be provided."
)
for src_dtype, conversion_func in CONVERSION_DTYPES:
data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device))
return data