Add some python wrapper for TF_ApiDefMap.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 May 2018 21:24:47 +0000 (14:24 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 11 May 2018 21:27:53 +0000 (14:27 -0700)
PiperOrigin-RevId: 196308677

tensorflow/python/BUILD
tensorflow/python/framework/c_api_util.py
tensorflow/python/framework/c_api_util_test.py [new file with mode: 0644]

index cc96d5a..ea11b70 100644 (file)
@@ -627,6 +627,7 @@ py_library(
     srcs_version = "PY2AND3",
     deps = [
         ":pywrap_tensorflow",
+        "//tensorflow/core:protos_all_py",
     ],
 )
 
@@ -3972,6 +3973,18 @@ cuda_py_test(
 )
 
 py_test(
+    name = "c_api_util_test",
+    size = "small",
+    srcs = ["framework/c_api_util_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":c_api_util",
+        ":framework_test_lib",
+        ":platform_test",
+    ],
+)
+
+py_test(
     name = "graph_util_test",
     size = "small",
     srcs = ["framework/graph_util_test.py"],
index 7bbe318..aff289f 100644 (file)
@@ -19,6 +19,8 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.core.framework import api_def_pb2
+from tensorflow.core.framework import op_def_pb2
 from tensorflow.python import pywrap_tensorflow as c_api
 from tensorflow.python.util import compat
 from tensorflow.python.util import tf_contextlib
@@ -89,6 +91,50 @@ class ScopedTFFunction(object):
       c_api.TF_DeleteFunction(self.func)
 
 
+class ApiDefMap(object):
+  """Wrapper around Tf_ApiDefMap that handles querying and deletion.
+
+  The OpDef protos are also stored in this class so that they could
+  be queried by op name.
+  """
+
+  def __init__(self):
+    op_def_proto = op_def_pb2.OpList()
+    buf = c_api.TF_GetAllOpList()
+    try:
+      op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
+      self._api_def_map = c_api.TF_NewApiDefMap(buf)
+    finally:
+      c_api.TF_DeleteBuffer(buf)
+
+    self._op_per_name = {}
+    for op in op_def_proto.op:
+      self._op_per_name[op.name] = op
+
+  def __del__(self):
+    # Note: when we're destructing the global context (i.e when the process is
+    # terminating) we can have already deleted other modules.
+    if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
+      c_api.TF_DeleteApiDefMap(self._api_def_map)
+
+  def put_api_def(self, text):
+    c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
+
+  def get_api_def(self, op_name):
+    api_def_proto = api_def_pb2.ApiDef()
+    buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
+    try:
+      api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
+    finally:
+      c_api.TF_DeleteBuffer(buf)
+    return api_def_proto
+
+  def get_op_def(self, op_name):
+    if op_name in self._op_per_name:
+      return self._op_per_name[op_name]
+    raise ValueError("No entry found for " + op_name + ".")
+
+
 @tf_contextlib.contextmanager
 def tf_buffer(data=None):
   """Context manager that creates and deletes TF_Buffer.
diff --git a/tensorflow/python/framework/c_api_util_test.py b/tensorflow/python/framework/c_api_util_test.py
new file mode 100644 (file)
index 0000000..e0bc9ee
--- /dev/null
@@ -0,0 +1,55 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for c_api utils."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class ApiDefMapTest(test_util.TensorFlowTestCase):
+
+  def testApiDefMapGet(self):
+    api_def_map = c_api_util.ApiDefMap()
+    op_def = api_def_map.get_op_def("Add")
+    self.assertEqual(op_def.name, "Add")
+    api_def = api_def_map.get_api_def("Add")
+    self.assertEqual(api_def.graph_op_name, "Add")
+
+  def testApiDefMapPutThenGet(self):
+    api_def_map = c_api_util.ApiDefMap()
+    api_def_text = """
+op {
+  graph_op_name: "Add"
+  summary: "Returns x + y element-wise."
+  description: <<END
+*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+END
+}
+"""
+    api_def_map.put_api_def(api_def_text)
+    api_def = api_def_map.get_api_def("Add")
+    self.assertEqual(api_def.graph_op_name, "Add")
+    self.assertEqual(api_def.summary, "Returns x + y element-wise.")
+
+
+if __name__ == "__main__":
+  googletest.main()
+