-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplot_result.py
179 lines (134 loc) · 8.38 KB
/
plot_result.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import json
import random
import matplotlib.pyplot as plt
import numpy as np
def plot_raw_result():
result_path = 'stats/niid_50round_pro07.txt'
with open(result_path, 'r') as outfile:
data = json.load(outfile)
train_accuracy = data['train_accuracy']
train_loss = data['train_loss']
valid_accuracy = data['valid_accuracy']
valid_loss = data['valid_loss']
print(len(train_accuracy))
print(len(valid_accuracy))
x = []
y = []
t_acc = []
v_acc = []
num = []
for r in train_accuracy:
x.append(r[0])
t_acc.append(r[2])
for v in valid_accuracy:
y.append(v[0])
v_acc.append(v[2])
num.append(v[3])
fig, ax1 = plt.subplots()
ax2 = ax1.twinx() # 共享x轴
ax1.plot(x, t_acc, color='r', label='train_acc')
ax1.plot(y, v_acc, color='b', label='valid_acc')
ax1.set(xlabel='round', ylabel='accuracy', title='ddqn4 niid choose')
ax1.legend(loc=2)
ax2.plot(y, num, color='g', label='client_num')
ax2.set_ylim(0, 10)
ax2.set(ylabel='update_num')
ax2.legend(loc=7)
plt.show()
def plot_raw_acc():
# plt.rcParams['font.sans-serif'] = 'Times New Roman'
# plt.rcParams['figure.figsize'] = (5.6, 3.5)
c0_nniid_50round_drop3_result = './result/mnist/nniid/50round_drop3/client_0_global_model_local_data_acc.txt'
c1_nniid_50round_drop3_result = './result/mnist/nniid/50round_drop3/client_1_global_model_local_data_acc.txt'
c2_nniid_50round_drop3_result = './result/mnist/nniid/50round_drop3/client_2_global_model_local_data_acc.txt'
c0_nniid_50round_pro08_result = './result/mnist/nniid/50round_pro08/client_0_global_model_local_data_acc.txt'
c1_nniid_50round_pro08_result = './result/mnist/nniid/50round_pro08/client_1_global_model_local_data_acc.txt'
c2_nniid_50round_pro08_result = './result/mnist/nniid/50round_pro08/client_2_global_model_local_data_acc.txt'
c0_nniid_50round_pro09_result = './result/mnist/nniid/50round_pro09/client_0_global_model_local_data_acc.txt'
c1_nniid_50round_pro09_result = './result/mnist/nniid/50round_pro09/client_1_global_model_local_data_acc.txt'
c2_nniid_50round_pro09_result = './result/mnist/nniid/50round_pro09/client_2_global_model_local_data_acc.txt'
final_c0_iid_50round_all_result = './result/mnist/iid/50round_all/client_0_global_model_local_data_acc.txt'
final_c0_nniid_50round_all_result = './result/mnist/nniid/50round_all/client_0_global_model_local_data_acc.txt'
final_c0_nniid_50round_pro09_result = './result/mnist/nniid/final_pro09/client_0_global_model_local_data_acc.txt'
final_c0_nniid_50round_pro07_result = './result/mnist/nniid/final_pro07/client_0_global_model_local_data_acc.txt'
nniid_50round_all_result = './result/mnist/nniid/50round_all/global_model_global_data_acc.txt'
nniid_50round_drop3_result = './result/mnist/nniid/50round_drop3/global_model_global_data_acc.txt'
nniid_50round_drop5_result = './result/mnist/nniid/50round_drop5/global_model_global_data_acc.txt'
nniid_50round_pro08_result = './result/mnist/nniid/50round_pro08/global_model_global_data_acc.txt'
nniid_50round_pro09_result = './result/mnist/nniid/50round_pro09/global_model_global_data_acc.txt'
nniid_50round_pro07_result = './result/mnist/nniid/50round_pro07/global_model_global_data_acc.txt'
final_nniid_50round_pro07_result = './result/mnist/nniid/final_pro07/global_model_global_data_acc.txt'
final_nniid_50round_pro09_result = './result/mnist/nniid/final_pro09/global_model_global_data_acc.txt'
iid_50round_all_result = './result/mnist/iid/50round_all/global_model_global_data_acc.txt'
iid_50round_drop3_result = './result/mnist/iid/50round_drop3/global_model_global_data_acc.txt'
niid_50round_all_result = './result/mnist/niid/50round_all/global_model_global_data_acc.txt'
niid_50round_drop3_result = './result/mnist/niid/50round_drop3/global_model_global_data_acc.txt'
niid_50round_pro07_result = './result/mnist/niid/50round_pro07/global_model_global_data_acc.txt'
c0_nniid_50round_pro08_acc = np.loadtxt(c0_nniid_50round_pro08_result)
c1_nniid_50round_pro08_acc = np.loadtxt(c1_nniid_50round_pro08_result)
c2_nniid_50round_pro08_acc = np.loadtxt(c2_nniid_50round_pro08_result)
c0_nniid_50round_pro09_acc = np.loadtxt(c0_nniid_50round_pro09_result)
c1_nniid_50round_pro09_acc = np.loadtxt(c1_nniid_50round_pro09_result)
c2_nniid_50round_pro09_acc = np.loadtxt(c2_nniid_50round_pro09_result)
c0_nniid_50round_drop3_acc = np.loadtxt(c0_nniid_50round_drop3_result)
c1_nniid_50round_drop3_acc = np.loadtxt(c1_nniid_50round_drop3_result)
c2_nniid_50round_drop3_acc = np.loadtxt(c2_nniid_50round_drop3_result)
final_c0_nniid_50round_all_acc = np.loadtxt(final_c0_nniid_50round_all_result)
final_c0_nniid_50round_pro07_acc = np.loadtxt(final_c0_nniid_50round_pro07_result)
final_c0_nniid_50round_pro09_acc = np.loadtxt(final_c0_nniid_50round_pro09_result)
final_c0_iid_50round_pro09_acc = np.loadtxt(final_c0_iid_50round_all_result)
nniid_50round_all_acc = np.loadtxt(nniid_50round_all_result)
nniid_50round_drop3_acc = np.loadtxt(nniid_50round_drop3_result)
nniid_50round_drop5_acc = np.loadtxt(nniid_50round_drop5_result)
nniid_50round_pro08_acc = np.loadtxt(nniid_50round_pro08_result)
nniid_50round_pro09_acc = np.loadtxt(nniid_50round_pro09_result)
nniid_50round_pro07_acc = np.loadtxt(nniid_50round_pro07_result)
final_nniid_50round_pro07_acc = np.loadtxt(final_nniid_50round_pro07_result)
final_nniid_50round_pro09_acc = np.loadtxt(final_nniid_50round_pro09_result)
iid_50round_all_acc = np.loadtxt(iid_50round_all_result)
iid_50round_drop3_acc = np.loadtxt(iid_50round_drop3_result)
niid_50round_all_acc = np.loadtxt(niid_50round_all_result)
niid_50round_drop3_acc = np.loadtxt(niid_50round_drop3_result)
niid_50round_pro07_acc = np.loadtxt(niid_50round_pro07_result)
x = np.arange(len(nniid_50round_all_acc))
plt.plot(x, iid_50round_all_acc, '-', linewidth='1.5', label='iid_all_acc')
# plt.plot(x, iid_50round_all_acc, label='tpu_iid_all_acc')
# plt.plot(x, iid_50round_drop3_acc, label='iid_drop3_acc')
# plt.plot(x, niid_50round_all_acc, label='niid_all_acc')
# plt.plot(x, niid_50round_drop3_acc, label='niid_drop3_acc')
# plt.plot(x, niid_50round_pro07_acc, label='niid_pro07_acc')
plt.plot(x, nniid_50round_all_acc, '-', linewidth='1.5', label='tpu_non-iid_all_alive_acc')
# plt.plot(x, nniid_50round_drop3_acc, label='nniid_drop3_acc')
# plt.plot(x, nniid_50round_pro08_acc, label='nniid_pro08_acc')
# plt.plot(x, nniid_50round_pro09_acc, label='nniid_pro_09_acc')
# plt.plot(x, nniid_50round_pro07_acc, label='nniid_pro07_acc')
plt.plot(x, final_nniid_50round_pro07_acc, '-', linewidth='1.5', label='tpu_non-iid_probability_07_acc')
plt.plot(x, final_nniid_50round_pro09_acc, '-', linewidth='1.5', label='tpu_non-iid_probability_09_acc')
# plt.plot(x, nniid_50round_drop5_acc, label='nniid_drop5_acc')
# plt.plot(x, c0_nniid_50round_pro08_acc, label='c0_nniid_pro08_acc')
# plt.plot(x, c1_nniid_50round_pro08_acc, label='c1_nniid_pro08_acc')
# plt.plot(x, c2_nniid_50round_pro08_acc, label='c2_nniid_pro08_acc')
#
# plt.plot(x, c0_nniid_50round_pro09_acc, label='c0_nniid_pro09_acc')
# plt.plot(x, c1_nniid_50round_pro09_acc, label='c1_nniid_pro09_acc')
# plt.plot(x, c2_nniid_50round_pro09_acc, label='c2_nniid_pro09_acc')
# plt.plot(x, c0_nniid_50round_drop3_acc, label='c0_nniid_drop3_acc')
# plt.plot(x, c1_nniid_50round_drop3_acc, label='c1_nniid_drop3_acc')
# plt.plot(x, c2_nniid_50round_drop3_acc, label='c2_nniid_drop3_acc')
# plt.plot(x, final_c0_iid_50round_pro09_acc, '-', linewidth='1.5', label='client_iid_alive_acc')
# plt.plot(x, final_c0_nniid_50round_all_acc, '-', linewidth='1.5', label='client_nniid_alive_acc')
# plt.plot(x, final_c0_nniid_50round_pro09_acc, '-', linewidth='1.5', label='client_nniid_probability_09_acc')
# plt.plot(x, final_c0_nniid_50round_pro07_acc, '-', linewidth='1.5', label='client_nniid_probability_07_acc')
plt.ylim(0.4, 1)
# plt.ylim(0, 1)
# plt.ylabel('Accuracy', fontdict={'size': 16})
plt.ylabel('Accuracy')
# plt.xlabel('Round', fontdict={'size': 16})
plt.xlabel('Round')
# leg = plt.legend(fontsize=15)
plt.legend()
# plt.tick_params(labelsize=12)
plt.show()
if __name__ == '__main__':
# plot_raw_result()
plot_raw_acc()