removed opt check
This commit is contained in:
parent
79c0054c38
commit
74d714f159
|
@ -1,22 +0,0 @@
|
|||
from torch import nn
|
||||
from torch import optim
|
||||
|
||||
|
||||
class OptimizerConfig(nn.Module):
|
||||
|
||||
def choose_optimizer(self, optimizer, params, optimizer_params, opt_name_key):
|
||||
if optimizer == 'adam':
|
||||
optimizer = optim.Adam(params, **optimizer_params)
|
||||
if optimizer == 'sparse_adam':
|
||||
optimizer = optim.SparseAdam(params, **optimizer_params)
|
||||
if optimizer == 'sgd':
|
||||
optimizer = optim.SGD(params, **optimizer_params)
|
||||
if optimizer == 'adadelta':
|
||||
optimizer = optim.Adadelta(params, **optimizer_params)
|
||||
|
||||
# transfer opt state if loaded
|
||||
if opt_name_key in self.loaded_optimizer_states_dict:
|
||||
state = self.loaded_optimizer_states_dict[opt_name_key]
|
||||
optimizer.load_state_dict(state)
|
||||
|
||||
return optimizer
|
|
@ -5,11 +5,10 @@ import math
|
|||
from pytorch_lightning.root_module.memory import ModelSummary
|
||||
from pytorch_lightning.root_module.grads import GradInformation
|
||||
from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv
|
||||
from pytorch_lightning.root_module.optimization import OptimizerConfig
|
||||
from pytorch_lightning.root_module.hooks import ModelHooks
|
||||
|
||||
|
||||
class LightningModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
|
||||
class LightningModule(GradInformation, ModelIO, ModelHooks):
|
||||
|
||||
def __init__(self, hparams):
|
||||
super(LightningModule, self).__init__()
|
||||
|
|
Loading…
Reference in New Issue