Change the classifier input from 2048 to 1000. (#5232)
* Change the classifier input from 2048 to 1000. * Update docs for Imagenet example Thanks @rohitgr7 * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
d5b367871f
commit
a40e3a325e
|
@ -52,16 +52,22 @@ Example: Imagenet (computer Vision)
|
|||
|
||||
class ImagenetTransferLearning(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# init a pretrained resnet
|
||||
num_target_classes = 10
|
||||
self.feature_extractor = models.resnet50(pretrained=True)
|
||||
self.feature_extractor.eval()
|
||||
backbone = models.resnet50(pretrained=True)
|
||||
num_filters = backbone.fc.in_features
|
||||
layers = list(backbone.children())[:-1]
|
||||
self.feature_extractor = torch.nn.Sequential(*layers)
|
||||
|
||||
# use the pretrained model to classify cifar-10 (10 image classes)
|
||||
self.classifier = nn.Linear(2048, num_target_classes)
|
||||
num_target_classes = 10
|
||||
self.classifier = nn.Linear(num_filters, num_target_classes)
|
||||
|
||||
def forward(self, x):
|
||||
representations = self.feature_extractor(x)
|
||||
self.feature_extractor.eval()
|
||||
with torch.no_grad():
|
||||
representations = self.feature_extractor(x).flatten(1)
|
||||
x = self.classifier(representations)
|
||||
...
|
||||
|
||||
|
|
Loading…
Reference in New Issue