Skip to content

Commit 80d5093

Browse files
authored
Add model test with explicit call() (#151)
* Test a model with an explicit call to call(). * Add TODO. * Make it look like the example from TF docs. * More tests.
1 parent 7461b27 commit 80d5093

File tree

8 files changed

+172
-1
lines changed

8 files changed

+172
-1
lines changed

edu.cuny.hunter.hybridize.tests/resources/HybridizeFunction/testModel/in/A.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, **kwargs):
1616
self.dropout = tf.keras.layers.Dropout(0.2)
1717
self.dense_2 = tf.keras.layers.Dense(10)
1818

19-
def call(self, x):
19+
def __call__(self, x):
2020
x = self.flatten(x)
2121

2222
for layer in self.my_layers:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import tensorflow as tf
2+
3+
# Create an override model to classify pictures
4+
5+
class SequentialModel(tf.keras.Model):
6+
def __init__(self, **kwargs):
7+
super(SequentialModel, self).__init__(**kwargs)
8+
9+
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
10+
11+
# Add a lot of small layers
12+
num_layers = 100
13+
self.my_layers = [tf.keras.layers.Dense(64, activation="relu")
14+
for n in range(num_layers)]
15+
16+
self.dropout = tf.keras.layers.Dropout(0.2)
17+
self.dense_2 = tf.keras.layers.Dense(10)
18+
19+
def call(self, x):
20+
x = self.flatten(x)
21+
22+
for layer in self.my_layers:
23+
x = layer(x)
24+
25+
x = self.dropout(x)
26+
x = self.dense_2(x)
27+
28+
return x
29+
30+
if __name__ == '__main__':
31+
input_data = tf.random.uniform([20, 28, 28])
32+
print("Input:")
33+
print(type(input_data))
34+
print(input_data)
35+
36+
model = SequentialModel()
37+
result = model(input_data)
38+
39+
print("Output:")
40+
print(type(input_data))
41+
print(result)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensorflow==2.9.3
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import tensorflow as tf
2+
3+
# Create an override model to classify pictures
4+
5+
class SequentialModel(tf.keras.Model):
6+
def __init__(self, **kwargs):
7+
super(SequentialModel, self).__init__(**kwargs)
8+
9+
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
10+
11+
# Add a lot of small layers
12+
num_layers = 100
13+
self.my_layers = [tf.keras.layers.Dense(64, activation="relu")
14+
for n in range(num_layers)]
15+
16+
self.dropout = tf.keras.layers.Dropout(0.2)
17+
self.dense_2 = tf.keras.layers.Dense(10)
18+
19+
def call(self, x):
20+
x = self.flatten(x)
21+
22+
for layer in self.my_layers:
23+
x = layer(x)
24+
25+
x = self.dropout(x)
26+
x = self.dense_2(x)
27+
28+
return x
29+
30+
if __name__ == '__main__':
31+
input_data = tf.random.uniform([20, 28, 28])
32+
print("Input:")
33+
print(type(input_data))
34+
print(input_data)
35+
36+
model = SequentialModel()
37+
result = model.call(input_data)
38+
39+
print("Output:")
40+
print(type(input_data))
41+
print(result)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensorflow==2.9.3
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import tensorflow as tf
2+
3+
# Create an override model to classify pictures
4+
5+
class SequentialModel(tf.keras.Model):
6+
def __init__(self, **kwargs):
7+
super(SequentialModel, self).__init__(**kwargs)
8+
9+
self.flatten = tf.keras.layers.Flatten(input_shape=(28, 28))
10+
11+
# Add a lot of small layers
12+
num_layers = 100
13+
self.my_layers = [tf.keras.layers.Dense(64, activation="relu")
14+
for n in range(num_layers)]
15+
16+
self.dropout = tf.keras.layers.Dropout(0.2)
17+
self.dense_2 = tf.keras.layers.Dense(10)
18+
19+
def __call__(self, x):
20+
x = self.flatten(x)
21+
22+
for layer in self.my_layers:
23+
x = layer(x)
24+
25+
x = self.dropout(x)
26+
x = self.dense_2(x)
27+
28+
return x
29+
30+
if __name__ == '__main__':
31+
input_data = tf.random.uniform([20, 28, 28])
32+
print("Input:")
33+
print(type(input_data))
34+
print(input_data)
35+
36+
model = SequentialModel()
37+
result = model.__call__(input_data)
38+
39+
print("Output:")
40+
print(type(input_data))
41+
print(result)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensorflow==2.9.3

edu.cuny.hunter.hybridize.tests/test cases/edu/cuny/hunter/hybridize/tests/HybridizeFunctionRefactoringTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,4 +1608,49 @@ public void testModel() throws Exception {
16081608
// no hybrids.
16091609
assertTrue(functions.stream().map(Function::isHybrid).allMatch(b -> b == false));
16101610
}
1611+
1612+
/**
1613+
* Test a model. No tf.function in this one. Use call instead of __call__. Ariadne doesn't support __call__.
1614+
* See https://github.com/wala/ML/issues/24.
1615+
*/
1616+
@Test
1617+
public void testModel2() throws Exception {
1618+
Set<Function> functions = this.getFunctions();
1619+
assertNotNull(functions);
1620+
1621+
LOG.info("Found functions: " + functions.size());
1622+
1623+
// no hybrids.
1624+
assertTrue(functions.stream().map(Function::isHybrid).allMatch(b -> b == false));
1625+
}
1626+
1627+
/**
1628+
* Test a model. No tf.function in this one. Explicit call method.
1629+
*/
1630+
@Test
1631+
public void testModel3() throws Exception {
1632+
Set<Function> functions = this.getFunctions();
1633+
assertNotNull(functions);
1634+
1635+
LOG.info("Found functions: " + functions.size());
1636+
1637+
// no hybrids.
1638+
assertTrue(functions.stream().map(Function::isHybrid).allMatch(b -> b == false));
1639+
}
1640+
1641+
/**
1642+
* Test a model. No tf.function in this one. Explicit call method.
1643+
*/
1644+
@Test
1645+
public void testModel4() throws Exception {
1646+
Set<Function> functions = this.getFunctions();
1647+
assertNotNull(functions);
1648+
1649+
LOG.info("Found functions: " + functions.size());
1650+
1651+
// no hybrids.
1652+
assertTrue(functions.stream().map(Function::isHybrid).allMatch(b -> b == false));
1653+
}
1654+
1655+
// TODO: Test models that have tf.functions.
16111656
}

0 commit comments

Comments
 (0)