|
2 | 2 | import copy |
3 | 3 | import itertools |
4 | 4 | import logging |
| 5 | +from collections import defaultdict |
5 | 6 | from enum import Enum |
6 | 7 | from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union |
7 | 8 | import torch |
@@ -138,7 +139,7 @@ def get_default_optimizer_params( |
138 | 139 | bias_lr_factor: Optional[float] = 1.0, |
139 | 140 | weight_decay_bias: Optional[float] = None, |
140 | 141 | overrides: Optional[Dict[str, Dict[str, float]]] = None, |
141 | | -): |
| 142 | +) -> List[Dict[str, Any]]: |
142 | 143 | """ |
143 | 144 | Get default param list for optimizer, with support for a few types of |
144 | 145 | overrides. If no overrides needed, this is equivalent to `model.parameters()`. |
@@ -214,7 +215,39 @@ def get_default_optimizer_params( |
214 | 215 | hyperparams["weight_decay"] = weight_decay_norm |
215 | 216 | hyperparams.update(overrides.get(module_param_name, {})) |
216 | 217 | params.append({"params": [value], **hyperparams}) |
217 | | - return params |
| 218 | + return reduce_param_groups(params) |
| 219 | + |
| 220 | + |
| 221 | +def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| 222 | + # Transform parameter groups into per-parameter structure. |
| 223 | + # Later items in `params` can overwrite parameters set in previous items. |
| 224 | + ret = defaultdict(dict) |
| 225 | + for item in params: |
| 226 | + assert "params" in item |
| 227 | + cur_params = {x: y for x, y in item.items() if x != "params"} |
| 228 | + for param in item["params"]: |
| 229 | + ret[param].update({"params": [param], **cur_params}) |
| 230 | + return list(ret.values()) |
| 231 | + |
| 232 | + |
| 233 | +def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| 234 | + # Reorganize the parameter groups and merge duplicated groups. |
| 235 | + # The number of parameter groups needs to be as small as possible in order |
| 236 | + # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead |
| 237 | + # of using a parameter_group per single parameter, we reorganize the |
| 238 | + # parameter groups and merge duplicated groups. This approach speeds |
| 239 | + # up multi-tensor optimizer significantly. |
| 240 | + params = _expand_param_groups(params) |
| 241 | + groups = defaultdict(list) # re-group all parameter groups by their hyperparams |
| 242 | + for item in params: |
| 243 | + cur_params = tuple((x, y) for x, y in item.items() if x != "params") |
| 244 | + groups[cur_params].extend(item["params"]) |
| 245 | + ret = [] |
| 246 | + for param_keys, param_values in groups.items(): |
| 247 | + cur = {kv[0]: kv[1] for kv in param_keys} |
| 248 | + cur["params"] = param_values |
| 249 | + ret.append(cur) |
| 250 | + return ret |
218 | 251 |
|
219 | 252 |
|
220 | 253 | def build_lr_scheduler( |
|
0 commit comments