Update accelerator.py (#7318)
This commit is contained in:
parent
badd0bba30
commit
39274273a4
|
@ -13,7 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import contextlib
|
import contextlib
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union
|
from typing import Any, Callable, DefaultDict, Dict, Generator, Iterable, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -114,7 +114,7 @@ class Accelerator:
|
||||||
def _move_optimizer_state(self) -> None:
|
def _move_optimizer_state(self) -> None:
|
||||||
""" Moves the state of the optimizers to the GPU if needed. """
|
""" Moves the state of the optimizers to the GPU if needed. """
|
||||||
for opt in self.optimizers:
|
for opt in self.optimizers:
|
||||||
state = defaultdict(dict)
|
state: DefaultDict = defaultdict(dict)
|
||||||
for p, v in opt.state.items():
|
for p, v in opt.state.items():
|
||||||
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
|
state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, self.root_device)
|
||||||
opt.state = state
|
opt.state = state
|
||||||
|
|
Loading…
Reference in New Issue