Skip to content

Commit 501222d

Browse files
committed
update pytest
1 parent b19b606 commit 501222d

File tree

1 file changed

+30
-2
lines changed

1 file changed

+30
-2
lines changed

tests/test_zoo.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from compressai.zoo import (
4141
bmshj2018_factorized,
42+
bmshj2018_factorized_relu,
4243
bmshj2018_hyperprior,
4344
cheng2020_anchor,
4445
cheng2020_attn,
@@ -98,6 +99,33 @@ def test_pretrained(self, metric):
9899
assert net.state_dict()["g_a.6.weight"].size(0) == 320
99100

100101

102+
class TestBmshj2018FactorizedReLU:
103+
def test_params(self):
104+
for i in range(1, 6):
105+
net = bmshj2018_factorized_relu(i, metric="mse")
106+
assert isinstance(net, FactorizedPrior)
107+
assert net.state_dict()["g_a.0.weight"].size(0) == 128
108+
assert net.state_dict()["g_a.6.weight"].size(0) == 192
109+
110+
for i in range(6, 9):
111+
net = bmshj2018_factorized_relu(i, metric="mse")
112+
assert isinstance(net, FactorizedPrior)
113+
assert net.state_dict()["g_a.0.weight"].size(0) == 192
114+
115+
def test_invalid_params(self):
116+
with pytest.raises(ValueError):
117+
bmshj2018_factorized_relu(-1)
118+
119+
with pytest.raises(ValueError):
120+
bmshj2018_factorized_relu(10)
121+
122+
with pytest.raises(ValueError):
123+
bmshj2018_factorized_relu(10, metric="ssim")
124+
125+
with pytest.raises(ValueError):
126+
bmshj2018_factorized_relu(1, metric="ssim")
127+
128+
101129
class TestBmshj2018Hyperprior:
102130
def test_params(self):
103131
for i in range(1, 6):
@@ -131,12 +159,12 @@ def test_invalid_params(self):
131159
def test_pretrained(self, metric):
132160
# test we can load the correct models from the urls
133161
for i in range(1, 6):
134-
net = bmshj2018_factorized(i, metric=metric, pretrained=True)
162+
net = bmshj2018_hyperprior(i, metric=metric, pretrained=True)
135163
assert net.state_dict()["g_a.0.weight"].size(0) == 128
136164
assert net.state_dict()["g_a.6.weight"].size(0) == 192
137165

138166
for i in range(6, 9):
139-
net = bmshj2018_factorized(i, metric=metric, pretrained=True)
167+
net = bmshj2018_hyperprior(i, metric=metric, pretrained=True)
140168
assert net.state_dict()["g_a.0.weight"].size(0) == 192
141169
assert net.state_dict()["g_a.6.weight"].size(0) == 320
142170

0 commit comments

Comments
 (0)