Skip to content

Commit dc72b8b

Browse files
kcacademicpivovarit
authored andcommitted
Adding source code for the tutorial tracked under BAEL-2759 (eugenp#6533)
1 parent b3fc270 commit dc72b8b

File tree

8 files changed

+158
-0
lines changed

8 files changed

+158
-0
lines changed

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,9 @@
526526
<module>rxjava</module>
527527
<module>rxjava-2</module>
528528
<module>software-security/sql-injection-samples</module>
529+
530+
<module>tensorflow-java</module>
531+
529532
</modules>
530533

531534
</profile>
@@ -742,6 +745,8 @@
742745
<module>xml</module>
743746
<module>xmlunit-2</module>
744747
<module>xstream</module>
748+
749+
<module>tensorflow-java</module>
745750

746751
</modules>
747752

tensorflow-java/.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
/.settings
2+
/model
3+
/target
4+
.classpath
5+
.project
6+
.springBeans

tensorflow-java/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
## Relevant articles:
2+
3+
- [TensorFlow for Java](https://www.baeldung.com/xxxx)

tensorflow-java/pom.xml

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<project xmlns="http://maven.apache.org/POM/4.0.0"
3+
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
5+
<modelVersion>4.0.0</modelVersion>
6+
<groupId>com.baeldung</groupId>
7+
<artifactId>tensorflow-java</artifactId>
8+
<version>1.0-SNAPSHOT</version>
9+
<packaging>jar</packaging>
10+
<url>http://maven.apache.org</url>
11+
12+
<parent>
13+
<groupId>com.baeldung</groupId>
14+
<artifactId>parent-modules</artifactId>
15+
<version>1.0.0-SNAPSHOT</version>
16+
</parent>
17+
18+
<properties>
19+
<java.version>1.8</java.version>
20+
<tensorflow.version>1.12.0</tensorflow.version>
21+
<junit.jupiter.version>5.4.0</junit.jupiter.version>
22+
</properties>
23+
24+
<dependencies>
25+
<dependency>
26+
<groupId>org.tensorflow</groupId>
27+
<artifactId>tensorflow</artifactId>
28+
<version>${tensorflow.version}</version>
29+
</dependency>
30+
<dependency>
31+
<groupId>org.junit.jupiter</groupId>
32+
<artifactId>junit-jupiter-api</artifactId>
33+
<version>${junit.jupiter.version}</version>
34+
<scope>test</scope>
35+
</dependency>
36+
<dependency>
37+
<groupId>org.junit.jupiter</groupId>
38+
<artifactId>junit-jupiter-engine</artifactId>
39+
<version>${junit.jupiter.version}</version>
40+
<scope>test</scope>
41+
</dependency>
42+
</dependencies>
43+
44+
<build>
45+
<plugins>
46+
<plugin>
47+
<groupId>org.springframework.boot</groupId>
48+
<artifactId>spring-boot-maven-plugin</artifactId>
49+
</plugin>
50+
</plugins>
51+
</build>
52+
</project>
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package org.baeldung.tensorflow;
2+
3+
import org.tensorflow.DataType;
4+
import org.tensorflow.Graph;
5+
import org.tensorflow.Operation;
6+
import org.tensorflow.Session;
7+
import org.tensorflow.Tensor;
8+
9+
public class TensorflowGraph {
10+
11+
public static Graph createGraph() {
12+
Graph graph = new Graph();
13+
Operation a = graph.opBuilder("Const", "a").setAttr("dtype", DataType.fromClass(Double.class))
14+
.setAttr("value", Tensor.<Double>create(3.0, Double.class)).build();
15+
Operation b = graph.opBuilder("Const", "b").setAttr("dtype", DataType.fromClass(Double.class))
16+
.setAttr("value", Tensor.<Double>create(2.0, Double.class)).build();
17+
Operation x = graph.opBuilder("Placeholder", "x").setAttr("dtype", DataType.fromClass(Double.class)).build();
18+
Operation y = graph.opBuilder("Placeholder", "y").setAttr("dtype", DataType.fromClass(Double.class)).build();
19+
Operation ax = graph.opBuilder("Mul", "ax").addInput(a.output(0)).addInput(x.output(0)).build();
20+
Operation by = graph.opBuilder("Mul", "by").addInput(b.output(0)).addInput(y.output(0)).build();
21+
graph.opBuilder("Add", "z").addInput(ax.output(0)).addInput(by.output(0)).build();
22+
return graph;
23+
}
24+
25+
public static Object runGraph(Graph graph, Double x, Double y) {
26+
Object result;
27+
try (Session sess = new Session(graph)) {
28+
result = sess.runner().fetch("z").feed("x", Tensor.<Double>create(x, Double.class))
29+
.feed("y", Tensor.<Double>create(y, Double.class)).run().get(0).expect(Double.class)
30+
.doubleValue();
31+
}
32+
return result;
33+
}
34+
35+
public static void main(String[] args) {
36+
Graph graph = TensorflowGraph.createGraph();
37+
Object result = TensorflowGraph.runGraph(graph, 3.0, 6.0);
38+
System.out.println(result);
39+
graph.close();
40+
}
41+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package org.baeldung.tensorflow;
2+
3+
import org.tensorflow.SavedModelBundle;
4+
import org.tensorflow.Tensor;
5+
6+
public class TensorflowSavedModel {
7+
8+
public static void main(String[] args) {
9+
SavedModelBundle model = SavedModelBundle.load("./model", "serve");
10+
Tensor<Integer> tensor = model.session().runner().fetch("z").feed("x", Tensor.<Integer>create(3, Integer.class))
11+
.feed("y", Tensor.<Integer>create(3, Integer.class)).run().get(0).expect(Integer.class);
12+
System.out.println(tensor.intValue());
13+
}
14+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import tensorflow as tf
2+
graph = tf.Graph()
3+
builder = tf.saved_model.builder.SavedModelBuilder('./model')
4+
writer = tf.summary.FileWriter('.')
5+
with graph.as_default():
6+
a = tf.constant(2, name='a')
7+
b = tf.constant(3, name='b')
8+
x = tf.placeholder(tf.int32, name='x')
9+
y = tf.placeholder(tf.int32, name='y')
10+
z = tf.math.add(a*x, b*y, name='z')
11+
writer.add_graph(tf.get_default_graph())
12+
writer.flush()
13+
sess = tf.Session()
14+
sess.run(z, feed_dict = {x: 2, y: 3})
15+
builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.SERVING])
16+
builder.save()
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.baeldung.tensorflow;
2+
3+
import static org.junit.jupiter.api.Assertions.assertEquals;
4+
5+
import org.junit.Test;
6+
import org.tensorflow.Graph;
7+
8+
public class TensorflowGraphUnitTest {
9+
10+
@Test
11+
public void givenTensorflowGraphWhenRunInSessionReturnsExpectedResult() {
12+
13+
Graph graph = TensorflowGraph.createGraph();
14+
Object result = TensorflowGraph.runGraph(graph, 3.0, 6.0);
15+
assertEquals(21.0, result);
16+
System.out.println(result);
17+
graph.close();
18+
19+
}
20+
21+
}

0 commit comments

Comments
 (0)