Skip to content

Commit 4fca54f

Browse files
Wong4jnv-kkudrynski
authored andcommitted
[RN50/Paddle] Fix 2308 compatibility issue
1 parent 0131db6 commit 4fca54f

File tree

4 files changed

+144
-98
lines changed

4 files changed

+144
-98
lines changed

PaddlePaddle/Classification/RN50v1.5/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.06-py3
1+
ARG FROM_IMAGE_NAME=nvcr.io/nvidia/paddlepaddle:23.09-py3
22
FROM ${FROM_IMAGE_NAME}
33

44
ADD requirements.txt /workspace/

PaddlePaddle/Classification/RN50v1.5/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ The following section lists the requirements you need to meet to start training
233233
This repository contains a Dockerfile that extends the CUDA NGC container and encapsulates some dependencies. Aside from these dependencies, ensure you have the following components:
234234

235235
* [NVIDIA Docker](https://github.com/NVIDIA/nvidia-docker)
236-
* [PaddlePaddle 22.05-py3 NGC container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/paddlepaddle) or newer
236+
* [PaddlePaddle 23.09-py3 NGC container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/paddlepaddle) or newer
237237
* Supported GPUs:
238238
* [NVIDIA Ampere architecture](https://www.nvidia.com/en-us/data-center/nvidia-ampere-gpu-architecture/)
239239

@@ -289,7 +289,7 @@ docker build . -t nvidia_resnet50
289289

290290
### 4. Start an interactive session in the NGC container to run training/inference.
291291
```bash
292-
nvidia-docker run --rm -it -v <path to imagenet>:/imagenet --ipc=host nvidia_resnet50
292+
nvidia-docker run --rm -it -v <path to imagenet>:/imagenet --ipc=host --e FLAGS_apply_pass_to_program=1 nvidia_resnet50
293293
```
294294

295295
### 5. Start training

PaddlePaddle/Classification/RN50v1.5/program.py

Lines changed: 83 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import time
1615
import logging
17-
16+
import time
1817
from profile import Profiler
18+
19+
import dllogger
20+
import models
1921
import numpy as np
20-
from optimizer import build_optimizer
2122
from lr_scheduler import build_lr_scheduler
23+
from optimizer import build_optimizer
2224
from utils.misc import AverageMeter
2325
from utils.mode import Mode, RunScope
2426
from utils.utility import get_num_trainers
25-
import models
26-
27-
import dllogger
2827

2928
import paddle
3029
import paddle.nn.functional as F
3130
from paddle.distributed import fleet
3231
from paddle.distributed.fleet import DistributedStrategy
33-
from paddle.static import sparsity
3432
from paddle.distributed.fleet.meta_optimizers.common import CollectiveHelper
33+
from paddle.incubate import asp as sparsity
3534

3635

3736
def create_feeds(image_shape):
@@ -45,11 +44,13 @@ def create_feeds(image_shape):
4544
key (string): Name of variable to feed.
4645
Value (tuple): paddle.static.data.
4746
"""
48-
feeds = dict()
47+
feeds = {}
4948
feeds['data'] = paddle.static.data(
50-
name="data", shape=[None] + image_shape, dtype="float32")
49+
name="data", shape=[None] + image_shape, dtype="float32"
50+
)
5151
feeds['label'] = paddle.static.data(
52-
name="label", shape=[None, 1], dtype="int64")
52+
name="label", shape=[None, 1], dtype="int64"
53+
)
5354

5455
return feeds
5556

@@ -70,16 +71,15 @@ def create_fetchs(out, feeds, class_num, label_smoothing=0, mode=Mode.TRAIN):
7071
key (string): Name of variable to fetch.
7172
Value (tuple): (variable, AverageMeter).
7273
"""
73-
fetchs = dict()
74+
fetchs = {}
7475
target = paddle.reshape(feeds['label'], [-1, 1])
7576

7677
if mode == Mode.TRAIN:
7778
if label_smoothing == 0:
7879
loss = F.cross_entropy(out, target)
7980
else:
8081
label_one_hot = F.one_hot(target, class_num)
81-
soft_target = F.label_smooth(
82-
label_one_hot, epsilon=label_smoothing)
82+
soft_target = F.label_smooth(label_one_hot, epsilon=label_smoothing)
8383
soft_target = paddle.reshape(soft_target, shape=[-1, class_num])
8484
log_softmax = -F.log_softmax(out, axis=-1)
8585
loss = paddle.sum(log_softmax * soft_target, axis=-1)
@@ -94,19 +94,23 @@ def create_fetchs(out, feeds, class_num, label_smoothing=0, mode=Mode.TRAIN):
9494

9595
acc_top1 = paddle.metric.accuracy(input=out, label=target, k=1)
9696
acc_top5 = paddle.metric.accuracy(input=out, label=target, k=5)
97-
metric_dict = dict()
97+
metric_dict = {}
9898
metric_dict["top1"] = acc_top1
9999
metric_dict["top5"] = acc_top5
100100

101101
for key in metric_dict:
102102
if mode != Mode.TRAIN and paddle.distributed.get_world_size() > 1:
103103
paddle.distributed.all_reduce(
104-
metric_dict[key], op=paddle.distributed.ReduceOp.SUM)
105-
metric_dict[key] = metric_dict[
106-
key] / paddle.distributed.get_world_size()
104+
metric_dict[key], op=paddle.distributed.ReduceOp.SUM
105+
)
106+
metric_dict[key] = (
107+
metric_dict[key] / paddle.distributed.get_world_size()
108+
)
107109

108-
fetchs[key] = (metric_dict[key], AverageMeter(
109-
key, '7.4f', need_avg=True))
110+
fetchs[key] = (
111+
metric_dict[key],
112+
AverageMeter(key, '7.4f', need_avg=True),
113+
)
110114

111115
return fetchs
112116

@@ -127,13 +131,16 @@ def create_strategy(args, is_train=True):
127131
exec_strategy = paddle.static.ExecutionStrategy()
128132

129133
exec_strategy.num_threads = 1
130-
exec_strategy.num_iteration_per_drop_scope = (10000 if args.amp and
131-
args.use_pure_fp16 else 10)
132-
133-
paddle.set_flags({
134-
'FLAGS_cudnn_exhaustive_search': True,
135-
'FLAGS_conv_workspace_size_limit': 4096
136-
})
134+
exec_strategy.num_iteration_per_drop_scope = (
135+
10000 if args.amp and args.use_pure_fp16 else 10
136+
)
137+
138+
paddle.set_flags(
139+
{
140+
'FLAGS_cudnn_exhaustive_search': True,
141+
'FLAGS_conv_workspace_size_limit': 4096,
142+
}
143+
)
137144

138145
if not is_train:
139146
build_strategy.fix_op_run_order = True
@@ -177,7 +184,7 @@ def dist_optimizer(args, optimizer):
177184
dist_strategy.amp_configs = {
178185
"init_loss_scaling": args.scale_loss,
179186
"use_dynamic_loss_scaling": args.use_dynamic_loss_scaling,
180-
"use_pure_fp16": args.use_pure_fp16
187+
"use_pure_fp16": args.use_pure_fp16,
181188
}
182189

183190
dist_strategy.asp = args.asp
@@ -223,14 +230,16 @@ def build(args, main_prog, startup_prog, step_each_epoch, is_train=True):
223230
input_image_channel=input_image_channel,
224231
data_format=data_format,
225232
use_pure_fp16=use_pure_fp16,
226-
bn_weight_decay=bn_weight_decay)
233+
bn_weight_decay=bn_weight_decay,
234+
)
227235
out = model(feeds["data"])
228236

