Fix `accumulated_grad_batches` typehint (#9071)

* Fix `accumulated_grad_batches` typehint
This commit is contained in:
ananthsub 2021-08-24 10:12:36 -07:00 committed by GitHub
parent 1a2468f530
commit 376734a1e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 4 deletions

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Union
from typing import Dict, Union
from pytorch_lightning.callbacks import GradientAccumulationScheduler
from pytorch_lightning.utilities import GradClipAlgorithmType
@ -27,7 +27,7 @@ class TrainingTricksConnector:
gradient_clip_val: float,
gradient_clip_algorithm: str,
track_grad_norm: Union[int, float, str],
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
accumulate_grad_batches: Union[int, Dict[int, int]],
terminate_on_nan: bool,
):
@ -48,7 +48,7 @@ class TrainingTricksConnector:
self.trainer.accumulate_grad_batches = accumulate_grad_batches
self.configure_accumulated_gradients(accumulate_grad_batches)
def configure_accumulated_gradients(self, accumulate_grad_batches):
def configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None:
if isinstance(accumulate_grad_batches, dict):
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
elif isinstance(accumulate_grad_batches, int):

View File

@ -123,7 +123,7 @@ class Trainer(
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: Optional[int] = None,