12猫分类学习笔记三(预测篇)

训练的代码集中在cat_test.py里。编译执行该文件即可进行预测。

基本上就是调用model.eval,然后进行格式整理,导出对应格式的csv即可。

cat_test.py 源代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch
from torch import nn,optim
import torchvision.transforms as transforms
import torchvision
import pandas as pd
import os
from PIL import Image
# 种类数
num_classes = 12 # 分类别数量
# 模型文件名称
cat_model_name = "cat_model.pth"

# 定义transform
test_transform = transforms.Compose([
transforms.Resize((256, 256)), # 缩放到指定大小
transforms.CenterCrop(196),
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean =[0.4848, 0.4435, 0.4023], std=[0.2744, 0.2688, 0.2757]), # 归一化
])


class myRes(nn.Module):
def __init__(self):
super(myRes, self).__init__()

self.resnet = torchvision.models.resnet50(pretrained=False)
self.add_module('add_Linear', nn.Linear(1000, num_classes))

def forward(self, x):
x = self.resnet(x)
return x


def GetFiles(file_dir, file_type, IsCurrent=False):
'''
功能:获取指定文件路径&文件类型下的所有文件名
传入:
file_dir 文件路径,
file_type 文件类型,
IsCurrent 是否只获取当前文件路径下的文件,默认False
返回:含文件名的列表
'''
file_list = []
for parent, dirnames, filenames in os.walk(file_dir):
for filename in filenames:
if filename.endswith(('.%s' % file_type)): # 判断文件类型
file_list.append(filename)
if IsCurrent == True:
break
return file_list


if __name__ == '__main__':

# 模型加载
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = myRes().to(device)
model.load_state_dict(torch.load(cat_model_name))
# 切换至预测模式
model.eval()
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11']

labels = []

names = GetFiles('data/cat_12_test','jpg')
error_names = []

for i in names:
img_path = 'data/cat_12_test/{}'.format(i)
try:
img = Image.open(img_path)
img_ = test_transform(img).unsqueeze(0)

img_ = img_.to(device)
outputs = model(img_)

_, indices = torch.max(outputs, 1)
percentage = torch.nn.functional.softmax(outputs, dim=1)[0] * 100
perc = percentage[int(indices)].item()
result = class_names[indices]
labels.append(result)
except:
labels.append('5')
error_names.append(img_path)


df1 = pd.DataFrame(names)
df2 = pd.DataFrame(labels)
df = pd.concat([df1, df2],axis=1,join='outer')
df.to_csv('result.csv', header=False, index=False)
print("Done!")