1
+ package org .baeldung .tensorflow ;
2
+
3
+ import org .tensorflow .DataType ;
4
+ import org .tensorflow .Graph ;
5
+ import org .tensorflow .Operation ;
6
+ import org .tensorflow .Session ;
7
+ import org .tensorflow .Tensor ;
8
+
9
+ public class TensorflowGraph {
10
+
11
+ public static Graph createGraph () {
12
+ Graph graph = new Graph ();
13
+ Operation a = graph .opBuilder ("Const" , "a" ).setAttr ("dtype" , DataType .fromClass (Double .class ))
14
+ .setAttr ("value" , Tensor .<Double >create (3.0 , Double .class )).build ();
15
+ Operation b = graph .opBuilder ("Const" , "b" ).setAttr ("dtype" , DataType .fromClass (Double .class ))
16
+ .setAttr ("value" , Tensor .<Double >create (2.0 , Double .class )).build ();
17
+ Operation x = graph .opBuilder ("Placeholder" , "x" ).setAttr ("dtype" , DataType .fromClass (Double .class )).build ();
18
+ Operation y = graph .opBuilder ("Placeholder" , "y" ).setAttr ("dtype" , DataType .fromClass (Double .class )).build ();
19
+ Operation ax = graph .opBuilder ("Mul" , "ax" ).addInput (a .output (0 )).addInput (x .output (0 )).build ();
20
+ Operation by = graph .opBuilder ("Mul" , "by" ).addInput (b .output (0 )).addInput (y .output (0 )).build ();
21
+ graph .opBuilder ("Add" , "z" ).addInput (ax .output (0 )).addInput (by .output (0 )).build ();
22
+ return graph ;
23
+ }
24
+
25
+ public static Object runGraph (Graph graph , Double x , Double y ) {
26
+ Object result ;
27
+ try (Session sess = new Session (graph )) {
28
+ result = sess .runner ().fetch ("z" ).feed ("x" , Tensor .<Double >create (x , Double .class ))
29
+ .feed ("y" , Tensor .<Double >create (y , Double .class )).run ().get (0 ).expect (Double .class )
30
+ .doubleValue ();
31
+ }
32
+ return result ;
33
+ }
34
+
35
+ public static void main (String [] args ) {
36
+ Graph graph = TensorflowGraph .createGraph ();
37
+ Object result = TensorflowGraph .runGraph (graph , 3.0 , 6.0 );
38
+ System .out .println (result );
39
+ graph .close ();
40
+ }
41
+ }
0 commit comments