|
33 | 33 | Cheng2020Anchor, |
34 | 34 | Cheng2020Attention, |
35 | 35 | FactorizedPrior, |
| 36 | + FactorizedPriorReLU, |
36 | 37 | JointAutoregressiveHierarchicalPriors, |
37 | 38 | MeanScaleHyperprior, |
38 | 39 | ScaleHyperprior, |
|
42 | 43 |
|
43 | 44 | __all__ = [ |
44 | 45 | "bmshj2018_factorized", |
| 46 | + "bmshj2018_factorized_relu", |
45 | 47 | "bmshj2018_hyperprior", |
46 | 48 | "mbt2018", |
47 | 49 | "mbt2018_mean", |
|
51 | 53 |
|
52 | 54 | model_architectures = { |
53 | 55 | "bmshj2018-factorized": FactorizedPrior, |
| 56 | + "bmshj2018_factorized_relu": FactorizedPriorReLU, |
54 | 57 | "bmshj2018-hyperprior": ScaleHyperprior, |
55 | 58 | "mbt2018-mean": MeanScaleHyperprior, |
56 | 59 | "mbt2018": JointAutoregressiveHierarchicalPriors, |
|
197 | 200 | 7: (192, 320), |
198 | 201 | 8: (192, 320), |
199 | 202 | }, |
| 203 | + "bmshj2018-factorized-relu": { |
| 204 | + 1: (128, 192), |
| 205 | + 2: (128, 192), |
| 206 | + 3: (128, 192), |
| 207 | + 4: (128, 192), |
| 208 | + 5: (128, 192), |
| 209 | + 6: (192, 320), |
| 210 | + 7: (192, 320), |
| 211 | + 8: (192, 320), |
| 212 | + }, |
200 | 213 | "bmshj2018-hyperprior": { |
201 | 214 | 1: (128, 192), |
202 | 215 | 2: (128, 192), |
@@ -298,6 +311,31 @@ def bmshj2018_factorized( |
298 | 311 | ) |
299 | 312 |
|
300 | 313 |
|
| 314 | +def bmshj2018_factorized_relu( |
| 315 | + quality, metric="mse", pretrained=False, progress=True, **kwargs |
| 316 | +): |
| 317 | + r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang, |
| 318 | + N. Johnston: `"Variational Image Compression with a Scale Hyperprior" |
| 319 | + <https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations |
| 320 | + (ICLR), 2018. |
| 321 | + GDN activations are replaced by ReLU |
| 322 | + Args: |
| 323 | + quality (int): Quality levels (1: lowest, highest: 8) |
| 324 | + metric (str): Optimized metric, choose from ('mse', 'ms-ssim') |
| 325 | + pretrained (bool): If True, returns a pre-trained model |
| 326 | + progress (bool): If True, displays a progress bar of the download to stderr |
| 327 | + """ |
| 328 | + if metric not in ("mse", "ms-ssim"): |
| 329 | + raise ValueError(f'Invalid metric "{metric}"') |
| 330 | + |
| 331 | + if quality < 1 or quality > 8: |
| 332 | + raise ValueError(f'Invalid quality "{quality}", should be between (1, 8)') |
| 333 | + |
| 334 | + return _load_model( |
| 335 | + "bmshj2018-factorized", metric, quality, pretrained, progress, **kwargs |
| 336 | + ) |
| 337 | + |
| 338 | + |
301 | 339 | def bmshj2018_hyperprior( |
302 | 340 | quality, metric="mse", pretrained=False, progress=True, **kwargs |
303 | 341 | ): |
|
0 commit comments