default sched (#6062)
This commit is contained in:
parent
8f82823a08
commit
5d6a091531
|
@ -26,6 +26,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
|
||||
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
|
||||
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
from pytorch_lightning.utilities.distributed import rank_zero_info, rank_zero_only
|
||||
|
@ -240,16 +241,8 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
)
|
||||
|
||||
def configure_scheduler(self, lr_scheduler):
|
||||
# this duplicates the defaults from init_optimizers
|
||||
scheduler = {
|
||||
'scheduler': lr_scheduler,
|
||||
'name': None, # no custom name
|
||||
'interval': 'epoch', # after epoch is over
|
||||
'frequency': 1, # every epoch/batch
|
||||
'reduce_on_plateau': False, # most often not ReduceLROnPlateau scheduler
|
||||
'monitor': None, # value to monitor for ReduceLROnPlateau
|
||||
'strict': True, # enforce that the monitor exists for ReduceLROnPlateau
|
||||
}
|
||||
scheduler = _get_default_scheduler_config()
|
||||
scheduler["scheduler"] = lr_scheduler
|
||||
return [scheduler]
|
||||
|
||||
@property
|
||||
|
|
Loading…
Reference in New Issue