deps = [
":c_api",
":c_api_internal",
+ # TODO(b/74620627): remove when _USE_C_SHAPES is removed
+ "//tensorflow/python:cpp_shape_inference_proto_cc",
],
)
#include "tensorflow/c/python_api.h"
#include "tensorflow/c/c_api_internal.h"
+#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
namespace tensorflow {
session->extend_before_run = false;
}
+std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+ Node* node = &output.oper->node;
+ CppShapeInferenceResult::HandleData handle_data;
+ handle_data.set_is_set(true);
+ {
+ mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(node);
+ CHECK(ic != nullptr);
+ CHECK_LT(output.index, ic->num_outputs());
+ const auto* shapes_and_types =
+ ic->output_handle_shapes_and_types(output.index);
+ if (shapes_and_types == nullptr) return "";
+
+ for (const auto& p : *shapes_and_types) {
+ auto* out_shape_and_type = handle_data.add_shape_and_type();
+ ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
+ out_shape_and_type->set_dtype(p.dtype);
+ }
+ }
+ string result;
+ handle_data.SerializeToString(&result);
+ return result;
+}
+
} // namespace tensorflow
#ifndef TENSORFLOW_C_PYTHON_API_H_
#define TENSORFLOW_C_PYTHON_API_H_
+#include <string>
+
#include "tensorflow/c/c_api.h"
// These functions can be removed without notice. They exist to facilitate some
// the graph after the session has been made aware of them.
void ExtendSession(TF_Session* session, TF_Status* status);
+// Returns the serialized CppShapeInferenceResult::HandleData proto for
+// `output` if its a resource tensor, or otherwise returns the empty string.
+// TODO(b/74620627): remove when _USE_C_SHAPES is removed
+std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
srcs = ["framework/cpp_shape_inference.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
+ # TODO(b/74620627): remove when _USE_C_SHAPES is removed
+ visibility = ["//tensorflow:internal"],
)
py_test(
%unignore TF_TryEvaluateConstant_wrapper;
%noexception TF_TryEvaluateConstant_wrapper;
%unignore ExtendSession;
+%unignore ResourceHandleShapeAndType;
%include "tensorflow/python/client/tf_session_helper.h"
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
import tensorflow.python.ops.nn_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
self.assertEqual(d._input_types, [dtypes.int32_ref, dtypes.int32])
self.assertEqual(d.outputs, [])
+ def testResources(self):
+ # Produce GraphDef containing a ops producing and consuming resources.
+ graph = ops.Graph()
+ with graph.as_default():
+ var = resource_variable_ops.ResourceVariable(1.0)
+ var_assign = var.assign(2.0)
+ # Use an op that requires handle shape to be set.
+ var_shape = resource_variable_ops.variable_shape(var.handle)
+ init = variables.global_variables_initializer()
+ graph_def = graph.as_graph_def()
+
+ # Import the GraphDef.
+ with ops.Graph().as_default():
+ # pylint: disable=unused-variable
+ imported_var, imported_assign, imported_shape, imported_init = (
+ importer.import_graph_def(
+ graph_def,
+ return_elements=[var.name, var_assign.name, var_shape.name,
+ init.name]))
+
+ # Make sure the handle shape is set on the imported variable.
+ new_var_shape = resource_variable_ops.variable_shape(imported_var)
+ # pylint: enable=unused-variable
+
+ # Run the imported graph.
+ # TODO(b/76173421): make this work (currently DCHECKS)
+ # with self.test_session() as sess:
+ # sess.run(imported_init)
+ # self.assertEqual(sess.run(imported_var), 1.0)
+ # self.assertEqual(sess.run(imported_assign), 2.0)
+ # self.assertEqual(list(sess.run(imported_shape)), [])
+ # self.assertEqual(list(sess.run(new_var_shape)), [])
+
def testWhileLoop(self):
# Produce GraphDef containing while loop.
graph = ops.Graph()
from tensorflow.python.eager import core
from tensorflow.python.eager import tape
from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
# Attributes used for C++ shape inference. Not inspected, only forwarded.
# If set, will be a HandleData object from cpp_shape_inference.proto.
+ # TODO(b/74620627): remove when _USE_C_SHAPES is removed
self._handle_data = None
self._id = uid()
shape_vector = [None if d == -1 else d for d in shape_vector]
output.set_shape(tensor_shape.TensorShape(shape_vector))
+ serialized = c_api.ResourceHandleShapeAndType(op._graph._c_graph,
+ output._as_tf_output())
+ if serialized:
+ output._handle_data = (
+ cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
+ compat.as_bytes(serialized)))
+ else:
+ output._handle_data = None
# TODO(skyewm): remove this when _USE_C_API flag is removed.
def _set_shapes_for_outputs(op):