diff --git a/examples/mnist_savedmodel.rs b/examples/mnist_savedmodel.rs index 2943583..f7be7cf 100644 --- a/examples/mnist_savedmodel.rs +++ b/examples/mnist_savedmodel.rs @@ -37,10 +37,15 @@ fn main() -> Result<(), Box> { // Load the saved model exported by regression_savedmodel.py. let mut graph = Graph::new(); - let session = - SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?.session; - let op_x = graph.operation_by_name_required("serving_default_sequential_input")?; - let op_predict = graph.operation_by_name_required("StatefulPartitionedCall")?; + let bundle = + SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?; + let session = &bundle.session; + + let signature = bundle.meta_graph_def().get_signature("serving_default")?; + let input_info = signature.get_input("input")?; + let op_x = graph.operation_by_name_required(&input_info.name().name)?; + let output_info = signature.get_output("output")?; + let op_predict = graph.operation_by_name_required(&output_info.name().name)?; // Train the model (e.g. for fine tuning). let mut args = SessionRunArgs::new(); diff --git a/examples/mnist_savedmodel/expected_values.txt b/examples/mnist_savedmodel/expected_values.txt index 14815df..c985c30 100644 --- a/examples/mnist_savedmodel/expected_values.txt +++ b/examples/mnist_savedmodel/expected_values.txt @@ -1 +1 @@ -4.0663176e-06, 1.4199884e-07, 9.556003e-05, 0.00065914105, 2.260991e-07, 4.076631e-06, 2.5459945e-09, 0.99904054, 1.5654963e-05, 0.00018059688 +3.112342e-05, 8.721303e-08, 0.0005018024, 0.0003709061, 1.6482764e-08, 1.8595395e-06, 1.3620006e-09, 0.999046, 5.4331244e-06, 4.2815118e-05 diff --git a/examples/mnist_savedmodel/mnist_savedmodel.py b/examples/mnist_savedmodel/mnist_savedmodel.py index 2fb107f..4225a3f 100644 --- a/examples/mnist_savedmodel/mnist_savedmodel.py +++ b/examples/mnist_savedmodel/mnist_savedmodel.py @@ -31,7 +31,11 @@ model.fit(x_train, y_train, epochs=1) # convert output type through softmax so that it can be interpreted as probability -probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax(name="output")]) +inputs = tf.keras.Input((28, 28), name="input", dtype=tf.float32) +x = model(inputs) +outputs = tf.keras.layers.Softmax(name="output")(x) + +probability_model = tf.keras.Model(inputs=inputs, outputs=outputs) # dump expected values to compare Rust's outputs with open("examples/mnist_savedmodel/expected_values.txt", "w") as f: @@ -45,6 +49,6 @@ logdir = "logs/mnist_savedmodel" writer = tf.summary.create_file_writer(logdir) tf.summary.trace_on() -values = probability_model(x_test[:1, :, :]) +values = probability_model.predict(x_test[:1, :, :]) with writer.as_default(): tf.summary.trace_export("Default", step=0) diff --git a/examples/mnist_savedmodel/saved_model.pb b/examples/mnist_savedmodel/saved_model.pb index 0a72d10..89b5b60 100644 Binary files a/examples/mnist_savedmodel/saved_model.pb and b/examples/mnist_savedmodel/saved_model.pb differ diff --git a/examples/mnist_savedmodel/variables/variables.data-00000-of-00001 b/examples/mnist_savedmodel/variables/variables.data-00000-of-00001 index 5db930e..a510262 100644 Binary files a/examples/mnist_savedmodel/variables/variables.data-00000-of-00001 and b/examples/mnist_savedmodel/variables/variables.data-00000-of-00001 differ diff --git a/examples/mnist_savedmodel/variables/variables.index b/examples/mnist_savedmodel/variables/variables.index index 18d93a2..6644524 100644 Binary files a/examples/mnist_savedmodel/variables/variables.index and b/examples/mnist_savedmodel/variables/variables.index differ diff --git a/examples/regression_savedmodel.rs b/examples/regression_savedmodel.rs index 39365a7..de787e2 100644 --- a/examples/regression_savedmodel.rs +++ b/examples/regression_savedmodel.rs @@ -41,19 +41,33 @@ fn main() -> Result<(), Box> { // Load the saved model exported by regression_savedmodel.py. let mut graph = Graph::new(); - let session = - SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?.session; - let op_x = graph.operation_by_name_required("train_x")?; - let op_y = graph.operation_by_name_required("train_y")?; - let op_train = graph.operation_by_name_required("StatefulPartitionedCall")?; - let op_w = graph.operation_by_name_required("StatefulPartitionedCall_1")?; - let op_b = graph.operation_by_name_required("StatefulPartitionedCall_1")?; + let bundle = + SavedModelBundle::load(&SessionOptions::new(), &["serve"], &mut graph, export_dir)?; + let session = &bundle.session; + + let train_signature = bundle.meta_graph_def().get_signature("train")?; + let x_info = train_signature.get_input("x")?; + let y_info = train_signature.get_input("y")?; + let train_info = train_signature.get_output("train")?; + let op_x = graph.operation_by_name_required(&x_info.name().name)?; + let op_y = graph.operation_by_name_required(&y_info.name().name)?; + let op_train = graph.operation_by_name_required(&train_info.name().name)?; + let w_info = bundle + .meta_graph_def() + .get_signature("w")? + .get_output("output")?; + let op_w = graph.operation_by_name_required(&w_info.name().name)?; + let b_info = bundle + .meta_graph_def() + .get_signature("b")? + .get_output("output")?; + let op_b = graph.operation_by_name_required(&b_info.name().name)?; // Train the model (e.g. for fine tuning). let mut train_step = SessionRunArgs::new(); train_step.add_feed(&op_x, 0, &x); train_step.add_feed(&op_y, 0, &y); - train_step.request_fetch(&op_train, 0); + train_step.add_target(&op_train); for _ in 0..steps { session.run(&mut train_step)?; } @@ -61,7 +75,7 @@ fn main() -> Result<(), Box> { // Grab the data out of the session. let mut output_step = SessionRunArgs::new(); let w_ix = output_step.request_fetch(&op_w, 0); - let b_ix = output_step.request_fetch(&op_b, 1); + let b_ix = output_step.request_fetch(&op_b, 0); session.run(&mut output_step)?; // Check our results. diff --git a/examples/regression_savedmodel/regression_savedmodel.py b/examples/regression_savedmodel/regression_savedmodel.py index 8f9d169..f61d52d 100644 --- a/examples/regression_savedmodel/regression_savedmodel.py +++ b/examples/regression_savedmodel/regression_savedmodel.py @@ -14,8 +14,12 @@ def __call__(self, x): return y_hat @tf.function - def get_weights(self): - return self.w, self.b + def get_w(self): + return {"output": self.w} + + @tf.function + def get_b(self): + return {"output": self.b} @tf.function def train(self, x, y): @@ -23,10 +27,8 @@ def train(self, x, y): y_hat = self(x) loss = tf.reduce_mean(tf.square(y_hat - y)) grads = tape.gradient(loss, self.trainable_variables) - _ = self.optimizer.apply_gradients( - zip(grads, self.trainable_variables), name="train" - ) - return loss + _ = self.optimizer.apply_gradients(zip(grads, self.trainable_variables)) + return {"train": loss} model = LinearRegresstion() @@ -34,10 +36,11 @@ def train(self, x, y): x = tf.TensorSpec([None], tf.float32, name="x") y = tf.TensorSpec([None], tf.float32, name="y") train = model.train.get_concrete_function(x, y) -weights = model.get_weights.get_concrete_function() +w = model.get_w.get_concrete_function() +b = model.get_b.get_concrete_function() directory = "examples/regression_savedmodel" -signatures = {"train": train, "weights": weights} +signatures = {"train": train, "w": w, "b": b} tf.saved_model.save(model, directory, signatures=signatures) # export graph info to TensorBoard @@ -45,4 +48,5 @@ def train(self, x, y): writer = tf.summary.create_file_writer(logdir) with writer.as_default(): tf.summary.graph(train.graph) - tf.summary.graph(weights.graph) + tf.summary.graph(w.graph) + tf.summary.graph(b.graph) diff --git a/examples/regression_savedmodel/saved_model.pb b/examples/regression_savedmodel/saved_model.pb index 6cf4816..f2db6a6 100644 Binary files a/examples/regression_savedmodel/saved_model.pb and b/examples/regression_savedmodel/saved_model.pb differ diff --git a/examples/regression_savedmodel/variables/variables.data-00000-of-00001 b/examples/regression_savedmodel/variables/variables.data-00000-of-00001 index 6802057..7b50dfb 100644 Binary files a/examples/regression_savedmodel/variables/variables.data-00000-of-00001 and b/examples/regression_savedmodel/variables/variables.data-00000-of-00001 differ diff --git a/examples/regression_savedmodel/variables/variables.index b/examples/regression_savedmodel/variables/variables.index index 5c6e67f..03a65d8 100644 Binary files a/examples/regression_savedmodel/variables/variables.index and b/examples/regression_savedmodel/variables/variables.index differ