From 74d714f1597240df6019990984ba6ef6f5a1f80a Mon Sep 17 00:00:00 2001 From: William Falcon Date: Wed, 24 Jul 2019 15:33:08 -0400 Subject: [PATCH] removed opt check --- pytorch_lightning/root_module/optimization.py | 22 ------------------- pytorch_lightning/root_module/root_module.py | 3 +-- 2 files changed, 1 insertion(+), 24 deletions(-) delete mode 100644 pytorch_lightning/root_module/optimization.py diff --git a/pytorch_lightning/root_module/optimization.py b/pytorch_lightning/root_module/optimization.py deleted file mode 100644 index 3172e1a1e6..0000000000 --- a/pytorch_lightning/root_module/optimization.py +++ /dev/null @@ -1,22 +0,0 @@ -from torch import nn -from torch import optim - - -class OptimizerConfig(nn.Module): - - def choose_optimizer(self, optimizer, params, optimizer_params, opt_name_key): - if optimizer == 'adam': - optimizer = optim.Adam(params, **optimizer_params) - if optimizer == 'sparse_adam': - optimizer = optim.SparseAdam(params, **optimizer_params) - if optimizer == 'sgd': - optimizer = optim.SGD(params, **optimizer_params) - if optimizer == 'adadelta': - optimizer = optim.Adadelta(params, **optimizer_params) - - # transfer opt state if loaded - if opt_name_key in self.loaded_optimizer_states_dict: - state = self.loaded_optimizer_states_dict[opt_name_key] - optimizer.load_state_dict(state) - - return optimizer diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index e2013d7a8e..2345997b00 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -5,11 +5,10 @@ import math from pytorch_lightning.root_module.memory import ModelSummary from pytorch_lightning.root_module.grads import GradInformation from pytorch_lightning.root_module.model_saving import ModelIO, load_hparams_from_tags_csv -from pytorch_lightning.root_module.optimization import OptimizerConfig from pytorch_lightning.root_module.hooks import ModelHooks -class LightningModule(GradInformation, ModelIO, OptimizerConfig, ModelHooks): +class LightningModule(GradInformation, ModelIO, ModelHooks): def __init__(self, hparams): super(LightningModule, self).__init__()