2021-08-26 08:36:21 +00:00
|
|
|
# Copyright The PyTorch Lightning team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
from abc import ABC, abstractmethod
|
2021-09-14 13:48:27 +00:00
|
|
|
from dataclasses import dataclass
|
2021-09-15 12:18:19 +00:00
|
|
|
from typing import Any, Dict, Generic, Optional, TypeVar
|
2021-08-26 08:36:21 +00:00
|
|
|
|
2021-09-15 12:18:19 +00:00
|
|
|
from torch import Tensor
|
|
|
|
|
|
|
|
from pytorch_lightning.utilities import rank_zero_deprecation
|
|
|
|
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
2021-09-08 10:24:57 +00:00
|
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
2021-09-10 11:40:20 +00:00
|
|
|
|
2021-09-14 13:48:27 +00:00
|
|
|
T = TypeVar("T")
|
2021-09-10 11:40:20 +00:00
|
|
|
|
|
|
|
|
2021-09-14 13:48:27 +00:00
|
|
|
@dataclass
|
|
|
|
class OutputResult:
|
2021-09-15 12:18:19 +00:00
|
|
|
@staticmethod
|
|
|
|
def _check_extra_detach_deprecation(extra: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
# TODO: remove with the deprecation removal in v1.6
|
|
|
|
# this is only here to avoid duplication
|
|
|
|
def check_fn(v: Tensor) -> Tensor:
|
|
|
|
if v.grad_fn is not None:
|
|
|
|
rank_zero_deprecation(
|
|
|
|
f"One of the returned values {set(extra.keys())} has a `grad_fn`. We will detach it automatically"
|
|
|
|
" but this behaviour will change in v1.6. Please detach it manually:"
|
|
|
|
" `return {'loss': ..., 'something': something.detach()}`"
|
|
|
|
)
|
|
|
|
return v.detach()
|
|
|
|
return v
|
|
|
|
|
|
|
|
return apply_to_collection(extra, Tensor, check_fn)
|
|
|
|
|
|
|
|
def asdict(self) -> Dict[str, Any]:
|
|
|
|
raise NotImplementedError
|
2021-08-26 08:36:21 +00:00
|
|
|
|
|
|
|
|
2021-09-14 13:48:27 +00:00
|
|
|
class AbstractClosure(ABC, Generic[T]):
|
2021-09-06 12:49:09 +00:00
|
|
|
"""Abstract base class for optimizer closures in Lightning.
|
2021-08-26 08:36:21 +00:00
|
|
|
|
|
|
|
Formally, a closure is binding variables from an external scope to a function that does a computation on these
|
|
|
|
variables without taking them explicitly as input. This has the benefit that a closure can be passed to an
|
|
|
|
object which later can call it like a function but without requiring to pass in any arguments.
|
|
|
|
|
|
|
|
This class provides a simple abstraction making the instance of this class callable like a function while capturing
|
2021-09-15 12:18:19 +00:00
|
|
|
the closure result and caching it.
|
2021-08-26 08:36:21 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
super().__init__()
|
2021-09-14 13:48:27 +00:00
|
|
|
self._result: Optional[T] = None
|
2021-08-26 08:36:21 +00:00
|
|
|
|
2021-09-14 13:48:27 +00:00
|
|
|
def consume_result(self) -> T:
|
2021-09-06 12:49:09 +00:00
|
|
|
"""The cached result from the last time the closure was called.
|
|
|
|
|
|
|
|
Once accessed, the internal reference gets reset and the consumer will have to hold on to the reference as long
|
|
|
|
as necessary.
|
|
|
|
"""
|
2021-09-08 10:24:57 +00:00
|
|
|
if self._result is None:
|
|
|
|
raise MisconfigurationException(
|
|
|
|
"The closure hasn't been executed."
|
|
|
|
" HINT: did you call `optimizer_closure()` in your `optimizer_step` hook? It could also happen because"
|
|
|
|
" the `optimizer.step(optimizer_closure)` call did not execute it internally."
|
|
|
|
)
|
|
|
|
result, self._result = self._result, None # free memory
|
2021-08-26 08:36:21 +00:00
|
|
|
return result
|
|
|
|
|
|
|
|
@abstractmethod
|
2021-09-14 13:48:27 +00:00
|
|
|
def closure(self, *args: Any, **kwargs: Any) -> T:
|
2021-08-26 08:36:21 +00:00
|
|
|
"""Implements the behavior of the closure once it is getting called."""
|
|
|
|
pass
|
|
|
|
|
2021-09-14 13:48:27 +00:00
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
2021-08-26 08:36:21 +00:00
|
|
|
self._result = self.closure(*args, **kwargs)
|
2021-09-14 13:48:27 +00:00
|
|
|
return self
|