Skip to content

Commit e98998d

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
Add reduce_param_groups to D2
Summary: this utility function was added in D30272112 and is useful to all D2 (11528ce) users as well Differential Revision: D31833523 fbshipit-source-id: 0adfc612adb8b448fa7f3dbec1b1278c309554c5
1 parent b101248 commit e98998d

File tree

2 files changed

+101
-2
lines changed

2 files changed

+101
-2
lines changed

detectron2/solver/build.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import copy
33
import itertools
44
import logging
5+
from collections import defaultdict
56
from enum import Enum
67
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
78
import torch
@@ -138,7 +139,7 @@ def get_default_optimizer_params(
138139
bias_lr_factor: Optional[float] = 1.0,
139140
weight_decay_bias: Optional[float] = None,
140141
overrides: Optional[Dict[str, Dict[str, float]]] = None,
141-
):
142+
) -> List[Dict[str, Any]]:
142143
"""
143144
Get default param list for optimizer, with support for a few types of
144145
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
@@ -214,7 +215,39 @@ def get_default_optimizer_params(
214215
hyperparams["weight_decay"] = weight_decay_norm
215216
hyperparams.update(overrides.get(module_param_name, {}))
216217
params.append({"params": [value], **hyperparams})
217-
return params
218+
return reduce_param_groups(params)
219+
220+
221+
def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
222+
# Transform parameter groups into per-parameter structure.
223+
# Later items in `params` can overwrite parameters set in previous items.
224+
ret = defaultdict(dict)
225+
for item in params:
226+
assert "params" in item
227+
cur_params = {x: y for x, y in item.items() if x != "params"}
228+
for param in item["params"]:
229+
ret[param].update({"params": [param], **cur_params})
230+
return list(ret.values())
231+
232+
233+
def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
234+
# Reorganize the parameter groups and merge duplicated groups.
235+
# The number of parameter groups needs to be as small as possible in order
236+
# to efficiently use the PyTorch multi-tensor optimizer. Therefore instead
237+
# of using a parameter_group per single parameter, we reorganize the
238+
# parameter groups and merge duplicated groups. This approach speeds
239+
# up multi-tensor optimizer significantly.
240+
params = _expand_param_groups(params)
241+
groups = defaultdict(list) # re-group all parameter groups by their hyperparams
242+
for item in params:
243+
cur_params = tuple((x, y) for x, y in item.items() if x != "params")
244+
groups[cur_params].extend(item["params"])
245+
ret = []
246+
for param_keys, param_values in groups.items():
247+
cur = {kv[0]: kv[1] for kv in param_keys}
248+
cur["params"] = param_values
249+
ret.append(cur)
250+
return ret
218251

219252

220253
def build_lr_scheduler(

tests/test_solver.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import unittest
2+
3+
from detectron2.solver.build import _expand_param_groups, reduce_param_groups
4+
5+
6+
class TestOptimizer(unittest.TestCase):
7+
def testExpandParamsGroups(self):
8+
params = [
9+
{
10+
"params": ["p1", "p2", "p3", "p4"],
11+
"lr": 1.0,
12+
"weight_decay": 3.0,
13+
},
14+
{
15+
"params": ["p2", "p3", "p5"],
16+
"lr": 2.0,
17+
"momentum": 2.0,
18+
},
19+
{
20+
"params": ["p1"],
21+
"weight_decay": 4.0,
22+
},
23+
]
24+
out = _expand_param_groups(params)
25+
gt = [
26+
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
27+
dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
28+
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
29+
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
30+
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
31+
]
32+
self.assertEqual(out, gt)
33+
34+
def testReduceParamGroups(self):
35+
params = [
36+
dict(params=["p1"], lr=1.0, weight_decay=4.0), # noqa
37+
dict(params=["p2", "p6"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
38+
dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0), # noqa
39+
dict(params=["p4"], lr=1.0, weight_decay=3.0), # noqa
40+
dict(params=["p5"], lr=2.0, momentum=2.0), # noqa
41+
]
42+
gt_groups = [
43+
{
44+
"lr": 1.0,
45+
"weight_decay": 4.0,
46+
"params": ["p1"],
47+
},
48+
{
49+
"lr": 2.0,
50+
"weight_decay": 3.0,
51+
"momentum": 2.0,
52+
"params": ["p2", "p6", "p3"],
53+
},
54+
{
55+
"lr": 1.0,
56+
"weight_decay": 3.0,
57+
"params": ["p4"],
58+
},
59+
{
60+
"lr": 2.0,
61+
"momentum": 2.0,
62+
"params": ["p5"],
63+
},
64+
]
65+
out = reduce_param_groups(params)
66+
self.assertEqual(out, gt_groups)

0 commit comments

Comments
 (0)