Fix `accumulated_grad_batches` typehint (#9071)
* Fix `accumulated_grad_batches` typehint
This commit is contained in:
parent
1a2468f530
commit
376734a1e2
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Dict, List, Union
|
from typing import Dict, Union
|
||||||
|
|
||||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||||
|
@ -27,7 +27,7 @@ class TrainingTricksConnector:
|
||||||
gradient_clip_val: float,
|
gradient_clip_val: float,
|
||||||
gradient_clip_algorithm: str,
|
gradient_clip_algorithm: str,
|
||||||
track_grad_norm: Union[int, float, str],
|
track_grad_norm: Union[int, float, str],
|
||||||
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
|
accumulate_grad_batches: Union[int, Dict[int, int]],
|
||||||
terminate_on_nan: bool,
|
terminate_on_nan: bool,
|
||||||
):
|
):
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ class TrainingTricksConnector:
|
||||||
self.trainer.accumulate_grad_batches = accumulate_grad_batches
|
self.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||||
self.configure_accumulated_gradients(accumulate_grad_batches)
|
self.configure_accumulated_gradients(accumulate_grad_batches)
|
||||||
|
|
||||||
def configure_accumulated_gradients(self, accumulate_grad_batches):
|
def configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None:
|
||||||
if isinstance(accumulate_grad_batches, dict):
|
if isinstance(accumulate_grad_batches, dict):
|
||||||
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
||||||
elif isinstance(accumulate_grad_batches, int):
|
elif isinstance(accumulate_grad_batches, int):
|
||||||
|
|
|
@ -123,7 +123,7 @@ class Trainer(
|
||||||
track_grad_norm: Union[int, float, str] = -1,
|
track_grad_norm: Union[int, float, str] = -1,
|
||||||
check_val_every_n_epoch: int = 1,
|
check_val_every_n_epoch: int = 1,
|
||||||
fast_dev_run: Union[int, bool] = False,
|
fast_dev_run: Union[int, bool] = False,
|
||||||
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
|
accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
|
||||||
max_epochs: Optional[int] = None,
|
max_epochs: Optional[int] = None,
|
||||||
min_epochs: Optional[int] = None,
|
min_epochs: Optional[int] = None,
|
||||||
max_steps: Optional[int] = None,
|
max_steps: Optional[int] = None,
|
||||||
|
|
Loading…
Reference in New Issue