Fix `accumulated_grad_batches` typehint (#9071)
* Fix `accumulated_grad_batches` typehint
This commit is contained in:
parent
1a2468f530
commit
376734a1e2
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, Union
|
||||
|
||||
from pytorch_lightning.callbacks import GradientAccumulationScheduler
|
||||
from pytorch_lightning.utilities import GradClipAlgorithmType
|
||||
|
@ -27,7 +27,7 @@ class TrainingTricksConnector:
|
|||
gradient_clip_val: float,
|
||||
gradient_clip_algorithm: str,
|
||||
track_grad_norm: Union[int, float, str],
|
||||
accumulate_grad_batches: Union[int, Dict[int, int], List[list]],
|
||||
accumulate_grad_batches: Union[int, Dict[int, int]],
|
||||
terminate_on_nan: bool,
|
||||
):
|
||||
|
||||
|
@ -48,7 +48,7 @@ class TrainingTricksConnector:
|
|||
self.trainer.accumulate_grad_batches = accumulate_grad_batches
|
||||
self.configure_accumulated_gradients(accumulate_grad_batches)
|
||||
|
||||
def configure_accumulated_gradients(self, accumulate_grad_batches):
|
||||
def configure_accumulated_gradients(self, accumulate_grad_batches: Union[int, Dict[int, int]]) -> None:
|
||||
if isinstance(accumulate_grad_batches, dict):
|
||||
self.trainer.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches)
|
||||
elif isinstance(accumulate_grad_batches, int):
|
||||
|
|
|
@ -123,7 +123,7 @@ class Trainer(
|
|||
track_grad_norm: Union[int, float, str] = -1,
|
||||
check_val_every_n_epoch: int = 1,
|
||||
fast_dev_run: Union[int, bool] = False,
|
||||
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
|
||||
accumulate_grad_batches: Union[int, Dict[int, int]] = 1,
|
||||
max_epochs: Optional[int] = None,
|
||||
min_epochs: Optional[int] = None,
|
||||
max_steps: Optional[int] = None,
|
||||
|
|
Loading…
Reference in New Issue