pytorch 使用resnet预训练模型提取特征

使用pytorch自带的预训练模型提取特征

1
2
3
4
5
6
7
import torch
import torch.nn as nn
import torchvision

resnet = torchvision.models.resnet50(pretrained=True)
resnet.fc = nn.Linear(2048, 2048, bias=False)
torch.nn.init.eye(resnet.fc.weight)

使用pytorch自带的预训练模型来训练自己的分类器

1
2
3
4
5
6
import torch
import torch.nn as nn
import torchvision

resnet = torchvision.models.resnet50(pretrained=True)
resnet.fc = nn.Linear(resnet.fc.in_features , num) # num是分类类别数