229237
fetchs = create_fetchs(
230-
out, feeds, class_num, args.label_smoothing, mode=mode)
238+
out, feeds, class_num, args.label_smoothing, mode=mode
239+
)
231240

232241
if args.asp:
233-
sparsity.set_excluded_layers(main_prog, [model.fc.weight.name])
242+
sparsity.set_excluded_layers(main_program=main_prog, param_names=[model.fc.weight.name])
234243

235244
lr_scheduler = None
236245
optimizer = None
@@ -244,10 +253,13 @@ def build(args, main_prog, startup_prog, step_each_epoch, is_train=True):
244253
# This is a workaround to "Communicator of ring id 0 has not been initialized.".
245254
# Since Paddle's design, the initialization would be done inside train program,
246255
# eval_only need to manually call initialization.
247-
if args.run_scope == RunScope.EVAL_ONLY and \
248-
paddle.distributed.get_world_size() > 1:
256+
if (
257+
args.run_scope == RunScope.EVAL_ONLY
258+
and paddle.distributed.get_world_size() > 1
259+
):
249260
collective_helper = CollectiveHelper(
250-
role_maker=fleet.PaddleCloudRoleMaker(is_collective=True))
261+
role_maker=fleet.PaddleCloudRoleMaker(is_collective=True)
262+
)
251263
collective_helper.update_startup_program(startup_prog)
252264

253265
return fetchs, lr_scheduler, feeds, optimizer
@@ -270,22 +282,22 @@ def compile_prog(args, program, loss_name=None, is_train=True):
270282
build_strategy, exec_strategy = create_strategy(args, is_train)
271283

272284
compiled_program = paddle.static.CompiledProgram(
273-
program).with_data_parallel(
274-
loss_name=loss_name,
275-
build_strategy=build_strategy,
276-
exec_strategy=exec_strategy)
285+
program, build_strategy=build_strategy
286+
)
277287

278288
return compiled_program
279289

280290

