Make _USE_C_API = True and _USE_C_SHAPES = False work with handle data, take 2.
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 27 Mar 2018 22:07:05 +0000 (15:07 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 27 Mar 2018 22:09:19 +0000 (15:09 -0700)
This change makes _set_shapes_for_outputs_c_api fetch and set
Tensor._handle_data. This is necessary for running the
Python shape inference code on resource tensors.

PiperOrigin-RevId: 190681459

tensorflow/c/BUILD
tensorflow/c/python_api.cc
tensorflow/c/python_api.h
tensorflow/python/BUILD
tensorflow/python/client/tf_session.i
tensorflow/python/framework/importer_test.py
tensorflow/python/framework/ops.py

index 7f03e40..249135f 100644 (file)
@@ -283,6 +283,8 @@ tf_cuda_library(
     deps = [
         ":c_api",
         ":c_api_internal",
+        # TODO(b/74620627): remove when _USE_C_SHAPES is removed
+        "//tensorflow/python:cpp_shape_inference_proto_cc",
     ],
 )
 
index cd60453..9315599 100644 (file)
@@ -16,6 +16,7 @@ limitations under the License.
 #include "tensorflow/c/python_api.h"
 
 #include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
 
 namespace tensorflow {
 
@@ -109,4 +110,29 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
   session->extend_before_run = false;
 }
 
+std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+  Node* node = &output.oper->node;
+  CppShapeInferenceResult::HandleData handle_data;
+  handle_data.set_is_set(true);
+  {
+    mutex_lock l(graph->mu);
+    tensorflow::shape_inference::InferenceContext* ic =
+        graph->refiner.GetContext(node);
+    CHECK(ic != nullptr);
+    CHECK_LT(output.index, ic->num_outputs());
+    const auto* shapes_and_types =
+        ic->output_handle_shapes_and_types(output.index);
+    if (shapes_and_types == nullptr) return "";
+
+    for (const auto& p : *shapes_and_types) {
+      auto* out_shape_and_type = handle_data.add_shape_and_type();
+      ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
+      out_shape_and_type->set_dtype(p.dtype);
+    }
+  }
+  string result;
+  handle_data.SerializeToString(&result);
+  return result;
+}
+
 }  // namespace tensorflow
index 13b680b..2d4c8cd 100644 (file)
@@ -16,6 +16,8 @@ limitations under the License.
 #ifndef TENSORFLOW_C_PYTHON_API_H_
 #define TENSORFLOW_C_PYTHON_API_H_
 
+#include <string>
+
 #include "tensorflow/c/c_api.h"
 
 // These functions can be removed without notice. They exist to facilitate some
@@ -51,6 +53,11 @@ void SetRequireShapeInferenceFns(TF_Graph* graph, bool require);
 // the graph after the session has been made aware of them.
 void ExtendSession(TF_Session* session, TF_Status* status);
 
+// Returns the serialized CppShapeInferenceResult::HandleData proto for
+// `output` if its a resource tensor, or otherwise returns the empty string.
+// TODO(b/74620627): remove when _USE_C_SHAPES is removed
+std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_C_PYTHON_API_H_
index 20d7e81..4f61c01 100644 (file)
@@ -3144,6 +3144,8 @@ tf_proto_library(
     srcs = ["framework/cpp_shape_inference.proto"],
     cc_api_version = 2,
     protodeps = tf_additional_all_protos(),
+    # TODO(b/74620627): remove when _USE_C_SHAPES is removed
+    visibility = ["//tensorflow:internal"],
 )
 
 py_test(
index e88fc0c..70a3d03 100644 (file)
@@ -723,6 +723,7 @@ def TF_Reset(target, containers=None, config=None):
 %unignore TF_TryEvaluateConstant_wrapper;
 %noexception TF_TryEvaluateConstant_wrapper;
 %unignore ExtendSession;
+%unignore ResourceHandleShapeAndType;
 
 %include "tensorflow/python/client/tf_session_helper.h"
 
index 6593b17..369669c 100644 (file)
@@ -39,6 +39,7 @@ from tensorflow.python.ops import gradients_impl
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import nn_ops
 from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import resource_variable_ops
 from tensorflow.python.ops import variables
 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
 from tensorflow.python.platform import test
@@ -356,6 +357,39 @@ class ImportGraphDefTest(test.TestCase):
       self.assertEqual(d._input_types, [dtypes.int32_ref, dtypes.int32])
       self.assertEqual(d.outputs, [])
 
+  def testResources(self):
+    # Produce GraphDef containing a ops producing and consuming resources.
+    graph = ops.Graph()
+    with graph.as_default():
+      var = resource_variable_ops.ResourceVariable(1.0)
+      var_assign = var.assign(2.0)
+      # Use an op that requires handle shape to be set.
+      var_shape = resource_variable_ops.variable_shape(var.handle)
+      init = variables.global_variables_initializer()
+    graph_def = graph.as_graph_def()
+
+    # Import the GraphDef.
+    with ops.Graph().as_default():
+      # pylint: disable=unused-variable
+      imported_var, imported_assign, imported_shape, imported_init = (
+          importer.import_graph_def(
+              graph_def,
+              return_elements=[var.name, var_assign.name, var_shape.name,
+                               init.name]))
+
+      # Make sure the handle shape is set on the imported variable.
+      new_var_shape = resource_variable_ops.variable_shape(imported_var)
+      # pylint: enable=unused-variable
+
+      # Run the imported graph.
+      # TODO(b/76173421): make this work (currently DCHECKS)
+      # with self.test_session() as sess:
+      #   sess.run(imported_init)
+      #   self.assertEqual(sess.run(imported_var), 1.0)
+      #   self.assertEqual(sess.run(imported_assign), 2.0)
+      #   self.assertEqual(list(sess.run(imported_shape)), [])
+      #   self.assertEqual(list(sess.run(new_var_shape)), [])
+
   def testWhileLoop(self):
     # Produce GraphDef containing while loop.
     graph = ops.Graph()
index 25a951a..4b0f3f3 100644 (file)
@@ -42,6 +42,7 @@ from tensorflow.python.eager import context
 from tensorflow.python.eager import core
 from tensorflow.python.eager import tape
 from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import cpp_shape_inference_pb2
 from tensorflow.python.framework import device as pydev
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import errors
@@ -295,6 +296,7 @@ class Tensor(_TensorLike):
 
     # Attributes used for C++ shape inference. Not inspected, only forwarded.
     # If set, will be a HandleData object from cpp_shape_inference.proto.
+    # TODO(b/74620627): remove when _USE_C_SHAPES is removed
     self._handle_data = None
     self._id = uid()
 
@@ -2472,6 +2474,14 @@ def _set_shapes_for_outputs_c_api(op):
       shape_vector = [None if d == -1 else d for d in shape_vector]
       output.set_shape(tensor_shape.TensorShape(shape_vector))
 
+    serialized = c_api.ResourceHandleShapeAndType(op._graph._c_graph,
+                                                  output._as_tf_output())
+    if serialized:
+      output._handle_data = (
+          cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
+              compat.as_bytes(serialized)))
+    else:
+      output._handle_data = None
 
 # TODO(skyewm): remove this when _USE_C_API flag is removed.
 def _set_shapes_for_outputs(op):