FreezeSavedModel function: Get a frozen GraphDef, inputs, and outputs from a loaded...
authorSuharsh Sivakumar <suharshs@google.com>
Sun, 14 Jan 2018 11:38:40 +0000 (03:38 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Sun, 14 Jan 2018 11:42:18 +0000 (03:42 -0800)
#14567

PiperOrigin-RevId: 181887870

tensorflow/BUILD
tensorflow/cc/tools/BUILD [new file with mode: 0644]
tensorflow/cc/tools/freeze_saved_model.cc [new file with mode: 0644]
tensorflow/cc/tools/freeze_saved_model.h [new file with mode: 0644]
tensorflow/cc/tools/freeze_saved_model_test.cc [new file with mode: 0644]

index ca2d2397bfe99285aa551fe8aa69418d7657b13e..53c632a0a09c9efec6d25a56399e2369c5ca6f0f 100644 (file)
@@ -394,6 +394,7 @@ filegroup(
         "//tensorflow/cc:all_files",
         "//tensorflow/cc/saved_model:all_files",
         "//tensorflow/cc/saved_model/python:all_files",
+        "//tensorflow/cc/tools:all_files",
         "//tensorflow/compiler/aot:all_files",
         "//tensorflow/compiler/aot/tests:all_files",
         "//tensorflow/compiler/jit:all_files",
diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD
new file mode 100644 (file)
index 0000000..0a7c373
--- /dev/null
@@ -0,0 +1,58 @@
+# Description:
+# TensorFlow cc tools.
+
+package(
+    default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+    "//tensorflow:tensorflow.bzl",
+    "tf_cc_test",
+)
+
+cc_library(
+    name = "freeze_saved_model",
+    srcs = ["freeze_saved_model.cc"],
+    hdrs = ["freeze_saved_model.h"],
+    deps = [
+        "//tensorflow/cc/saved_model:loader",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:tensorflow",
+    ],
+)
+
+tf_cc_test(
+    name = "freeze_saved_model_test",
+    srcs = ["freeze_saved_model_test.cc"],
+    deps = [
+        ":freeze_saved_model",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
+# -----------------------------------------------------------------------------
+# Google-internal targets.
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc
new file mode 100644 (file)
index 0000000..ddf372c
--- /dev/null
@@ -0,0 +1,194 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/tools/freeze_saved_model.h"
+
+#include <queue>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Gets tensor names from tensor_info and inserts them into the set of tensor
+// names.
+void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info,
+                                  std::unordered_set<string>* tensor_names) {
+  if (tensor_info.has_coo_sparse()) {
+    // If the tensor is sparse we have to add all three tensors of the sparse
+    // representations.
+    const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse();
+    tensor_names->insert(coo_sparse.values_tensor_name());
+    tensor_names->insert(coo_sparse.indices_tensor_name());
+    tensor_names->insert(coo_sparse.dense_shape_tensor_name());
+  } else {
+    tensor_names->insert(tensor_info.name());
+  }
+}
+
+// Gets the union of all inputs and outputs of all SignatureDefs in the bundle
+void GetSignatureDefsInputsAndOutputs(
+    const SavedModelBundle& saved_model_bundle,
+    std::unordered_set<string>* inputs, std::unordered_set<string>* outputs) {
+  for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) {
+    const SignatureDef& signature_def = sigdef_elem.second;
+    for (auto& input_elem : signature_def.inputs()) {
+      GetTensorNamesFromTensorInfo(input_elem.second, inputs);
+    }
+    for (auto& output_elem : signature_def.outputs()) {
+      GetTensorNamesFromTensorInfo(output_elem.second, outputs);
+    }
+  }
+}
+
+// Gets a map from string node name to NodeDef.
+void GetNodeNameToNodeDefMap(
+    GraphDef* graph_def,
+    std::unordered_map<string, NodeDef*>* name_to_node_map) {
+  for (size_t i = 0; i < graph_def->node_size(); i++) {
+    NodeDef* node = graph_def->mutable_node(i);
+    (*name_to_node_map)[node->name()] = node;
+  }
+}
+
+// Gets the set of node names needed by `outputs` and the corresponding set of
+// variable nodes to convert.
+void GetReachableNodesAndVariables(
+    GraphDef* graph_def, const std::unordered_set<string>& outputs,
+    std::unordered_set<string>* reachable_node_names,
+    std::unordered_set<string>* variable_node_names) {
+  // TODO(suharshs): Add support for ResourceVariables.
+  static const std::unordered_set<string>* kVariableTypes =
+      new std::unordered_set<string>({"Variable", "VariableV2"});
+  // name_to_node_map is needed to get the inputs from the NodeDef corresponding
+  // the a string node name. These inputs are used when doing our backwards
+  // traversal.
+  std::unordered_map<string, NodeDef*> name_to_node_map;
+  GetNodeNameToNodeDefMap(graph_def, &name_to_node_map);
+  std::queue<string> nodes_to_visit;
+  for (const string& tensor_name : outputs) {
+    // We need to strip off the tensor part to get the node name.
+    std::vector<string> tensor_name_parts = str_util::Split(tensor_name, ':');
+    nodes_to_visit.push(tensor_name_parts[0]);
+  }
+  // We do a traversal backwards from the outputs specified in the MetaGraphDef.
+  while (!nodes_to_visit.empty()) {
+    const string node_name = nodes_to_visit.front();
+    nodes_to_visit.pop();
+    if (reachable_node_names->find(node_name) != reachable_node_names->end()) {
+      continue;
+    }
+    reachable_node_names->insert(node_name);
+    NodeDef* node = name_to_node_map[node_name];
+    if (kVariableTypes->find(node->op()) != kVariableTypes->end()) {
+      variable_node_names->insert(node->name());
+    }
+    for (const string& input : node->input()) {
+      nodes_to_visit.push(input);
+    }
+  }
+}
+
+// Gets a map from variable name to variable value.
+Status GetVariableNameToTensorMap(
+    Session* session, std::unordered_set<string> variable_names_set,
+    std::unordered_map<string, Tensor>* variable_name_to_value_map) {
+  if (variable_names_set.empty()) {
+    return Status::OK();
+  }
+  std::vector<string> variable_names;
+  std::vector<string> tensor_names;
+  for (const string& node_name : variable_names_set) {
+    variable_names.push_back(node_name);
+    // We need to run tensors, so append ":0".
+    tensor_names.push_back(node_name + ":0");
+  }
+  std::vector<Tensor> outputs;
+  TF_RETURN_IF_ERROR(
+      session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs));
+  for (size_t i = 0; i < variable_names.size(); i++) {
+    (*variable_name_to_value_map)[variable_names[i]] = outputs[i];
+  }
+  return Status::OK();
+}
+
+// Converts a Variable NodeDef into a Constant NodeDef.
+void ConvertVariableToConstant(const NodeDef& variable_node,
+                               const Tensor& variable_value,
+                               NodeDef* const_node) {
+  const_node->set_name(variable_node.name());
+  const_node->set_op("Const");
+  (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype");
+  variable_value.AsProtoTensorContent(
+      (*const_node->mutable_attr())["value"].mutable_tensor());
+}
+
+// Freezes the subgraph of all nodes needed by `outputs`.
+Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle,
+                      const std::unordered_set<string>& outputs,
+                      GraphDef* frozen_graph_def) {
+  GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def();
+  // Copy versions and library as-is from original graph.
+  *frozen_graph_def->mutable_versions() = graph_def.versions();
+  *frozen_graph_def->mutable_library() = graph_def.library();
+  // If the graph is empty there is nothing left to do.
+  if (graph_def.node_size() == 0) {
+    return Status::OK();
+  }
+  std::unordered_set<string> reachable_node_names;
+  std::unordered_set<string> variable_node_names;
+  GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names,
+                                &variable_node_names);
+  std::unordered_map<string, Tensor> variable_to_value_map;
+  TF_RETURN_IF_ERROR(
+      GetVariableNameToTensorMap(saved_model_bundle.session.get(),
+                                 variable_node_names, &variable_to_value_map));
+  // We copy the nodes in the same order they were in the original graph_def.
+  for (const NodeDef& node : graph_def.node()) {
+    if (reachable_node_names.find(node.name()) == reachable_node_names.end()) {
+      continue;
+    }
+    if (variable_node_names.find(node.name()) != variable_node_names.end()) {
+      ConvertVariableToConstant(node, variable_to_value_map[node.name()],
+                                frozen_graph_def->add_node());
+    } else {
+      // If the node isn't a variable, just copy the node as-is.
+      *frozen_graph_def->add_node() = node;
+    }
+  }
+  return Status::OK();
+}
+
+}  // namespace
+
+Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
+                        GraphDef* frozen_graph_def,
+                        std::unordered_set<string>* inputs,
+                        std::unordered_set<string>* outputs) {
+  GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs);
+  TF_RETURN_IF_ERROR(
+      FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def));
+  return Status::OK();
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/cc/tools/freeze_saved_model.h b/tensorflow/cc/tools/freeze_saved_model.h
new file mode 100644 (file)
index 0000000..bd5e051
--- /dev/null
@@ -0,0 +1,43 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
+#define THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
+
+#include <unordered_set>
+
+#include "tensorflow/cc/saved_model/loader.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Returns a frozen GraphDef, input tensors, and output tensors from the loaded
+// SavedModelBundle.
+// `inputs` and `outputs` consist of the union of all inputs and outputs in the
+// SignatureDefs in the SavedModelBundle.
+// FreezeSavedModel sets `frozen_graph_def` to a GraphDef of all nodes needed by
+// `outputs`. All variables in the supplied SavedModelBundle are converted to
+// constants, set to the value of the variables, by running the restored Session
+// in the SavedModelBundle.
+// WARNING: Only the variable checkpoints will be reflected in the frozen
+// graph_def. All saved_model assets will be ignored.
+Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle,
+                        GraphDef* frozen_graph_def,
+                        std::unordered_set<string>* inputs,
+                        std::unordered_set<string>* outputs);
+
+}  // namespace tensorflow
+
+#endif  // THIRD_PARTY_TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_
diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc
new file mode 100644 (file)
index 0000000..57244a4
--- /dev/null
@@ -0,0 +1,307 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/tools/freeze_saved_model.h"
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace {
+
+class FreezeTest : public ::testing::Test {
+ protected:
+  void GraphDefEqual(const GraphDef& actual, const GraphDef& expected) {
+    EXPECT_EQ(actual.ShortDebugString(), expected.ShortDebugString());
+  }
+
+  // Builds a SignatureDef with the provided `inputs` and `outputs`.
+  SignatureDef BuildSignatureDef(const std::unordered_set<string>& inputs,
+                                 const std::unordered_set<string>& outputs) {
+    SignatureDef signature_def;
+    for (const string& input : inputs) {
+      (*signature_def.mutable_inputs())[input].set_name(input);
+    }
+    for (const string& output : outputs) {
+      (*signature_def.mutable_outputs())[output].set_name(output);
+    }
+    return signature_def;
+  }
+
+  // Adds `signature_def` to `saved_model_bundle` under `key`.
+  void AddSignatureDefToSavedModelBundle(const SignatureDef& signature_def,
+                                         const string& key,
+                                         SavedModelBundle* saved_model_bundle) {
+    MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
+    (*meta_graph_def->mutable_signature_def())[key] = signature_def;
+  }
+
+  // Adds an initialized session to `saved_model_bundle` using `graph_def` and
+  // initializing with `init_node`.
+  Status InitializeSavedModelBundleSession(
+      const GraphDef& graph_def, const string& init_node,
+      SavedModelBundle* saved_model_bundle) {
+    SessionOptions session_options;
+    saved_model_bundle->session.reset(NewSession(session_options));
+    TF_RETURN_IF_ERROR(saved_model_bundle->session->Create(graph_def));
+    if (!init_node.empty()) {
+      std::vector<Tensor> outputs;
+      return saved_model_bundle->session->Run(
+          /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs);
+    }
+    return Status::OK();
+  }
+
+  // Adds `graph_def` to `saved_model_bundle` and intializes a session with
+  // `init_node`.
+  Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def,
+                                       const string& init_node,
+                                       SavedModelBundle* saved_model_bundle) {
+    MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def;
+    *meta_graph_def->mutable_graph_def() = graph_def;
+    return InitializeSavedModelBundleSession(graph_def, init_node,
+                                             saved_model_bundle);
+  }
+
+  // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in
+  // `saved_model_bundle` and initializes a session with `init_node`.
+  Status AddGraphDefWithOutputsToSavedModelBundle(
+      const GraphDef& graph_def, const std::unordered_set<string>& outputs,
+      const string& init_node, SavedModelBundle* saved_model_bundle) {
+    SignatureDef signature_def =
+        BuildSignatureDef(std::unordered_set<string>(), outputs);
+    AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
+                                      saved_model_bundle);
+    return AddGraphDefToSavedModelBundle(graph_def, init_node,
+                                         saved_model_bundle);
+  }
+
+  // Runs and compares the outputs of `tensor_name` on both the
+  // `unfrozen_session` and the `frozen_graph_def.
+  void RunAndCompareFrozenAndUnfrozenGraphs(Session* unfrozen_session,
+                                            const GraphDef& frozen_graph_def,
+                                            const string& tensor_name) {
+    std::vector<Tensor> unfrozen_outputs;
+    TF_ASSERT_OK(unfrozen_session->Run(/* inputs */ {}, {tensor_name},
+                                       /* targets */ {}, &unfrozen_outputs));
+
+    SessionOptions session_options;
+    std::unique_ptr<Session> frozen_session(NewSession(session_options));
+    TF_ASSERT_OK(frozen_session->Create(frozen_graph_def));
+    std::vector<Tensor> frozen_outputs;
+    TF_ASSERT_OK(frozen_session->Run(/* inputs */ {}, {tensor_name},
+                                     /* targets */ {}, &frozen_outputs));
+
+    test::ExpectTensorEqual<float>(unfrozen_outputs[0], frozen_outputs[0]);
+  }
+};
+
+TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) {
+  // Test that inputs and outputs get correctly populated for a single
+  // SignatureDef.
+  SavedModelBundle saved_model_bundle;
+  std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
+  std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
+  SignatureDef signature_def =
+      BuildSignatureDef(expected_inputs, expected_outputs);
+  AddSignatureDefToSavedModelBundle(signature_def, "signature_def",
+                                    &saved_model_bundle);
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+  EXPECT_EQ(expected_inputs, inputs);
+  EXPECT_EQ(expected_outputs, outputs);
+}
+
+TEST_F(FreezeTest, InputsAndOutputsMultipleSignatureDefs) {
+  // Test that inputs and outputs get correctly merged and populated when
+  // multiple SignatureDefs are provided.
+  SavedModelBundle saved_model_bundle;
+  SignatureDef signature_def_0 = BuildSignatureDef({"input0:0"}, {"output0:0"});
+  SignatureDef signature_def_1 = BuildSignatureDef({"input1:0"}, {"output1:0"});
+  AddSignatureDefToSavedModelBundle(signature_def_0, "signature_def_0",
+                                    &saved_model_bundle);
+  AddSignatureDefToSavedModelBundle(signature_def_1, "signature_def_1",
+                                    &saved_model_bundle);
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+  std::unordered_set<string> expected_inputs = {"input0:0", "input1:0"};
+  std::unordered_set<string> expected_outputs = {"output0:0", "output1:0"};
+  EXPECT_EQ(expected_inputs, inputs);
+  EXPECT_EQ(expected_outputs, outputs);
+}
+
+TEST_F(FreezeTest, GraphDefVersionsAndLibrary) {
+  // Test that GraphDef versions and library are copied correctly into the
+  // frozen graph.
+  SavedModelBundle saved_model_bundle;
+  GraphDef graph_def;
+  graph_def.mutable_versions()->set_producer(1234);
+  graph_def.mutable_versions()->set_min_consumer(1234);
+  *graph_def.mutable_library()->add_function() = test::function::NonZero();
+  TF_ASSERT_OK(
+      AddGraphDefToSavedModelBundle(graph_def, "", &saved_model_bundle));
+
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+
+  GraphDefEqual(frozen_graph_def, graph_def);
+}
+
+TEST_F(FreezeTest, GraphDefWithNoVariables) {
+  // Test freezing a graph with no variables.
+  SavedModelBundle saved_model_bundle;
+  GraphDef graph_def;
+  Scope scope = Scope::NewRootScope();
+  Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+  Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+  Output c = ops::Mul(scope.WithOpName("c"), a, b);
+  TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+  TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "",
+                                                        &saved_model_bundle));
+
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+
+  GraphDefEqual(frozen_graph_def, graph_def);
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) {
+  // Test freezing a graph with variables that are not needed by the outputs in
+  // the SignatureDef. The resulting graph shouldn't be frozen, but
+  // non-dependent nodes should be pruned.
+  SavedModelBundle saved_model_bundle;
+  GraphDef graph_def;
+  Scope scope = Scope::NewRootScope();
+  Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+  Output b = ops::Const(scope.WithOpName("b"), 10.0f, {});
+  Output c = ops::Mul(scope.WithOpName("c"), a, b);
+  Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+  Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+  TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+  // "c" isnt dependent on the variable, so nothing should be frozen.
+  TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+      graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+
+  GraphDef expected_graph_def;
+  Scope expected_scope = Scope::NewRootScope();
+  Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {});
+  Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {});
+  Output expected_c =
+      ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b);
+  TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def));
+
+  GraphDefEqual(frozen_graph_def, expected_graph_def);
+
+  RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+                                       frozen_graph_def, "c:0");
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) {
+  // Test freezing a graph with variables that are needed by outputs in the
+  // SignatureDef. The variables should be frozen.
+  SavedModelBundle saved_model_bundle;
+  GraphDef graph_def;
+  Scope scope = Scope::NewRootScope();
+  Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+  Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+  Output c = ops::Mul(scope.WithOpName("c"), a, var);
+  Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+  TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+  // "c" isnt dependent on the variable, so nothing should be frozen.
+  TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+      graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+
+  // There should be 3 nodes in the resulting graph_def, and none should be
+  // variables.
+  EXPECT_EQ(frozen_graph_def.node_size(), 3);
+  for (const NodeDef& node : frozen_graph_def.node()) {
+    EXPECT_NE(node.op(), "Variable") << node.name();
+    EXPECT_NE(node.op(), "VariableV2") << node.name();
+  }
+
+  RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+                                       frozen_graph_def, "c:0");
+}
+
+TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) {
+  // Test freezing a graph with some variables that are needed and not needed by
+  // the outputs in the SignatureDef. The resulting graph should only freeze
+  // dependent variables.
+  SavedModelBundle saved_model_bundle;
+  GraphDef graph_def;
+  Scope scope = Scope::NewRootScope();
+  Output a = ops::Const(scope.WithOpName("a"), 10.0f, {});
+  Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT);
+  Output c = ops::Mul(scope.WithOpName("c"), a, var);
+  Output assign = ops::Assign(scope.WithOpName("assign"), var, a);
+  Output var_1 =
+      ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT);
+  Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a);
+  TF_ASSERT_OK(scope.ToGraphDef(&graph_def));
+  // "c" isnt dependent on the variable, so nothing should be frozen.
+  TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(
+      graph_def, {"c:0"}, assign.name(), &saved_model_bundle));
+
+  GraphDef frozen_graph_def;
+  std::unordered_set<string> inputs;
+  std::unordered_set<string> outputs;
+  TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs,
+                                &outputs));
+
+  // There should be 3 nodes in the resulting graph_def, and none should be
+  // variables.
+  EXPECT_EQ(frozen_graph_def.node_size(), 3);
+  for (const NodeDef& node : frozen_graph_def.node()) {
+    EXPECT_NE(node.op(), "Variable") << node.name();
+    EXPECT_NE(node.op(), "VariableV2") << node.name();
+  }
+
+  RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(),
+                                       frozen_graph_def, "c:0");
+}
+
+}  // namespace
+}  // namespace tensorflow