-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPaper_testForCIfar10.py
229 lines (182 loc) · 9.6 KB
/
Paper_testForCIfar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import torch
from Paper_Tree import SequentialDecisionTree,SequentialDecisionTreeForRDNet
from Paper_global_vars import global_vars
from Paper_DataSetCIFAR import create_train_loader,create_valid_loader
import os
import shutil
import numpy as np
import random
import matplotlib.pyplot as plt
from collections import defaultdict
import torch.nn.functional as F
from multiprocessing import freeze_support
if __name__ == '__main__':
freeze_support()
# 设置随机种子
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
root = os.path.join(os.path.dirname(__file__), "CIFAR10RawData")
Picture_save_path = os.path.join(os.path.dirname(__file__), "Picture_save")
def clear_directory(directory):
# 创建文件夹
if not os.path.exists(directory):
os.makedirs(directory)
# 删除文件夹中的所有文件
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print(f'Failed to delete {file_path}. Reason: {e}')
clear_directory(Picture_save_path)
# 初始化模型并移至GPU
model = SequentialDecisionTreeForRDNet(isTest=True).to(device)
# 加载模型
model_path = 'checkpoint_epoch_92_acc_0.9900.pth'
checkpoint = torch.load(model_path, map_location=device)
# 移除状态字典中的 'module.' 前缀
state_dict = checkpoint['model_state']
new_state_dict = {}
for key, value in state_dict.items():
if key.startswith('module._orig_mod.'):
new_key = key[len('module._orig_mod.'):] # 移除 'module.' 前缀
new_state_dict[new_key] = value
else:
new_state_dict[key] = value
#加载修改后的状态字典
model.load_state_dict(new_state_dict)
# 测试代码
model.eval()
total_correct = 0
confusion_dict = defaultdict(int)
class_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
sample_count = 0
fig = plt.figure(figsize=(30, 15))
gs = fig.add_gridspec(3, 3, width_ratios=[1, 2, 2])
axes = [[fig.add_subplot(gs[i, j]) for j in range(3)] for i in range(3)]
fig.tight_layout(pad=5.0)
valid_data = create_valid_loader()
with torch.no_grad():
for batch_idx, (data, target) in enumerate(valid_data):
data, target = data.to(device), target.to(device)
outputs = model(data)
for idx, true_label in enumerate(target):
predicted_probs = outputs[idx]
predicted_label = predicted_probs.argmax().item()
is_correct = predicted_label == true_label.item()
total_correct += is_correct
if not is_correct:
confusion_pair = (class_labels[true_label.item()], class_labels[predicted_label])
confusion_dict[confusion_pair] += 1
if not is_correct:
row = sample_count % 3
# Raw Image
img = data[idx].cpu().permute(1, 2, 0).numpy()
img = (img - img.min()) / (img.max() - img.min())
axes[row][0].imshow(img)
axes[row][0].set_title('Raw Image', fontsize=20)
axes[row][0].axis('off')
# Probability Distribution
probs_np = predicted_probs.cpu().numpy()
axes[row][1].bar(range(len(class_labels)), probs_np)
axes[row][1].set_title('Probability Distribution', fontsize=20)
axes[row][1].set_ylim(0, 1)
axes[row][1].tick_params(axis='y', labelsize=16)
axes[row][1].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) # Remove x-axis ticks
# Node Probabilities
node_names, all_probs = [], []
def traverse_tree(nodes, input_data):
for i, node in enumerate(nodes):
node_names.append(f"Node {i+1}")
with torch.no_grad():
outputs = node(input_data)
all_probs.append(outputs[idx].cpu().numpy())
traverse_tree(model.nodes, data)
x = range(len(node_names))
num_outputs = max(len(probs) for probs in all_probs)
width = 0.8 / num_outputs
axes[row][2].clear() # Clear the existing subplot
for i in range(num_outputs):
probs = [node_probs[i] if i < len(node_probs) else 0 for node_probs in all_probs]
axes[row][2].bar([pos + i * width for pos in x], probs, width, label=f'Output {i+1}')
axes[row][2].set_title('Node Probabilities', fontsize=20)
axes[row][2].set_ylim(0, 1)
axes[row][2].tick_params(axis='y', labelsize=16)
axes[row][2].tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False) # Remove x-axis ticks
axes[row][2].legend(fontsize=12, loc='upper right')
axes[row][0].set_title(f'Batch {batch_idx}, Sample {idx}\nTrue: {class_labels[true_label.item()]}, Pred: {class_labels[predicted_label]}', fontsize=20)
sample_count += 1
if sample_count % 3 == 0 or batch_idx == len(valid_data) - 1:
plt.savefig(f'{Picture_save_path}/combined_analysis_{sample_count//3}.png', bbox_inches='tight', dpi=300)
plt.close(fig)
if batch_idx < len(valid_data) - 1:
fig = plt.figure(figsize=(30, 15))
gs = fig.add_gridspec(3, 3, width_ratios=[1, 2, 2])
axes = [[fig.add_subplot(gs[i, j]) for j in range(3)] for i in range(3)]
fig.tight_layout(pad=5.0)
accuracy = total_correct / len(data)
print(f"Test Accuracy: {accuracy:.4f}\{total_correct}\{len(data)}")
accuracy = total_correct / len(valid_data.dataset)
print(f"Test Accuracy: {accuracy:.4f}")
# Create and save pie chart
total_confusions = sum(confusion_dict.values())
confusion_percentages = {k: v / total_confusions * 100 for k, v in confusion_dict.items()}
# Sort confusions by percentage and combine small categories
sorted_confusions = sorted(confusion_percentages.items(), key=lambda x: x[1], reverse=True)
pie_data = []
pie_labels = []
other_percentage = 0
threshold = 3
for (true_label, pred_label), percentage in sorted_confusions:
if percentage >= threshold:
pie_data.append(percentage)
pie_labels.append(f"{true_label} → {pred_label}")
else:
other_percentage += percentage
if other_percentage > 0:
pie_data.append(other_percentage)
pie_labels.append("Others")
# Create a colors list
base_colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
colors = (base_colors * ((len(pie_data) - 1) // len(base_colors) + 1))[:len(pie_data) - 1]
if other_percentage > 0:
colors.append('#999999') # Gray color for "Others"
# 设置全局字体大小
plt.rcParams.update({'font.size': 36}) # 默认字体大小的两倍
plt.figure(figsize=(24, 16)) # 增加图形大小以适应更大的字体
wedges, texts, autotexts = plt.pie(pie_data, labels=None, autopct='%1.1f%%', startangle=90,
wedgeprops=dict(width=0.6), textprops=dict(color="k"),
colors=colors)
# 增加自动百分比文本的字体大小
for autotext in autotexts:
autotext.set_fontsize(20)
# Add lines connecting wedges to labels
for i, wedge in enumerate(wedges):
ang = (wedge.theta2 + wedge.theta1) / 2
y = np.sin(np.deg2rad(ang))
x = np.cos(np.deg2rad(ang))
horizontalalignment = {-1: "right", 1: "left"}[int(np.sign(x))]
connectionstyle = f"angle,angleA=0,angleB={ang}"
plt.annotate(pie_labels[i], xy=(x, y), xytext=(1.35*np.sign(x), 1.4*y),
horizontalalignment=horizontalalignment,
verticalalignment="center",
arrowprops=dict(arrowstyle="-", color="0.5",
connectionstyle=connectionstyle),
fontsize=48) # 增加注释文本的字体大小
plt.title(f"Confusion Distribution (>{threshold}%)", fontsize=48) # 增加标题字体大小
plt.axis('equal')
plt.savefig(f'{Picture_save_path}/confusion_pie_chart_{threshold}.png', bbox_inches='tight', dpi=300) # 增加DPI以提高图像质量
plt.close()