srcs_version = "PY2AND3",
deps = [
":pywrap_tensorflow",
+ "//tensorflow/core:protos_all_py",
],
)
)
py_test(
+ name = "c_api_util_test",
+ size = "small",
+ srcs = ["framework/c_api_util_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":c_api_util",
+ ":framework_test_lib",
+ ":platform_test",
+ ],
+)
+
+py_test(
name = "graph_util_test",
size = "small",
srcs = ["framework/graph_util_test.py"],
from __future__ import division
from __future__ import print_function
+from tensorflow.core.framework import api_def_pb2
+from tensorflow.core.framework import op_def_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.util import compat
from tensorflow.python.util import tf_contextlib
c_api.TF_DeleteFunction(self.func)
+class ApiDefMap(object):
+ """Wrapper around Tf_ApiDefMap that handles querying and deletion.
+
+ The OpDef protos are also stored in this class so that they could
+ be queried by op name.
+ """
+
+ def __init__(self):
+ op_def_proto = op_def_pb2.OpList()
+ buf = c_api.TF_GetAllOpList()
+ try:
+ op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
+ self._api_def_map = c_api.TF_NewApiDefMap(buf)
+ finally:
+ c_api.TF_DeleteBuffer(buf)
+
+ self._op_per_name = {}
+ for op in op_def_proto.op:
+ self._op_per_name[op.name] = op
+
+ def __del__(self):
+ # Note: when we're destructing the global context (i.e when the process is
+ # terminating) we can have already deleted other modules.
+ if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
+ c_api.TF_DeleteApiDefMap(self._api_def_map)
+
+ def put_api_def(self, text):
+ c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
+
+ def get_api_def(self, op_name):
+ api_def_proto = api_def_pb2.ApiDef()
+ buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
+ try:
+ api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
+ finally:
+ c_api.TF_DeleteBuffer(buf)
+ return api_def_proto
+
+ def get_op_def(self, op_name):
+ if op_name in self._op_per_name:
+ return self._op_per_name[op_name]
+ raise ValueError("No entry found for " + op_name + ".")
+
+
@tf_contextlib.contextmanager
def tf_buffer(data=None):
"""Context manager that creates and deletes TF_Buffer.
--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for c_api utils."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import c_api_util
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+
+
+class ApiDefMapTest(test_util.TensorFlowTestCase):
+
+ def testApiDefMapGet(self):
+ api_def_map = c_api_util.ApiDefMap()
+ op_def = api_def_map.get_op_def("Add")
+ self.assertEqual(op_def.name, "Add")
+ api_def = api_def_map.get_api_def("Add")
+ self.assertEqual(api_def.graph_op_name, "Add")
+
+ def testApiDefMapPutThenGet(self):
+ api_def_map = c_api_util.ApiDefMap()
+ api_def_text = """
+op {
+ graph_op_name: "Add"
+ summary: "Returns x + y element-wise."
+ description: <<END
+*NOTE*: `Add` supports broadcasting. `AddN` does not. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+END
+}
+"""
+ api_def_map.put_api_def(api_def_text)
+ api_def = api_def_map.get_api_def("Add")
+ self.assertEqual(api_def.graph_op_name, "Add")
+ self.assertEqual(api_def.summary, "Returns x + y element-wise.")
+
+
+if __name__ == "__main__":
+ googletest.main()
+