11
11
import argparse
12
12
from torch .utils .data import Dataset , DataLoader
13
13
import json
14
+ import sys
14
15
15
16
torch .backends .cudnn .benchmark = True
16
17
# https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936
17
18
# This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.
18
19
# If you check it using the profile tool, the cnn method such as winograd, fft, etc. is used for the first iteration and the best operation is selected for the device.
19
20
20
-
21
21
MODEL_LIST = {
22
22
models .mnasnet : models .mnasnet .__all__ [1 :],
23
23
models .resnet : models .resnet .__all__ [1 :],
32
32
precisions = ["float" , "half" , "double" ]
33
33
# For post-voltaic architectures, there is a possibility to use tensor-core at half precision.
34
34
# Due to the gradient overflow problem, apex is recommended for practical use.
35
- device_name = str (torch .cuda .get_device_name (0 ))
36
35
# Training settings
37
36
parser = argparse .ArgumentParser (description = "PyTorch Benchmarking" )
38
37
parser .add_argument (
39
38
"--WARM_UP" , "-w" , type = int , default = 5 , required = False , help = "Num of warm up"
40
39
)
40
+
41
41
parser .add_argument (
42
42
"--NUM_TEST" , "-n" , type = int , default = 50 , required = False , help = "Num of Test"
43
43
)
44
+
44
45
parser .add_argument (
45
46
"--BATCH_SIZE" , "-b" , type = int , default = 12 , required = False , help = "Num of batch size"
46
47
)
48
+
47
49
parser .add_argument (
48
50
"--NUM_CLASSES" , "-c" , type = int , default = 1000 , required = False , help = "Num of class"
49
51
)
52
+
50
53
parser .add_argument (
51
- "--NUM_GPU " , "-g" , type = int , default = 1 , required = False , help = "Num of gpus"
54
+ "--GPU_COUNT " , "-g" , type = int , default = 1 , required = False , help = "Number of gpus used in test "
52
55
)
56
+
57
+ parser .add_argument (
58
+ "--GPU_INDEX" , "-i" , type = int , default = - 1 , required = False , help = "Index for the used gpu"
59
+ )
60
+
53
61
parser .add_argument (
54
62
"--folder" ,
55
63
"-f" ,
58
66
required = False ,
59
67
help = "folder to save results" ,
60
68
)
61
- args = parser .parse_args ()
62
- args .BATCH_SIZE *= args .NUM_GPU
63
-
64
69
65
70
class RandomDataset (Dataset ):
66
71
def __init__ (self , length ):
@@ -73,16 +78,7 @@ def __getitem__(self, index):
73
78
def __len__ (self ):
74
79
return self .len
75
80
76
-
77
- rand_loader = DataLoader (
78
- dataset = RandomDataset (args .BATCH_SIZE * (args .WARM_UP + args .NUM_TEST )),
79
- batch_size = args .BATCH_SIZE ,
80
- shuffle = False ,
81
- num_workers = 8 ,
82
- )
83
-
84
-
85
- def train (precision = "single" ):
81
+ def train (precision = "single" , gpu_index = - 1 ):
86
82
"""use fake image for training speed test"""
87
83
target = torch .LongTensor (args .BATCH_SIZE ).random_ (args .NUM_CLASSES ).cuda ()
88
84
criterion = nn .CrossEntropyLoss ()
@@ -91,10 +87,15 @@ def train(precision="single"):
91
87
for model_name in MODEL_LIST [model_type ]:
92
88
if model_name [- 8 :] == '_Weights' : continue
93
89
model = getattr (model_type , model_name )()
94
- if args .NUM_GPU > 1 :
95
- model = nn .DataParallel (model , device_ids = range (args .NUM_GPU ))
90
+ if args .GPU_COUNT > 1 :
91
+ model = nn .DataParallel (model , device_ids = range (args .GPU_COUNT ))
96
92
model = getattr (model , precision )()
97
- model = model .to ("cuda" )
93
+ torch_device_name = "cuda"
94
+ if (gpu_index >= 0 ):
95
+ torch_device_name = "cuda:" + str (gpu_index )
96
+ print ("torch_device_name: " + torch_device_name )
97
+ torch_device = torch .device (torch_device_name )
98
+ model = model .to (torch_device )
98
99
durations = []
99
100
print (f"Benchmarking Training { precision } precision type { model_name } " )
100
101
for step , img in enumerate (rand_loader ):
@@ -116,18 +117,22 @@ def train(precision="single"):
116
117
benchmark [model_name ] = durations
117
118
return benchmark
118
119
119
-
120
- def inference (precision = "float" ):
120
+ def inference (precision = "float" , gpu_index = - 1 ):
121
121
benchmark = {}
122
122
with torch .no_grad ():
123
123
for model_type in MODEL_LIST .keys ():
124
124
for model_name in MODEL_LIST [model_type ]:
125
125
if model_name [- 8 :] == '_Weights' : continue
126
126
model = getattr (model_type , model_name )()
127
- if args .NUM_GPU > 1 :
128
- model = nn .DataParallel (model , device_ids = range (args .NUM_GPU ))
127
+ if args .GPU_COUNT > 1 :
128
+ model = nn .DataParallel (model , device_ids = range (args .GPU_COUNT ))
129
129
model = getattr (model , precision )()
130
- model = model .to ("cuda" )
130
+ torch_device_name = "cuda"
131
+ if (gpu_index >= 0 ):
132
+ torch_device_name = "cuda:" + str (gpu_index )
133
+ print ("torch_device_name: " + torch_device_name )
134
+ torch_device = torch .device (torch_device_name )
135
+ model = model .to (torch_device )
131
136
model .eval ()
132
137
durations = []
133
138
print (
@@ -149,30 +154,62 @@ def inference(precision="float"):
149
154
benchmark [model_name ] = durations
150
155
return benchmark
151
156
152
-
153
157
f"{ platform .uname ()} \n { psutil .cpu_freq ()} \n cpu_count: { psutil .cpu_count ()} \n memory_available: { psutil .virtual_memory ().available } "
154
158
155
-
156
159
if __name__ == "__main__" :
157
- folder_name = args .folder
160
+ args = parser .parse_args ()
161
+ args .BATCH_SIZE *= args .GPU_COUNT
162
+
163
+ print ("BATCH_SIZE: " + str (args .BATCH_SIZE ))
164
+ rand_loader = DataLoader (
165
+ dataset = RandomDataset (args .BATCH_SIZE * (args .WARM_UP + args .NUM_TEST )),
166
+ batch_size = args .BATCH_SIZE ,
167
+ shuffle = False ,
168
+ num_workers = 8 ,
169
+ )
170
+ gpu_count = args .GPU_COUNT
171
+ gpu_index = args .GPU_INDEX
172
+
173
+ print ("gpu_index: " + str (gpu_index ))
174
+ print ("gpu_count: " + str (gpu_count ))
175
+
176
+ if (gpu_index >= 0 ):
177
+ device_name = str (torch .cuda .get_device_name (gpu_index ))
178
+ else :
179
+ device_name = str (torch .cuda .get_device_name (0 ))
180
+ device_name = f"{ device_name } "
181
+ if (args .GPU_COUNT > 1 ):
182
+ device_name = device_name + str (gpu_count ) + "X"
183
+ device_name = device_name .replace (" " , "_" )
184
+ device_file_name = device_name + "_"
185
+ print ("device_name: " + device_name )
186
+
187
+ if (gpu_index >= 0 ):
188
+ folder_name = args .folder + "/" + str (gpu_index ) + "/" + device_name
189
+ else :
190
+ folder_name = args .folder + "/" + str (gpu_count ) + "X"
191
+ print ("folder_name: " + folder_name )
158
192
159
- device_name = f"{ device_name } _{ args .NUM_GPU } _gpus_"
160
193
system_configs = f"{ platform .uname ()} \n \
161
194
{ psutil .cpu_freq ()} \n \
162
195
cpu_count: { psutil .cpu_count ()} \n \
163
196
memory_available: { psutil .virtual_memory ().available } "
164
197
gpu_configs = [
165
- torch .cuda .device_count (),
198
+ gpu_count ,
199
+ torch .__version__ ,
200
+ torch .version .hip ,
166
201
torch .version .cuda ,
167
202
torch .backends .cudnn .version (),
168
- torch . cuda . get_device_name ( 0 ) ,
203
+ device_name ,
169
204
]
170
205
gpu_configs = list (map (str , gpu_configs ))
171
206
temp = [
172
- "Number of GPUs on current device : " ,
173
- "CUDA Version : " ,
174
- "Cudnn Version : " ,
175
- "Device Name : " ,
207
+ "GPU_Count: " ,
208
+ "Torch_Version : " ,
209
+ "ROCM_Version: " ,
210
+ "CUDA_Version: " ,
211
+ "Cudnn_Version: " ,
212
+ "Device_Name: " ,
176
213
]
177
214
178
215
os .makedirs (folder_name , exist_ok = True )
@@ -197,14 +234,14 @@ def inference(precision="float"):
197
234
f .writelines (s + "\n " for s in gpu_configs )
198
235
199
236
for precision in precisions :
200
- train_result = train (precision )
237
+ train_result = train (precision , gpu_index )
201
238
train_result_df = pd .DataFrame (train_result )
202
- path = f"{ folder_name } /{ device_name } _{ precision } _model_train_benchmark.csv"
239
+ path = f"{ folder_name } /{ device_file_name } _{ precision } _model_train_benchmark.csv"
203
240
train_result_df .to_csv (path , index = False )
204
241
205
- inference_result = inference (precision )
242
+ inference_result = inference (precision , gpu_index )
206
243
inference_result_df = pd .DataFrame (inference_result )
207
- path = f"{ folder_name } /{ device_name } _{ precision } _model_inference_benchmark.csv"
244
+ path = f"{ folder_name } /{ device_file_name } _{ precision } _model_inference_benchmark.csv"
208
245
inference_result_df .to_csv (path , index = False )
209
246
210
247
now = datetime .datetime .now ()
0 commit comments