diff --git a/README.md b/README.md index 775f6bc7bf..2245eba3aa 100644 --- a/README.md +++ b/README.md @@ -81,36 +81,40 @@ class CoolModel(pl.LightningModule): def forward(self, x): return torch.relu(self.l1(x.view(x.size(0), -1))) - def my_loss(self, y_hat, y): - return F.cross_entropy(y_hat, y) - def training_step(self, batch, batch_nb): + # REQUIRED x, y = batch y_hat = self.forward(x) - return {'loss': self.my_loss(y_hat, y)} + return {'loss': F.cross_entropy(y_hat, y)} def validation_step(self, batch, batch_nb): + # OPTIONAL x, y = batch y_hat = self.forward(x) return {'val_loss': self.my_loss(y_hat, y)} def validation_end(self, outputs): + # OPTIONAL avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() return {'avg_val_loss': avg_loss} def configure_optimizers(self): + # REQUIRED return [torch.optim.Adam(self.parameters(), lr=0.02)] @pl.data_loader def tng_dataloader(self): + # REQUIRED return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def val_dataloader(self): + # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) @pl.data_loader def test_dataloader(self): + # OPTIONAL return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) ``` diff --git a/docs/LightningModule/RequiredTrainerInterface.md b/docs/LightningModule/RequiredTrainerInterface.md index d0b5b6b3d9..677815b73a 100644 --- a/docs/LightningModule/RequiredTrainerInterface.md +++ b/docs/LightningModule/RequiredTrainerInterface.md @@ -10,16 +10,14 @@ Otherwise, to Define a Lightning Module, implement the following methods: **Required**: - [training_step](RequiredTrainerInterface.md#training_step) -- [validation_step](RequiredTrainerInterface.md#validation_step) -- [validation_end](RequiredTrainerInterface.md#validation_end) - +- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader) - [configure_optimizers](RequiredTrainerInterface.md#configure_optimizers) -- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader) -- [tng_dataloader](RequiredTrainerInterface.md#tng_dataloader) -- [test_dataloader](RequiredTrainerInterface.md#test_dataloader) - **Optional**: +- [validation_step](RequiredTrainerInterface.md#validation_step) +- [validation_end](RequiredTrainerInterface.md#validation_end) +- [val_dataloader](RequiredTrainerInterface.md#val_dataloader) +- [test_dataloader](RequiredTrainerInterface.md#test_dataloader) - [on_save_checkpoint](RequiredTrainerInterface.md#on_save_checkpoint) - [on_load_checkpoint](RequiredTrainerInterface.md#on_load_checkpoint) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index de8a0a03e1..3490131d43 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -13,6 +13,7 @@ from torch.utils.data.distributed import DistributedSampler import torch.multiprocessing as mp import torch.distributed as dist +from pytorch_lightning.root_module.root_module import LightningModule from pytorch_lightning.root_module.memory import get_gpu_memory_map from pytorch_lightning.root_module.model_saving import TrainerIO from pytorch_lightning.pt_overrides.override_data_parallel import ( @@ -312,6 +313,14 @@ class Trainer(TrainerIO): f_op = getattr(model, f_name, None) return callable(f_op) + def __is_overriden(self, f_name): + model = self.__get_model() + super_object = super(model.__class__, model) + + # when code pointers are different, it was overriden + is_overriden = getattr(model, f_name).__code__ is not getattr(super_object, f_name).__code__ + return is_overriden + @property def __tng_tqdm_dic(self): tqdm_dic = { @@ -345,13 +354,13 @@ class Trainer(TrainerIO): self.nb_tng_batches = int(self.nb_tng_batches * self.train_percent_check) # determine number of validation batches - self.nb_val_batches = len(self.val_dataloader) + self.nb_val_batches = len(self.val_dataloader) if self.val_dataloader is not None else 0 self.nb_val_batches = int(self.nb_val_batches * self.val_percent_check) self.nb_val_batches = max(1, self.nb_val_batches) self.nb_val_batches = self.nb_val_batches # determine number of test batches - self.nb_test_batches = len(self.test_dataloader) + self.nb_test_batches = len(self.test_dataloader) if self.test_dataloader is not None else 0 self.nb_test_batches = int(self.nb_test_batches * self.test_percent_check) # determine when to check validation @@ -372,6 +381,10 @@ class Trainer(TrainerIO): :param max_batches: Scalar :return: """ + # skip validation if model has no validation_step defined + if not self.__is_overriden('validation_step'): + return {} + # enable eval mode model.zero_grad() model.eval() @@ -418,11 +431,13 @@ class Trainer(TrainerIO): if self.progress_bar and self.prog_bar is not None: self.prog_bar.update(1) - # give model a chance to do something with the outputs - if self.data_parallel: - val_results = model.module.validation_end(outputs) - else: - val_results = model.validation_end(outputs) + # give model a chance to do something with the outputs (and method defined) + val_results = {} + if self.__is_overriden('validation_end'): + if self.data_parallel: + val_results = model.module.validation_end(outputs) + else: + val_results = model.validation_end(outputs) # enable train mode again model.train() @@ -439,6 +454,7 @@ class Trainer(TrainerIO): :return: """ self.tng_dataloader = model.tng_dataloader + self.test_dataloader = model.test_dataloader self.val_dataloader = model.val_dataloader diff --git a/pytorch_lightning/root_module/root_module.py b/pytorch_lightning/root_module/root_module.py index 700e6db14e..13cf46818f 100644 --- a/pytorch_lightning/root_module/root_module.py +++ b/pytorch_lightning/root_module/root_module.py @@ -36,18 +36,20 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): def validation_step(self, data_batch, batch_nb): """ return whatever outputs will need to be aggregated in validation_end + OPTIONAL :param data_batch: :return: """ - raise NotImplementedError + pass def validation_end(self, outputs): """ Outputs has the appended output after each validation step + OPTIONAL :param outputs: :return: dic_with_metrics for tqdm """ - raise NotImplementedError + pass def training_step(self, data_batch, batch_nb): """ @@ -67,7 +69,7 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): @data_loader def tng_dataloader(self): """ - Implement a function to load an h5py of this data + Implement a PyTorch DataLoader :return: """ raise NotImplementedError @@ -75,18 +77,18 @@ class LightningModule(GradInformation, ModelIO, ModelHooks): @data_loader def test_dataloader(self): """ - Implement a function to load an h5py of this data + Implement a PyTorch DataLoader :return: """ - raise NotImplementedError + return None @data_loader def val_dataloader(self): """ - Implement a function to load an h5py of this data + Implement a PyTorch DataLoader :return: """ - raise NotImplementedError + return None @classmethod def load_from_metrics(cls, weights_path, tags_csv, on_gpu, map_location=None): diff --git a/pytorch_lightning/testing/__init__.py b/pytorch_lightning/testing/__init__.py index e69de29bb2..b3289a1c71 100644 --- a/pytorch_lightning/testing/__init__.py +++ b/pytorch_lightning/testing/__init__.py @@ -0,0 +1,3 @@ +from .lm_test_module import LightningTestModel +from .no_val_end_module import NoValEndTestModel +from .no_val_module import NoValModel diff --git a/pytorch_lightning/testing/no_val_end_module.py b/pytorch_lightning/testing/no_val_end_module.py new file mode 100644 index 0000000000..3b42ab0256 --- /dev/null +++ b/pytorch_lightning/testing/no_val_end_module.py @@ -0,0 +1,247 @@ +import os +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import optim +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import MNIST +from torchvision import transforms +from test_tube import HyperOptArgumentParser + +from pytorch_lightning.root_module.root_module import LightningModule +from pytorch_lightning import data_loader + + +class NoValEndTestModel(LightningModule): + """ + Sample model to show how to define a template + """ + + def __init__(self, hparams, force_remove_distributed_sampler=False): + """ + Pass in parsed HyperOptArgumentParser to the model + :param hparams: + """ + # init superclass + super(NoValEndTestModel, self).__init__() + self.hparams = hparams + + self.batch_size = hparams.batch_size + + # if you specify an example input, the summary will show input/output for each layer + self.example_input_array = torch.rand(5, 28 * 28) + + # remove to test warning for dist sampler + self.force_remove_distributed_sampler = force_remove_distributed_sampler + + # build model + self.__build_model() + + # --------------------- + # MODEL SETUP + # --------------------- + def __build_model(self): + """ + Layout model + :return: + """ + self.c_d1 = nn.Linear(in_features=self.hparams.in_features, + out_features=self.hparams.hidden_dim) + self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) + self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) + + self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, + out_features=self.hparams.out_features) + + # --------------------- + # TRAINING + # --------------------- + def forward(self, x): + """ + No special modification required for lightning, define as you normally would + :param x: + :return: + """ + + x = self.c_d1(x) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + + x = self.c_d2(x) + logits = F.log_softmax(x, dim=1) + + return logits + + def loss(self, labels, logits): + nll = F.nll_loss(logits, labels) + return nll + + def training_step(self, data_batch, batch_i): + """ + Lightning calls this inside the training loop + :param data_batch: + :return: + """ + # forward pass + x, y = data_batch + x = x.view(x.size(0), -1) + + y_hat = self.forward(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val + + def validation_step(self, data_batch, batch_i): + """ + Lightning calls this inside the validation loop + :param data_batch: + :return: + """ + x, y = data_batch + x = x.view(x.size(0), -1) + y_hat = self.forward(x) + + loss_val = self.loss(y, y_hat) + + # acc + labels_hat = torch.argmax(y_hat, dim=1) + val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + val_acc = torch.tensor(val_acc) + + if self.on_gpu: + val_acc = val_acc.cuda(loss_val.device.index) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + val_acc = val_acc.unsqueeze(0) + + # alternate possible outputs to test + if batch_i % 1 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + }) + return output + if batch_i % 2 == 0: + return val_acc + + if batch_i % 3 == 0: + output = OrderedDict({ + 'val_loss': loss_val, + 'val_acc': val_acc, + 'test_dic': {'val_loss_a': loss_val} + }) + return output + + def on_tng_metrics(self, logs): + logs['some_tensor_to_test'] = torch.rand(1) + + # --------------------- + # TRAINING SETUP + # --------------------- + def configure_optimizers(self): + """ + return whatever optimizers we want here + :return: list of optimizers + """ + # try no scheduler for this model (testing purposes) + optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + # test returning only 1 list instead of 2 + return [optimizer] + + def __dataloader(self, train): + # init data generators + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root=self.hparams.data_root, train=train, + transform=transform, download=True) + + # when using multi-node we need to add the datasampler + train_sampler = None + batch_size = self.hparams.batch_size + + try: + if self.on_gpu and not self.force_remove_distributed_sampler: + train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) + batch_size = batch_size // self.trainer.world_size # scale batch size + except Exception: + pass + + should_shuffle = train_sampler is None + loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=should_shuffle, + sampler=train_sampler + ) + + return loader + + @data_loader + def tng_dataloader(self): + return self.__dataloader(train=True) + + @data_loader + def val_dataloader(self): + return self.__dataloader(train=False) + + @data_loader + def test_dataloader(self): + return self.__dataloader(train=False) + + @staticmethod + def add_model_specific_args(parent_parser, root_dir): # pragma: no cover + """ + Parameters you define here will be available to your model through self.hparams + :param parent_parser: + :param root_dir: + :return: + """ + parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) + + # param overwrites + # parser.set_defaults(gradient_clip=5.0) + + # network params + parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) + parser.add_argument('--in_features', default=28 * 28, type=int) + parser.add_argument('--out_features', default=10, type=int) + # use 500 for CPU, 50000 for GPU to see speed difference + parser.add_argument('--hidden_dim', default=50000, type=int) + + # data + parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str) + + # training params (opt) + parser.opt_list('--learning_rate', default=0.001 * 8, type=float, + options=[0.0001, 0.0005, 0.001, 0.005], + tunable=False) + parser.opt_list('--optimizer_name', default='adam', type=str, + options=['adam'], tunable=False) + + # if using 2 nodes with 4 gpus each the batch size here + # (256) will be 256 / (2*8) = 16 per gpu + parser.opt_list('--batch_size', default=256 * 8, type=int, + options=[32, 64, 128, 256], tunable=False, + help='batch size will be divided over all gpus being used across all nodes') + return parser diff --git a/pytorch_lightning/testing/no_val_module.py b/pytorch_lightning/testing/no_val_module.py new file mode 100644 index 0000000000..029bc44769 --- /dev/null +++ b/pytorch_lightning/testing/no_val_module.py @@ -0,0 +1,196 @@ +import os +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import optim +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torchvision.datasets import MNIST +from torchvision import transforms +from test_tube import HyperOptArgumentParser + +from pytorch_lightning.root_module.root_module import LightningModule +from pytorch_lightning import data_loader + + +class NoValModel(LightningModule): + """ + Sample model to show how to define a template + """ + + def __init__(self, hparams, force_remove_distributed_sampler=False): + """ + Pass in parsed HyperOptArgumentParser to the model + :param hparams: + """ + # init superclass + super(NoValModel, self).__init__() + self.hparams = hparams + + self.batch_size = hparams.batch_size + + # if you specify an example input, the summary will show input/output for each layer + self.example_input_array = torch.rand(5, 28 * 28) + + # remove to test warning for dist sampler + self.force_remove_distributed_sampler = force_remove_distributed_sampler + + # build model + self.__build_model() + + # --------------------- + # MODEL SETUP + # --------------------- + def __build_model(self): + """ + Layout model + :return: + """ + self.c_d1 = nn.Linear(in_features=self.hparams.in_features, + out_features=self.hparams.hidden_dim) + self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim) + self.c_d1_drop = nn.Dropout(self.hparams.drop_prob) + + self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim, + out_features=self.hparams.out_features) + + # --------------------- + # TRAINING + # --------------------- + def forward(self, x): + """ + No special modification required for lightning, define as you normally would + :param x: + :return: + """ + + x = self.c_d1(x) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + + x = self.c_d2(x) + logits = F.log_softmax(x, dim=1) + + return logits + + def loss(self, labels, logits): + nll = F.nll_loss(logits, labels) + return nll + + def training_step(self, data_batch, batch_i): + """ + Lightning calls this inside the training loop + :param data_batch: + :return: + """ + # forward pass + x, y = data_batch + x = x.view(x.size(0), -1) + + y_hat = self.forward(x) + + # calculate loss + loss_val = self.loss(y, y_hat) + + # in DP mode (default) make sure if result is scalar, there's another dim in the beginning + if self.trainer.use_dp: + loss_val = loss_val.unsqueeze(0) + + # alternate possible outputs to test + if self.trainer.batch_nb % 1 == 0: + output = OrderedDict({ + 'loss': loss_val, + 'prog': {'some_val': loss_val * loss_val} + }) + return output + if self.trainer.batch_nb % 2 == 0: + return loss_val + + def on_tng_metrics(self, logs): + logs['some_tensor_to_test'] = torch.rand(1) + + # --------------------- + # TRAINING SETUP + # --------------------- + def configure_optimizers(self): + """ + return whatever optimizers we want here + :return: list of optimizers + """ + # try no scheduler for this model (testing purposes) + optimizer = optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + + # test returning only 1 list instead of 2 + return [optimizer] + + def __dataloader(self, train): + # init data generators + transform = transforms.Compose([transforms.ToTensor(), + transforms.Normalize((0.5,), (1.0,))]) + dataset = MNIST(root=self.hparams.data_root, train=train, + transform=transform, download=True) + + # when using multi-node we need to add the datasampler + train_sampler = None + batch_size = self.hparams.batch_size + + try: + if self.on_gpu and not self.force_remove_distributed_sampler: + train_sampler = DistributedSampler(dataset, rank=self.trainer.proc_rank) + batch_size = batch_size // self.trainer.world_size # scale batch size + except Exception: + pass + + should_shuffle = train_sampler is None + loader = DataLoader( + dataset=dataset, + batch_size=batch_size, + shuffle=should_shuffle, + sampler=train_sampler + ) + + return loader + + @data_loader + def tng_dataloader(self): + return self.__dataloader(train=True) + + @staticmethod + def add_model_specific_args(parent_parser, root_dir): # pragma: no cover + """ + Parameters you define here will be available to your model through self.hparams + :param parent_parser: + :param root_dir: + :return: + """ + parser = HyperOptArgumentParser(strategy=parent_parser.strategy, parents=[parent_parser]) + + # param overwrites + # parser.set_defaults(gradient_clip=5.0) + + # network params + parser.opt_list('--drop_prob', default=0.2, options=[0.2, 0.5], type=float, tunable=False) + parser.add_argument('--in_features', default=28 * 28, type=int) + parser.add_argument('--out_features', default=10, type=int) + # use 500 for CPU, 50000 for GPU to see speed difference + parser.add_argument('--hidden_dim', default=50000, type=int) + + # data + parser.add_argument('--data_root', default=os.path.join(root_dir, 'mnist'), type=str) + + # training params (opt) + parser.opt_list('--learning_rate', default=0.001 * 8, type=float, + options=[0.0001, 0.0005, 0.001, 0.005], + tunable=False) + parser.opt_list('--optimizer_name', default='adam', type=str, + options=['adam'], tunable=False) + + # if using 2 nodes with 4 gpus each the batch size here + # (256) will be 256 / (2*8) = 16 per gpu + parser.opt_list('--batch_size', default=256 * 8, type=int, + options=[32, 64, 128, 256], tunable=False, + help='batch size will be divided over all gpus being used across all nodes') + return parser diff --git a/tests/test_models.py b/tests/test_models.py index 896eb88490..ac0b68603f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -10,7 +10,7 @@ from test_tube import Experiment, SlurmCluster # sys.path += [os.path.abspath('..'), os.path.abspath('../..')] from pytorch_lightning import Trainer -from pytorch_lightning.testing.lm_test_module import LightningTestModel +from pytorch_lightning.testing import LightningTestModel, NoValEndTestModel, NoValModel from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping from pytorch_lightning.utilities.debugging import MisconfigurationException from pytorch_lightning.root_module import memory @@ -26,6 +26,122 @@ np.random.seed(SEED) # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ + +def test_early_stopping_cpu_model(): + """ + Test each of the trainer options + :return: + """ + + stopping = EarlyStopping(monitor='val_loss') + trainer_options = dict( + early_stop_callback=stopping, + gradient_clip=1.0, + overfit_pct=0.20, + track_grad_norm=2, + print_nan_grads=True, + progress_bar=False, + experiment=get_exp(), + train_percent_check=0.1, + val_percent_check=0.1 + ) + + model, hparams = get_model() + run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) + + # test freeze on cpu + model.freeze() + model.unfreeze() + + +def test_no_val_module(): + """ + Tests use case where trainer saves the model, and user loads it from tags independently + :return: + """ + hparams = get_hparams() + model = NoValModel(hparams) + + save_dir = init_save_dir() + + # exp file to get meta + exp = get_exp(False) + exp.argparse(hparams) + exp.save() + + trainer_options = dict( + max_nb_epochs=1, + cluster=SlurmCluster(), + experiment=exp, + checkpoint_callback=ModelCheckpoint(save_dir) + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # traning complete + assert result == 1, 'amp + ddp model failed to complete' + + # save model + new_weights_path = os.path.join(save_dir, 'save_test.ckpt') + trainer.save_checkpoint(new_weights_path) + + # load new model + tags_path = exp.get_data_path(exp.name, exp.version) + tags_path = os.path.join(tags_path, 'meta_tags.csv') + model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, + tags_csv=tags_path, on_gpu=False) + model_2.eval() + + # make prediction + clear_save_dir() + + +def test_no_val_end_module(): + """ + Tests use case where trainer saves the model, and user loads it from tags independently + :return: + """ + hparams = get_hparams() + model = NoValEndTestModel(hparams) + + save_dir = init_save_dir() + + # exp file to get meta + exp = get_exp(False) + exp.argparse(hparams) + exp.save() + + trainer_options = dict( + max_nb_epochs=1, + cluster=SlurmCluster(), + experiment=exp, + checkpoint_callback=ModelCheckpoint(save_dir) + ) + + # fit model + trainer = Trainer(**trainer_options) + result = trainer.fit(model) + + # traning complete + assert result == 1, 'amp + ddp model failed to complete' + + # save model + new_weights_path = os.path.join(save_dir, 'save_test.ckpt') + trainer.save_checkpoint(new_weights_path) + + # load new model + tags_path = exp.get_data_path(exp.name, exp.version) + tags_path = os.path.join(tags_path, 'meta_tags.csv') + model_2 = LightningTestModel.load_from_metrics(weights_path=new_weights_path, + tags_csv=tags_path, on_gpu=False) + model_2.eval() + + # make prediction + clear_save_dir() + + def test_simple_cpu(): """ Verify continue training session on CPU @@ -445,33 +561,6 @@ def test_amp_gpu_ddp_slurm_managed(): clear_save_dir() -def test_early_stopping_cpu_model(): - """ - Test each of the trainer options - :return: - """ - - stopping = EarlyStopping() - trainer_options = dict( - early_stop_callback=stopping, - gradient_clip=1.0, - overfit_pct=0.20, - track_grad_norm=2, - print_nan_grads=True, - progress_bar=False, - experiment=get_exp(), - train_percent_check=0.1, - val_percent_check=0.1 - ) - - model, hparams = get_model() - run_gpu_model_test(trainer_options, model, hparams, on_gpu=False) - - # test freeze on cpu - model.freeze() - model.unfreeze() - - def test_cpu_model_with_amp(): """ Make sure model trains on CPU @@ -525,6 +614,7 @@ def test_all_features_cpu_model(): print_nan_grads=True, progress_bar=False, experiment=get_exp(), + accumulate_grad_batches=2, max_nb_epochs=1, train_percent_check=0.4, val_percent_check=0.4