removed opt check

This commit is contained in:
William Falcon 2019-07-24 15:33:08 -04:00
parent 79c0054c38
commit 74d714f159
2 changed files with 1 additions and 24 deletions

View File

@ -1,22 +0,0 @@
from torch import nn
from torch import optim
class OptimizerConfig(nn.Module):
def choose_optimizer(self, optimizer, params, optimizer_params, opt_name_key):
if optimizer == 'adam':
optimizer = optim.Adam(params, **optimizer_params)
if optimizer == 'sparse_adam':
optimizer = optim.SparseAdam(params, **optimizer_params)
if optimizer == 'sgd':
optimizer = optim.SGD(params, **optimizer_params)
if optimizer == 'adadelta':
optimizer = optim.Adadelta(params, **optimizer_params)
# transfer opt state if loaded
if opt_name_key in self.loaded_optimizer_states_dict:
state = self.loaded_optimizer_states_dict[opt_name_key]
optimizer.load_state_dict(state)
return optimizer

View File

@ -5,11 +5,10 @@ import math
from pytorch_lightning.root_module.memory import ModelSummary
from pytorch_lightning.root_module.grads import GradInformation
from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.root_module.optimization import OptimizerConfig
from pytorch_lightning.root_module.hooks import ModelHooks
class LightningModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks):
class LightningModule(GradInformation, ModelIO, ModelHooks):
def __init__(self, hparams):
super(LightningModule, self).__init__()