Change the classifier input from 2048 to 1000. (#5232)

* Change the classifier input from 2048 to 1000.

* Update docs for Imagenet example

Thanks @rohitgr7

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
LaserBit 2021-01-05 21:09:52 +09:00 committed by GitHub
parent d5b367871f
commit a40e3a325e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 11 additions and 5 deletions

View File

@ -52,16 +52,22 @@ Example: Imagenet (computer Vision)
class ImagenetTransferLearning(LightningModule):
def __init__(self):
super().__init__()
# init a pretrained resnet
num_target_classes = 10
self.feature_extractor = models.resnet50(pretrained=True)
self.feature_extractor.eval()
backbone = models.resnet50(pretrained=True)
num_filters = backbone.fc.in_features
layers = list(backbone.children())[:-1]
self.feature_extractor = torch.nn.Sequential(*layers)
# use the pretrained model to classify cifar-10 (10 image classes)
self.classifier = nn.Linear(2048, num_target_classes)
num_target_classes = 10
self.classifier = nn.Linear(num_filters, num_target_classes)
def forward(self, x):
representations = self.feature_extractor(x)
self.feature_extractor.eval()
with torch.no_grad():
representations = self.feature_extractor(x).flatten(1)
x = self.classifier(representations)
...