281-
def run(args,
282-
dataloader,
283-
exe,
284-
program,
285-
fetchs,
286-
epoch,
287-
mode=Mode.TRAIN,
288-
lr_scheduler=None):
291+
def run(
292+
args,
293+
dataloader,
294+
exe,
295+
program,
296+
fetchs,
297+
epoch,
298+
mode=Mode.TRAIN,
299+
lr_scheduler=None,
300+
):
289301
"""
290302
Execute program.
291303
@@ -312,11 +324,11 @@ def run(args,
312324
if fetchs[k][1] is not None:
313325
metric_dict[k] = fetchs[k][1]
314326

315-
metric_dict["batch_time"] = AverageMeter(
316-
'batch_time', '.5f', postfix=" s,")
327+
metric_dict["batch_time"] = AverageMeter('batch_time', '.5f', postfix=" s,")
317328
metric_dict["data_time"] = AverageMeter('data_time', '.5f', postfix=" s,")
318329
metric_dict["compute_time"] = AverageMeter(
319-
'compute_time', '.5f', postfix=" s,")
330+
'compute_time', '.5f', postfix=" s,"
331+
)
320332

321333
for m in metric_dict.values():
322334
m.reset()
@@ -328,8 +340,7 @@ def run(args,
328340
batch_size = None
329341
latency = []
330342

331-
total_benchmark_steps = \
332-
args.benchmark_steps + args.benchmark_warmup_steps
343+
total_benchmark_steps = args.benchmark_steps + args.benchmark_warmup_steps
333344

334345
dataloader.reset()
335346
while True:
@@ -361,11 +372,12 @@ def run(args,
361372
batch_size = batch[0]["data"].shape()[0]
362373
feed_dict = batch[0]
363374

364-
with profiler.profile_tag(idx, "Training"
365-
if mode == Mode.TRAIN else "Evaluation"):
366-
results = exe.run(program=program,
367-
feed=feed_dict,
368-
fetch_list=fetch_list)
375+
with profiler.profile_tag(
376+
idx, "Training" if mode == Mode.TRAIN else "Evaluation"
377+
):
378+
results = exe.run(
379+
program=program, feed=feed_dict, fetch_list=fetch_list
380+
)
369381

370382
for name, m in zip(fetchs.keys(), results):
371383
if name in metric_dict:
@@ -382,15 +394,16 @@ def run(args,
382394
tic = time.perf_counter()
383395

384396
if idx % args.print_interval == 0:
385-
log_msg = dict()
397+
log_msg = {}
386398
log_msg['loss'] = metric_dict['loss'].val.item()
387399
log_msg['top1'] = metric_dict['top1'].val.item()
388400
log_msg['top5'] = metric_dict['top5'].val.item()
389401
log_msg['data_time'] = metric_dict['data_time'].val
390402
log_msg['compute_time'] = metric_dict['compute_time'].val
391403
log_msg['batch_time'] = metric_dict['batch_time'].val
392-
log_msg['ips'] = \
404+
log_msg['ips'] = (
393405
batch_size * num_trainers / metric_dict['batch_time'].val
406+
)
394407
if mode == Mode.TRAIN:
395408
log_msg['lr'] = metric_dict['lr'].val
396409
log_info((epoch, idx), log_msg, mode)
@@ -406,10 +419,10 @@ def run(args,
406419
logging.info("Begin benchmark at step %d", idx + 1)
407420

408421
if idx == total_benchmark_steps:
409-
benchmark_data = dict()
410-
benchmark_data[
411-
'ips'] = batch_size * num_trainers / metric_dict[
412-
'batch_time'].avg
422+
benchmark_data = {}
423+
benchmark_data['ips'] = (
424+
batch_size * num_trainers / metric_dict['batch_time'].avg
425+
)
413426
if mode == mode.EVAL:
414427
latency = np.array(latency) * 1000
415428
quantile = np.quantile(latency, [0.9, 0.95, 0.99])
@@ -422,15 +435,19 @@ def run(args,
422435
logging.info("End benchmark at epoch step %d", idx)
423436
return benchmark_data
424437

425-
epoch_data = dict()
438+
epoch_data = {}
426439
epoch_data['loss'] = metric_dict['loss'].avg.item()
427440
epoch_data['epoch_time'] = metric_dict['batch_time'].total
428-
epoch_data['ips'] = batch_size * num_trainers * \
429-
metric_dict["batch_time"].count / metric_dict["batch_time"].sum
441+
epoch_data['ips'] = (
442+
batch_size
443+
* num_trainers
444+
* metric_dict["batch_time"].count
445+
/ metric_dict["batch_time"].sum
446+
)
430447
if mode == Mode.EVAL:
431448
epoch_data['top1'] = metric_dict['top1'].avg.item()
432449
epoch_data['top5'] = metric_dict['top5'].avg.item()
433-
log_info((epoch, ), epoch_data, mode)
450+
log_info((epoch,), epoch_data, mode)
434451

435452
return epoch_data
436453

@@ -445,7 +462,7 @@ def log_info(step, metrics, mode):
445462
mode(utils.Mode): Train or eval mode.
446463
"""
447464
prefix = 'train' if mode == Mode.TRAIN else 'val'
448-
dllogger_iter_data = dict()
465+
dllogger_iter_data = {}
449466
for key in metrics:
450467
dllogger_iter_data[f"{prefix}.{key}"] = metrics[key]
451468
dllogger.log(step=step, data=dllogger_iter_data)

0 commit comments

Comments
 (0)