1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
| import torch from torch import nn,optim import torchvision.transforms as transforms import torchvision import pandas as pd import os from PIL import Image
num_classes = 12
cat_model_name = "cat_model.pth"
test_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.CenterCrop(196), transforms.ToTensor(), transforms.Normalize(mean =[0.4848, 0.4435, 0.4023], std=[0.2744, 0.2688, 0.2757]), ])
class myRes(nn.Module): def __init__(self): super(myRes, self).__init__()
self.resnet = torchvision.models.resnet50(pretrained=False) self.add_module('add_Linear', nn.Linear(1000, num_classes))
def forward(self, x): x = self.resnet(x) return x
def GetFiles(file_dir, file_type, IsCurrent=False): ''' 功能:获取指定文件路径&文件类型下的所有文件名 传入: file_dir 文件路径, file_type 文件类型, IsCurrent 是否只获取当前文件路径下的文件,默认False 返回:含文件名的列表 ''' file_list = [] for parent, dirnames, filenames in os.walk(file_dir): for filename in filenames: if filename.endswith(('.%s' % file_type)): file_list.append(filename) if IsCurrent == True: break return file_list
if __name__ == '__main__':
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = myRes().to(device) model.load_state_dict(torch.load(cat_model_name)) model.eval() class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11']
labels = []
names = GetFiles('data/cat_12_test','jpg') error_names = []
for i in names: img_path = 'data/cat_12_test/{}'.format(i) try: img = Image.open(img_path) img_ = test_transform(img).unsqueeze(0)
img_ = img_.to(device) outputs = model(img_)
_, indices = torch.max(outputs, 1) percentage = torch.nn.functional.softmax(outputs, dim=1)[0] * 100 perc = percentage[int(indices)].item() result = class_names[indices] labels.append(result) except: labels.append('5') error_names.append(img_path)
df1 = pd.DataFrame(names) df2 = pd.DataFrame(labels) df = pd.concat([df1, df2],axis=1,join='outer') df.to_csv('result.csv', header=False, index=False) print("Done!")
|