# containing the wrappers.
add_custom_command(
OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION}
- COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION}
+ COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION}
DEPENDS ${tf_python_op_lib_name}_gen_python
)
srcs = ["api_test.cc"],
data = [
":base_api_def",
+ ":python_api_def",
],
deps = [
":excluded_ops_lib",
namespace {
constexpr char kDefaultApiDefDir[] =
"tensorflow/core/api_def/base_api";
+constexpr char kPythonApiDefDir[] =
+ "tensorflow/core/api_def/python_api";
constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt";
-} // namespace
// Reads golden ApiDef files and returns a map from file name to ApiDef file
// contents.
}
}
-class ApiTest : public ::testing::Test {
+void TestAllApiDefsHaveCorrespondingOp(
+ const OpList& ops, const std::unordered_map<string, ApiDef>& api_defs_map) {
+ std::unordered_set<string> op_names;
+ for (const auto& op : ops.op()) {
+ op_names.insert(op.name());
+ }
+ for (const auto& name_and_api_def : api_defs_map) {
+ ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end())
+ << name_and_api_def.first << " op has ApiDef but missing from ops. "
+ << "Does api_def_" << name_and_api_def.first << " need to be deleted?";
+ }
+}
+
+void TestAllApiDefInputArgsAreValid(
+ const OpList& ops, const std::unordered_map<string, ApiDef>& api_defs_map) {
+ for (const auto& op : ops.op()) {
+ const auto api_def_iter = api_defs_map.find(op.name());
+ if (api_def_iter == api_defs_map.end()) {
+ continue;
+ }
+ const auto& api_def = api_def_iter->second;
+ for (const auto& api_def_arg : api_def.in_arg()) {
+ bool found_arg = false;
+ for (const auto& op_arg : op.input_arg()) {
+ if (api_def_arg.name() == op_arg.name()) {
+ found_arg = true;
+ break;
+ }
+ }
+ ASSERT_TRUE(found_arg)
+ << "Input argument " << api_def_arg.name()
+ << " (overwritten in api_def_" << op.name()
+ << ".pbtxt) is not defined in OpDef for " << op.name();
+ }
+ }
+}
+
+void TestAllApiDefOutputArgsAreValid(
+ const OpList& ops, const std::unordered_map<string, ApiDef>& api_defs_map) {
+ for (const auto& op : ops.op()) {
+ const auto api_def_iter = api_defs_map.find(op.name());
+ if (api_def_iter == api_defs_map.end()) {
+ continue;
+ }
+ const auto& api_def = api_def_iter->second;
+ for (const auto& api_def_arg : api_def.out_arg()) {
+ bool found_arg = false;
+ for (const auto& op_arg : op.output_arg()) {
+ if (api_def_arg.name() == op_arg.name()) {
+ found_arg = true;
+ break;
+ }
+ }
+ ASSERT_TRUE(found_arg)
+ << "Output argument " << api_def_arg.name()
+ << " (overwritten in api_def_" << op.name()
+ << ".pbtxt) is not defined in OpDef for " << op.name();
+ }
+ }
+}
+
+void TestAllApiDefAttributeNamesAreValid(
+ const OpList& ops, const std::unordered_map<string, ApiDef>& api_defs_map) {
+ for (const auto& op : ops.op()) {
+ const auto api_def_iter = api_defs_map.find(op.name());
+ if (api_def_iter == api_defs_map.end()) {
+ continue;
+ }
+ const auto& api_def = api_def_iter->second;
+ for (const auto& api_def_attr : api_def.attr()) {
+ bool found_attr = false;
+ for (const auto& op_attr : op.attr()) {
+ if (api_def_attr.name() == op_attr.name()) {
+ found_attr = true;
+ }
+ }
+ ASSERT_TRUE(found_attr)
+ << "Attribute " << api_def_attr.name() << " (overwritten in api_def_"
+ << op.name() << ".pbtxt) is not defined in OpDef for " << op.name();
+ }
+ }
+}
+} // namespace
+
+class BaseApiTest : public ::testing::Test {
protected:
- ApiTest() {
+ BaseApiTest() {
OpRegistry::Global()->Export(false, &ops_);
const std::vector<string> multi_line_fields = {"description"};
};
// Check that all ops have an ApiDef.
-TEST_F(ApiTest, AllOpsAreInApiDef) {
+TEST_F(BaseApiTest, AllOpsAreInApiDef) {
auto* excluded_ops = GetExcludedOps();
for (const auto& op : ops_.op()) {
if (excluded_ops->find(op.name()) != excluded_ops->end()) {
}
// Check that ApiDefs have a corresponding op.
-TEST_F(ApiTest, AllApiDefsHaveCorrespondingOp) {
- std::unordered_set<string> op_names;
- for (const auto& op : ops_.op()) {
- op_names.insert(op.name());
- }
- for (const auto& name_and_api_def : api_defs_map_) {
- ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end())
- << name_and_api_def.first << " op has ApiDef but missing from ops. "
- << "Does api_def_" << name_and_api_def.first << " need to be deleted?";
- }
+TEST_F(BaseApiTest, AllApiDefsHaveCorrespondingOp) {
+ TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_);
}
string GetOpDefHasDocStringError(const string& op_name) {
// Check that OpDef's do not have descriptions and summaries.
// Descriptions and summaries must be in corresponding ApiDefs.
-TEST_F(ApiTest, OpDefsShouldNotHaveDocs) {
+TEST_F(BaseApiTest, OpDefsShouldNotHaveDocs) {
auto* excluded_ops = GetExcludedOps();
for (const auto& op : ops_.op()) {
if (excluded_ops->find(op.name()) != excluded_ops->end()) {
// Checks that input arg names in an ApiDef match input
// arg names in corresponding OpDef.
-TEST_F(ApiTest, AllApiDefInputArgsAreValid) {
- for (const auto& op : ops_.op()) {
- const auto& api_def = api_defs_map_[op.name()];
- for (const auto& api_def_arg : api_def.in_arg()) {
- bool found_arg = false;
- for (const auto& op_arg : op.input_arg()) {
- if (api_def_arg.name() == op_arg.name()) {
- found_arg = true;
- break;
- }
- }
- ASSERT_TRUE(found_arg)
- << "Input argument " << api_def_arg.name()
- << " (overwritten in api_def_" << op.name()
- << ".pbtxt) is not defined in OpDef for " << op.name();
- }
- }
+TEST_F(BaseApiTest, AllApiDefInputArgsAreValid) {
+ TestAllApiDefInputArgsAreValid(ops_, api_defs_map_);
}
// Checks that output arg names in an ApiDef match output
// arg names in corresponding OpDef.
-TEST_F(ApiTest, AllApiDefOutputArgsAreValid) {
- for (const auto& op : ops_.op()) {
- const auto& api_def = api_defs_map_[op.name()];
- for (const auto& api_def_arg : api_def.out_arg()) {
- bool found_arg = false;
- for (const auto& op_arg : op.output_arg()) {
- if (api_def_arg.name() == op_arg.name()) {
- found_arg = true;
- break;
- }
- }
- ASSERT_TRUE(found_arg)
- << "Output argument " << api_def_arg.name()
- << " (overwritten in api_def_" << op.name()
- << ".pbtxt) is not defined in OpDef for " << op.name();
- }
- }
+TEST_F(BaseApiTest, AllApiDefOutputArgsAreValid) {
+ TestAllApiDefOutputArgsAreValid(ops_, api_defs_map_);
}
// Checks that attribute names in an ApiDef match attribute
// names in corresponding OpDef.
-TEST_F(ApiTest, AllApiDefAttributeNamesAreValid) {
- for (const auto& op : ops_.op()) {
- const auto& api_def = api_defs_map_[op.name()];
- for (const auto& api_def_attr : api_def.attr()) {
- bool found_attr = false;
- for (const auto& op_attr : op.attr()) {
- if (api_def_attr.name() == op_attr.name()) {
- found_attr = true;
- }
- }
- ASSERT_TRUE(found_attr)
- << "Attribute " << api_def_attr.name() << " (overwritten in api_def_"
- << op.name() << ".pbtxt) is not defined in OpDef for " << op.name();
- }
+TEST_F(BaseApiTest, AllApiDefAttributeNamesAreValid) {
+ TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_);
+}
+
+class PythonApiTest : public ::testing::Test {
+ protected:
+ PythonApiTest() {
+ OpRegistry::Global()->Export(false, &ops_);
+ const std::vector<string> multi_line_fields = {"description"};
+
+ Env* env = Env::Default();
+ GetGoldenApiDefs(env, kPythonApiDefDir, &api_defs_map_);
}
+ OpList ops_;
+ std::unordered_map<string, ApiDef> api_defs_map_;
+};
+
+// Check that ApiDefs have a corresponding op.
+TEST_F(PythonApiTest, AllApiDefsHaveCorrespondingOp) {
+ TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_);
}
+
+// Checks that input arg names in an ApiDef match input
+// arg names in corresponding OpDef.
+TEST_F(PythonApiTest, AllApiDefInputArgsAreValid) {
+ TestAllApiDefInputArgsAreValid(ops_, api_defs_map_);
+}
+
+// Checks that output arg names in an ApiDef match output
+// arg names in corresponding OpDef.
+TEST_F(PythonApiTest, AllApiDefOutputArgsAreValid) {
+ TestAllApiDefOutputArgsAreValid(ops_, api_defs_map_);
+}
+
+// Checks that attribute names in an ApiDef match attribute
+// names in corresponding OpDef.
+TEST_F(PythonApiTest, AllApiDefAttributeNamesAreValid) {
+ TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_);
+}
+
} // namespace tensorflow
visibility = ["//tensorflow:__subpackages__"],
)
-filegroup(
- name = "hidden_ops",
- srcs = ["ops/hidden_ops.txt"],
- visibility = ["//tensorflow:__subpackages__"],
-)
-
cuda_py_test(
name = "accumulate_n_benchmark",
size = "large",
bare_op_name = name[:-4] # Strip off the _gen
tf_gen_op_wrapper_py(name=bare_op_name,
out=out,
- hidden_file="ops/hidden_ops.txt",
visibility=visibility,
deps=deps,
require_shape_functions=require_shape_functions,
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
# Import resource_variable_ops for the variables-to-tensor implicit conversion.
def runTestBuildGraphError(self, sess):
# Ensure that errors from building the graph get propagated.
data = array_ops.placeholder(dtypes.float32, shape=[])
- enter_1 = control_flow_ops.enter(data, 'foo_1', False)
- enter_2 = control_flow_ops.enter(data, 'foo_2', False)
+ # pylint: disable=protected-access
+ enter_1 = gen_control_flow_ops._enter(data, 'foo_1', False)
+ enter_2 = gen_control_flow_ops._enter(data, 'foo_2', False)
+ # pylint: enable=protected-access
res = math_ops.add(enter_1, enter_2)
with self.assertRaisesOpError('has inputs from different frames'):
sess.run(res, feed_dict={data: 1.0})
auto out = cleaned_ops.mutable_op();
out->Reserve(ops.op_size());
for (const auto& op_def : ops.op()) {
- bool is_hidden = false;
- for (const string& hidden : hidden_ops) {
- if (op_def.name() == hidden) {
- is_hidden = true;
- break;
+ const auto* api_def = api_defs.GetApiDef(op_def.name());
+
+ if (api_def->visibility() == ApiDef::SKIP) {
+ continue;
+ }
+
+ // An op is hidden if either its ApiDef visibility is HIDDEN
+ // or it is in the hidden_ops list.
+ bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
+ if (!is_hidden) {
+ for (const string& hidden : hidden_ops) {
+ if (op_def.name() == hidden) {
+ is_hidden = true;
+ break;
+ }
}
}
continue;
}
- const auto* api_def = api_defs.GetApiDef(op_def.name());
strings::StrAppend(&result,
GetEagerPythonOp(op_def, *api_def, function_name));
GenPythonOp::~GenPythonOp() {}
string GenPythonOp::Code() {
- if (api_def_.visibility() == ApiDef::SKIP) {
- return "";
- }
// This has all the input args followed by those attrs that don't have
// defaults.
std::vector<ParamNames> params_no_default;
auto out = cleaned_ops.mutable_op();
out->Reserve(ops.op_size());
for (const auto& op_def : ops.op()) {
- bool is_hidden = false;
- for (const string& hidden : hidden_ops) {
- if (op_def.name() == hidden) {
- is_hidden = true;
- break;
+ const auto* api_def = api_defs.GetApiDef(op_def.name());
+
+ if (api_def->visibility() == ApiDef::SKIP) {
+ continue;
+ }
+
+ // An op is hidden if either its ApiDef visibility is HIDDEN
+ // or it is in the hidden_ops list.
+ bool is_hidden = api_def->visibility() == ApiDef::HIDDEN;
+ if (!is_hidden) {
+ for (const string& hidden : hidden_ops) {
+ if (op_def.name() == hidden) {
+ is_hidden = true;
+ break;
+ }
}
}
continue;
}
- const auto* api_def = api_defs.GetApiDef(op_def.name());
strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name));
if (!require_shapes) {
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_logging_ops
from tensorflow.python.ops import gen_state_ops
enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
nine = constant_op.constant(9)
- enter_nine = control_flow_ops.enter(nine, "foo_1")
+ enter_nine = gen_control_flow_ops._enter(nine, "foo_1")
op = state_ops.assign(enter_v, enter_nine)
v2 = control_flow_ops.with_dependencies([op], enter_v)
v3 = control_flow_ops.exit(v2)
def testEnterMulExit(self):
with self.test_session():
data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data")
- enter_data = control_flow_ops.enter(data, "foo_1", False)
+ enter_data = gen_control_flow_ops._enter(data, "foo_1", False)
five = constant_op.constant(5)
- enter_five = control_flow_ops.enter(five, "foo_1", False)
+ enter_five = gen_control_flow_ops._enter(five, "foo_1", False)
mul_op = math_ops.multiply(enter_data, enter_five)
exit_op = control_flow_ops.exit(mul_op)
v = variables.Variable([0.0, 0.0], dtype=dtypes.float32)
# If is_constant=True, the shape information should be propagated.
- enter_v_constant = control_flow_ops.enter(v, "frame1", is_constant=True)
+ enter_v_constant = gen_control_flow_ops._enter(
+ v, "frame1", is_constant=True)
self.assertEqual(enter_v_constant.shape, [2])
# Otherwise, the shape should be unknown.
- enter_v_non_constant = control_flow_ops.enter(
+ enter_v_non_constant = gen_control_flow_ops._enter(
v, "frame2", is_constant=False)
self.assertEqual(enter_v_non_constant.shape, None)
false = ops.convert_to_tensor(False)
n = constant_op.constant(10)
- enter_false = control_flow_ops.enter(false, "foo_1", False)
- enter_n = control_flow_ops.enter(n, "foo_1", False)
+ enter_false = gen_control_flow_ops._enter(false, "foo_1", False)
+ enter_n = gen_control_flow_ops._enter(n, "foo_1", False)
merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0]
switch_n = control_flow_ops.switch(merge_n, enter_false)
one = constant_op.constant(1)
n = constant_op.constant(10)
- enter_i = control_flow_ops.enter(zero, "foo", False)
- enter_one = control_flow_ops.enter(one, "foo", True)
- enter_n = control_flow_ops.enter(n, "foo", True)
+ enter_i = gen_control_flow_ops._enter(zero, "foo", False)
+ enter_one = gen_control_flow_ops._enter(one, "foo", True)
+ enter_n = gen_control_flow_ops._enter(n, "foo", True)
with ops.device(test.gpu_device_name()):
merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
one = constant_op.constant(1)
n = constant_op.constant(10)
- enter_i = control_flow_ops.enter(zero, "foo", False)
- enter_one = control_flow_ops.enter(one, "foo", True)
- enter_n = control_flow_ops.enter(n, "foo", True)
+ enter_i = gen_control_flow_ops._enter(zero, "foo", False)
+ enter_one = gen_control_flow_ops._enter(one, "foo", True)
+ enter_n = gen_control_flow_ops._enter(n, "foo", True)
merge_i = control_flow_ops.merge([enter_i, enter_i])[0]
def testDifferentFrame(self):
with self.test_session():
data = array_ops.placeholder(dtypes.float32, shape=[])
- enter_1 = control_flow_ops.enter(data, "foo_1", False)
- enter_2 = control_flow_ops.enter(data, "foo_2", False)
+ enter_1 = gen_control_flow_ops._enter(data, "foo_1", False)
+ enter_2 = gen_control_flow_ops._enter(data, "foo_2", False)
res = math_ops.add(enter_1, enter_2)
with self.assertRaisesOpError("has inputs from different frames"):
res.eval(feed_dict={data: 1.0})
self.assertFalse(control_flow_util.IsSwitch(test_ops.int_output().op))
def testIsLoopEnter(self):
- enter = gen_control_flow_ops.enter(1, frame_name="name").op
+ enter = gen_control_flow_ops._enter(1, frame_name="name").op
self.assertTrue(control_flow_util.IsLoopEnter(enter))
self.assertFalse(control_flow_util.IsLoopConstantEnter(enter))
- ref_enter = gen_control_flow_ops.ref_enter(test_ops.ref_output(),
- frame_name="name").op
+ ref_enter = gen_control_flow_ops._ref_enter(test_ops.ref_output(),
+ frame_name="name").op
self.assertTrue(control_flow_util.IsLoopEnter(ref_enter))
self.assertFalse(control_flow_util.IsLoopConstantEnter(ref_enter))
- const_enter = gen_control_flow_ops.enter(1, frame_name="name",
- is_constant=True).op
+ const_enter = gen_control_flow_ops._enter(1, frame_name="name",
+ is_constant=True).op
self.assertTrue(control_flow_util.IsLoopEnter(const_enter))
self.assertTrue(control_flow_util.IsLoopConstantEnter(const_enter))
data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True)
if isinstance(data, ops.Tensor):
if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access
- result = ref_enter(
+ result = gen_control_flow_ops._ref_enter(
data, frame_name, is_constant, parallel_iterations, name=name)
else:
- result = enter(
+ result = gen_control_flow_ops._enter(
data, frame_name, is_constant, parallel_iterations, name=name)
if use_input_shape:
result.set_shape(data.get_shape())
parallel_iterations=parallel_iterations,
use_input_shape=use_input_shape,
name=name)
- indices = enter(
+ indices = gen_control_flow_ops._enter(
data.indices,
frame_name,
is_constant,
if isinstance(data, ops.IndexedSlices):
dense_shape = data.dense_shape
if dense_shape is not None:
- dense_shape = enter(
+ dense_shape = gen_control_flow_ops._enter(
dense_shape,
frame_name,
is_constant,
dense_shape.set_shape(data.dense_shape.get_shape())
return ops.IndexedSlices(values, indices, dense_shape)
else:
- dense_shape = enter(
+ dense_shape = gen_control_flow_ops._enter(
data.dense_shape,
frame_name,
is_constant,
name = "api_compatibility_test",
srcs = ["api_compatibility_test.py"],
data = [
- ":convert_from_multiline",
- "//tensorflow/core/api_def:base_api_def",
- "//tensorflow/core/api_def:python_api_def",
- "//tensorflow/python:hidden_ops",
"//tensorflow/tools/api/golden:api_golden",
"//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt",
"//tensorflow/tools/api/tests:README.txt",
deps = [
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/tools/api/lib:python_object_to_proto_visitor",
from __future__ import print_function
import argparse
-from collections import defaultdict
import os
import re
-import subprocess
import sys
import unittest
from google.protobuf import text_format
-from tensorflow.core.framework import api_def_pb2
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
_TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt'
_UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt'
-_CONVERT_FROM_MULTILINE_SCRIPT = 'tensorflow/tools/api/tests/convert_from_multiline'
-_BASE_API_DIR = 'tensorflow/core/api_def/base_api'
-_PYTHON_API_DIR = 'tensorflow/core/api_def/python_api'
-_HIDDEN_OPS_FILE = 'tensorflow/python/ops/hidden_ops.txt'
-
def _KeyToFilePath(key):
"""From a given key, construct a filepath."""
return api_object_key
-def _GetSymbol(symbol_id):
- """Get TensorFlow symbol based on the given identifier.
-
- Args:
- symbol_id: Symbol identifier in the form module1.module2. ... .sym.
-
- Returns:
- Symbol corresponding to the given id.
- """
- # Ignore first module which should be tensorflow
- symbol_id_split = symbol_id.split('.')[1:]
- symbol = tf
- for sym in symbol_id_split:
- symbol = getattr(symbol, sym)
- return symbol
-
-
-def _IsGenModule(module_name):
- if not module_name:
- return False
- module_name_split = module_name.split('.')
- return module_name_split[-1].startswith('gen_')
-
-
-def _GetHiddenOps():
- hidden_ops_file = file_io.FileIO(_HIDDEN_OPS_FILE, 'r')
- hidden_ops = set()
- for line in hidden_ops_file:
- line = line.strip()
- if not line:
- continue
- if line[0] == '#': # comment line
- continue
- # If line is of the form "op_name # comment", only keep the op_name.
- line_split = line.split('#')
- hidden_ops.add(line_split[0].strip())
- return hidden_ops
-
-
-def _GetGoldenApiDefs():
- old_api_def_files = file_io.get_matching_files(_GetApiDefFilePath('*'))
- return {file_path: file_io.read_file_to_string(file_path)
- for file_path in old_api_def_files}
-
-
-def _GetApiDefFilePath(graph_op_name):
- return os.path.join(_PYTHON_API_DIR, 'api_def_%s.pbtxt' % graph_op_name)
-
-
class ApiCompatibilityTest(test.TestCase):
def __init__(self, *args, **kwargs):
update_goldens=FLAGS.update_goldens)
-class ApiDefTest(test.TestCase):
-
- def __init__(self, *args, **kwargs):
- super(ApiDefTest, self).__init__(*args, **kwargs)
- self._first_cap_pattern = re.compile('(.)([A-Z][a-z]+)')
- self._all_cap_pattern = re.compile('([a-z0-9])([A-Z])')
-
- def _GenerateLowerCaseOpName(self, op_name):
- lower_case_name = self._first_cap_pattern.sub(r'\1_\2', op_name)
- return self._all_cap_pattern.sub(r'\1_\2', lower_case_name).lower()
-
- def _CreatePythonApiDef(self, base_api_def, endpoint_names):
- """Creates Python ApiDef that overrides base_api_def if needed.
-
- Args:
- base_api_def: (api_def_pb2.ApiDef) base ApiDef instance.
- endpoint_names: List of Python endpoint names.
-
- Returns:
- api_def_pb2.ApiDef instance with overrides for base_api_def
- if module.name endpoint is different from any existing
- endpoints in base_api_def. Otherwise, returns None.
- """
- endpoint_names_set = set(endpoint_names)
-
- # If the only endpoint is equal to graph_op_name then
- # it is equivalent to having no endpoints.
- if (not base_api_def.endpoint and len(endpoint_names) == 1
- and endpoint_names[0] ==
- self._GenerateLowerCaseOpName(base_api_def.graph_op_name)):
- return None
-
- base_endpoint_names_set = {
- self._GenerateLowerCaseOpName(endpoint.name)
- for endpoint in base_api_def.endpoint}
-
- if endpoint_names_set == base_endpoint_names_set:
- return None # All endpoints are the same
-
- api_def = api_def_pb2.ApiDef()
- api_def.graph_op_name = base_api_def.graph_op_name
-
- for endpoint_name in sorted(endpoint_names):
- new_endpoint = api_def.endpoint.add()
- new_endpoint.name = endpoint_name
-
- return api_def
-
- def _GetBaseApiMap(self):
- """Get a map from graph op name to its base ApiDef.
-
- Returns:
- Dictionary mapping graph op name to corresponding ApiDef.
- """
- # Convert base ApiDef in Multiline format to Proto format.
- converted_base_api_dir = os.path.join(
- test.get_temp_dir(), 'temp_base_api_defs')
- subprocess.check_call(
- [os.path.join(resource_loader.get_root_dir_with_all_resources(),
- _CONVERT_FROM_MULTILINE_SCRIPT),
- _BASE_API_DIR, converted_base_api_dir])
-
- name_to_base_api_def = {}
- base_api_files = file_io.get_matching_files(
- os.path.join(converted_base_api_dir, 'api_def_*.pbtxt'))
- for base_api_file in base_api_files:
- if file_io.file_exists(base_api_file):
- api_defs = api_def_pb2.ApiDefs()
- text_format.Merge(
- file_io.read_file_to_string(base_api_file), api_defs)
- for api_def in api_defs.op:
- name_to_base_api_def[api_def.graph_op_name] = api_def
- return name_to_base_api_def
-
- def _AddHiddenOpOverrides(self, name_to_base_api_def, api_def_map):
- """Adds ApiDef overrides to api_def_map for hidden Python ops.
-
- Args:
- name_to_base_api_def: Map from op name to base api_def_pb2.ApiDef.
- api_def_map: Map from file path to api_def_pb2.ApiDefs for Python API
- overrides.
- """
- hidden_ops = _GetHiddenOps()
- for hidden_op in hidden_ops:
- if hidden_op not in name_to_base_api_def:
- logging.warning('Unexpected hidden op name: %s' % hidden_op)
- continue
-
- base_api_def = name_to_base_api_def[hidden_op]
- if base_api_def.visibility != api_def_pb2.ApiDef.HIDDEN:
- api_def = api_def_pb2.ApiDef()
- api_def.graph_op_name = base_api_def.graph_op_name
- api_def.visibility = api_def_pb2.ApiDef.HIDDEN
-
- file_path = _GetApiDefFilePath(base_api_def.graph_op_name)
- api_def_map[file_path].op.extend([api_def])
-
- @unittest.skipUnless(
- sys.version_info.major == 2 and os.uname()[0] == 'Linux',
- 'API compabitility test goldens are generated using python2 on Linux.')
- def testAPIDefCompatibility(self):
- # Get base ApiDef
- name_to_base_api_def = self._GetBaseApiMap()
- snake_to_camel_graph_op_names = {
- self._GenerateLowerCaseOpName(name): name
- for name in name_to_base_api_def.keys()}
- # Extract Python API
- visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
- public_api_visitor = public_api.PublicAPIVisitor(visitor)
- public_api_visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf, public_api_visitor)
- proto_dict = visitor.GetProtos()
-
- # Map from file path to Python ApiDefs.
- new_api_defs_map = defaultdict(api_def_pb2.ApiDefs)
- # We need to override all endpoints even if 1 endpoint differs from base
- # ApiDef. So, we first create a map from an op to all its endpoints.
- op_to_endpoint_name = defaultdict(list)
-
- # Generate map from generated python op to endpoint names.
- for public_module, value in proto_dict.items():
- module_obj = _GetSymbol(public_module)
- for sym in value.tf_module.member_method:
- obj = getattr(module_obj, sym.name)
-
- # Check if object is defined in gen_* module. That is,
- # the object has been generated from OpDef.
- if hasattr(obj, '__module__') and _IsGenModule(obj.__module__):
- if obj.__name__ not in snake_to_camel_graph_op_names:
- # Symbol might be defined only in Python and not generated from
- # C++ api.
- continue
- relative_public_module = public_module[len('tensorflow.'):]
- full_name = (relative_public_module + '.' + sym.name
- if relative_public_module else sym.name)
- op_to_endpoint_name[obj].append(full_name)
-
- # Generate Python ApiDef overrides.
- for op, endpoint_names in op_to_endpoint_name.items():
- graph_op_name = snake_to_camel_graph_op_names[op.__name__]
- api_def = self._CreatePythonApiDef(
- name_to_base_api_def[graph_op_name], endpoint_names)
-
- if api_def:
- file_path = _GetApiDefFilePath(graph_op_name)
- api_defs = new_api_defs_map[file_path]
- api_defs.op.extend([api_def])
-
- self._AddHiddenOpOverrides(name_to_base_api_def, new_api_defs_map)
-
- old_api_defs_map = _GetGoldenApiDefs()
- for file_path, new_api_defs in new_api_defs_map.items():
- # Get new ApiDef string.
- new_api_defs_str = str(new_api_defs)
-
- # Get current ApiDef for the given file.
- old_api_defs_str = (
- old_api_defs_map[file_path] if file_path in old_api_defs_map else '')
-
- if old_api_defs_str == new_api_defs_str:
- continue
-
- if FLAGS.update_goldens:
- logging.info('Updating %s...' % file_path)
- file_io.write_string_to_file(file_path, new_api_defs_str)
- else:
- self.assertMultiLineEqual(
- old_api_defs_str, new_api_defs_str,
- 'To update golden API files, run api_compatibility_test locally '
- 'with --update_goldens=True flag.')
-
- for file_path in set(old_api_defs_map) - set(new_api_defs_map):
- if FLAGS.update_goldens:
- logging.info('Deleting %s...' % file_path)
- file_io.delete_file(file_path)
- else:
- self.fail(
- '%s file is no longer needed and should be removed.'
- 'To update golden API files, run api_compatibility_test locally '
- 'with --update_goldens=True flag.' % file_path)
-
-
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(