removed decorators (#1079)
This commit is contained in:
parent
2bc01a00e8
commit
3d18099262
|
@ -101,7 +101,6 @@ train_dataloader (and val, train) code as follows.
|
|||
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
@pl.data_loader
|
||||
def train_dataloader(self):
|
||||
dataset = MNIST(
|
||||
os.getcwd(),
|
||||
|
|
|
@ -167,7 +167,6 @@ class GAN(LightningModule):
|
|||
opt_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr, betas=(b1, b2))
|
||||
return [opt_g, opt_d], []
|
||||
|
||||
@data_loader
|
||||
def train_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5])])
|
||||
|
|
|
@ -20,7 +20,6 @@ import torchvision.transforms as transforms
|
|||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.core import data_loader
|
||||
|
||||
# pull out resnet names from torchvision models
|
||||
MODEL_NAMES = sorted(
|
||||
|
@ -132,7 +131,6 @@ class ImageNetLightningModel(LightningModule):
|
|||
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
@data_loader
|
||||
def train_dataloader(self):
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
|
@ -163,7 +161,6 @@ class ImageNetLightningModel(LightningModule):
|
|||
)
|
||||
return train_loader
|
||||
|
||||
@data_loader
|
||||
def val_dataloader(self):
|
||||
normalize = transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
|
|
|
@ -11,7 +11,6 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.optim import Adam
|
||||
|
||||
from pytorch_lightning.core.decorators import data_loader
|
||||
from pytorch_lightning.core.grads import GradInformation
|
||||
from pytorch_lightning.core.hooks import ModelHooks
|
||||
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
|
||||
|
@ -1139,7 +1138,6 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
"""
|
||||
return None
|
||||
|
||||
@data_loader
|
||||
def tng_dataloader(self): # todo: remove in v1.0.0
|
||||
"""Implement a PyTorch DataLoader.
|
||||
|
||||
|
@ -1239,7 +1237,6 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
transform = transforms.Compose([transforms.ToTensor(),
|
||||
transforms.Normalize((0.5,), (1.0,))])
|
||||
|
@ -1254,7 +1251,6 @@ class LightningModule(ABC, GradInformation, ModelIO, ModelHooks):
|
|||
return loader
|
||||
|
||||
# can also return multiple dataloaders
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
return [loader_a, loader_b, ..., loader_n]
|
||||
|
||||
|
|
|
@ -41,14 +41,11 @@ class CoolModel(pl.LightningModule):
|
|||
def configure_optimizers(self):
|
||||
return [torch.optim.Adam(self.parameters(), lr=0.02)]
|
||||
|
||||
@pl.data_loader
|
||||
def train_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=True), batch_size=32)
|
||||
|
||||
@pl.data_loader
|
||||
def val_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||
|
||||
@pl.data_loader
|
||||
def test_dataloader(self):
|
||||
return DataLoader(MNIST('path/to/save', train=False), batch_size=32)
|
||||
|
|
Loading…
Reference in New Issue