From 16594f81af70eb33e9bbda1f1f26082a64919b2d Mon Sep 17 00:00:00 2001 From: Anna R Date: Fri, 2 Feb 2018 10:28:20 -0800 Subject: [PATCH] Remove hidden_ops.txt file. Instead, switch to use visibility attribute in ApiDef proto. PiperOrigin-RevId: 184301076 --- tensorflow/contrib/cmake/tf_python.cmake | 2 +- tensorflow/core/api_def/BUILD | 1 + tensorflow/core/api_def/api_test.cc | 197 +++++++++++------ tensorflow/python/BUILD | 6 - tensorflow/python/build_defs.bzl | 1 - tensorflow/python/client/session_test.py | 7 +- tensorflow/python/eager/python_eager_op_gen.cc | 21 +- tensorflow/python/framework/python_op_gen.cc | 24 ++- .../kernel_tests/control_flow_ops_py_test.py | 32 +-- .../python/kernel_tests/control_flow_util_test.py | 10 +- tensorflow/python/ops/control_flow_ops.py | 10 +- tensorflow/tools/api/tests/BUILD | 5 - .../tools/api/tests/api_compatibility_test.py | 239 --------------------- 13 files changed, 198 insertions(+), 357 deletions(-) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 8862390..294b9c5 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -307,7 +307,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) # containing the wrappers. add_custom_command( OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION} - COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} + COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} DEPENDS ${tf_python_op_lib_name}_gen_python ) diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 81187ff..58dbac4 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -96,6 +96,7 @@ tf_cc_test( srcs = ["api_test.cc"], data = [ ":base_api_def", + ":python_api_def", ], deps = [ ":excluded_ops_lib", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc index 112c55c..477a0b6 100644 --- a/tensorflow/core/api_def/api_test.cc +++ b/tensorflow/core/api_def/api_test.cc @@ -41,8 +41,9 @@ namespace tensorflow { namespace { constexpr char kDefaultApiDefDir[] = "tensorflow/core/api_def/base_api"; +constexpr char kPythonApiDefDir[] = + "tensorflow/core/api_def/python_api"; constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; -} // namespace // Reads golden ApiDef files and returns a map from file name to ApiDef file // contents. @@ -66,9 +67,93 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir, } } -class ApiTest : public ::testing::Test { +void TestAllApiDefsHaveCorrespondingOp( + const OpList& ops, const std::unordered_map& api_defs_map) { + std::unordered_set op_names; + for (const auto& op : ops.op()) { + op_names.insert(op.name()); + } + for (const auto& name_and_api_def : api_defs_map) { + ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end()) + << name_and_api_def.first << " op has ApiDef but missing from ops. " + << "Does api_def_" << name_and_api_def.first << " need to be deleted?"; + } +} + +void TestAllApiDefInputArgsAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_arg : api_def.in_arg()) { + bool found_arg = false; + for (const auto& op_arg : op.input_arg()) { + if (api_def_arg.name() == op_arg.name()) { + found_arg = true; + break; + } + } + ASSERT_TRUE(found_arg) + << "Input argument " << api_def_arg.name() + << " (overwritten in api_def_" << op.name() + << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} + +void TestAllApiDefOutputArgsAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_arg : api_def.out_arg()) { + bool found_arg = false; + for (const auto& op_arg : op.output_arg()) { + if (api_def_arg.name() == op_arg.name()) { + found_arg = true; + break; + } + } + ASSERT_TRUE(found_arg) + << "Output argument " << api_def_arg.name() + << " (overwritten in api_def_" << op.name() + << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} + +void TestAllApiDefAttributeNamesAreValid( + const OpList& ops, const std::unordered_map& api_defs_map) { + for (const auto& op : ops.op()) { + const auto api_def_iter = api_defs_map.find(op.name()); + if (api_def_iter == api_defs_map.end()) { + continue; + } + const auto& api_def = api_def_iter->second; + for (const auto& api_def_attr : api_def.attr()) { + bool found_attr = false; + for (const auto& op_attr : op.attr()) { + if (api_def_attr.name() == op_attr.name()) { + found_attr = true; + } + } + ASSERT_TRUE(found_attr) + << "Attribute " << api_def_attr.name() << " (overwritten in api_def_" + << op.name() << ".pbtxt) is not defined in OpDef for " << op.name(); + } + } +} +} // namespace + +class BaseApiTest : public ::testing::Test { protected: - ApiTest() { + BaseApiTest() { OpRegistry::Global()->Export(false, &ops_); const std::vector multi_line_fields = {"description"}; @@ -80,7 +165,7 @@ class ApiTest : public ::testing::Test { }; // Check that all ops have an ApiDef. -TEST_F(ApiTest, AllOpsAreInApiDef) { +TEST_F(BaseApiTest, AllOpsAreInApiDef) { auto* excluded_ops = GetExcludedOps(); for (const auto& op : ops_.op()) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { @@ -94,16 +179,8 @@ TEST_F(ApiTest, AllOpsAreInApiDef) { } // Check that ApiDefs have a corresponding op. -TEST_F(ApiTest, AllApiDefsHaveCorrespondingOp) { - std::unordered_set op_names; - for (const auto& op : ops_.op()) { - op_names.insert(op.name()); - } - for (const auto& name_and_api_def : api_defs_map_) { - ASSERT_TRUE(op_names.find(name_and_api_def.first) != op_names.end()) - << name_and_api_def.first << " op has ApiDef but missing from ops. " - << "Does api_def_" << name_and_api_def.first << " need to be deleted?"; - } +TEST_F(BaseApiTest, AllApiDefsHaveCorrespondingOp) { + TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_); } string GetOpDefHasDocStringError(const string& op_name) { @@ -117,7 +194,7 @@ string GetOpDefHasDocStringError(const string& op_name) { // Check that OpDef's do not have descriptions and summaries. // Descriptions and summaries must be in corresponding ApiDefs. -TEST_F(ApiTest, OpDefsShouldNotHaveDocs) { +TEST_F(BaseApiTest, OpDefsShouldNotHaveDocs) { auto* excluded_ops = GetExcludedOps(); for (const auto& op : ops_.op()) { if (excluded_ops->find(op.name()) != excluded_ops->end()) { @@ -143,62 +220,56 @@ TEST_F(ApiTest, OpDefsShouldNotHaveDocs) { // Checks that input arg names in an ApiDef match input // arg names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefInputArgsAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_arg : api_def.in_arg()) { - bool found_arg = false; - for (const auto& op_arg : op.input_arg()) { - if (api_def_arg.name() == op_arg.name()) { - found_arg = true; - break; - } - } - ASSERT_TRUE(found_arg) - << "Input argument " << api_def_arg.name() - << " (overwritten in api_def_" << op.name() - << ".pbtxt) is not defined in OpDef for " << op.name(); - } - } +TEST_F(BaseApiTest, AllApiDefInputArgsAreValid) { + TestAllApiDefInputArgsAreValid(ops_, api_defs_map_); } // Checks that output arg names in an ApiDef match output // arg names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefOutputArgsAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_arg : api_def.out_arg()) { - bool found_arg = false; - for (const auto& op_arg : op.output_arg()) { - if (api_def_arg.name() == op_arg.name()) { - found_arg = true; - break; - } - } - ASSERT_TRUE(found_arg) - << "Output argument " << api_def_arg.name() - << " (overwritten in api_def_" << op.name() - << ".pbtxt) is not defined in OpDef for " << op.name(); - } - } +TEST_F(BaseApiTest, AllApiDefOutputArgsAreValid) { + TestAllApiDefOutputArgsAreValid(ops_, api_defs_map_); } // Checks that attribute names in an ApiDef match attribute // names in corresponding OpDef. -TEST_F(ApiTest, AllApiDefAttributeNamesAreValid) { - for (const auto& op : ops_.op()) { - const auto& api_def = api_defs_map_[op.name()]; - for (const auto& api_def_attr : api_def.attr()) { - bool found_attr = false; - for (const auto& op_attr : op.attr()) { - if (api_def_attr.name() == op_attr.name()) { - found_attr = true; - } - } - ASSERT_TRUE(found_attr) - << "Attribute " << api_def_attr.name() << " (overwritten in api_def_" - << op.name() << ".pbtxt) is not defined in OpDef for " << op.name(); - } +TEST_F(BaseApiTest, AllApiDefAttributeNamesAreValid) { + TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); +} + +class PythonApiTest : public ::testing::Test { + protected: + PythonApiTest() { + OpRegistry::Global()->Export(false, &ops_); + const std::vector multi_line_fields = {"description"}; + + Env* env = Env::Default(); + GetGoldenApiDefs(env, kPythonApiDefDir, &api_defs_map_); } + OpList ops_; + std::unordered_map api_defs_map_; +}; + +// Check that ApiDefs have a corresponding op. +TEST_F(PythonApiTest, AllApiDefsHaveCorrespondingOp) { + TestAllApiDefsHaveCorrespondingOp(ops_, api_defs_map_); } + +// Checks that input arg names in an ApiDef match input +// arg names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefInputArgsAreValid) { + TestAllApiDefInputArgsAreValid(ops_, api_defs_map_); +} + +// Checks that output arg names in an ApiDef match output +// arg names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefOutputArgsAreValid) { + TestAllApiDefOutputArgsAreValid(ops_, api_defs_map_); +} + +// Checks that attribute names in an ApiDef match attribute +// names in corresponding OpDef. +TEST_F(PythonApiTest, AllApiDefAttributeNamesAreValid) { + TestAllApiDefAttributeNamesAreValid(ops_, api_defs_map_); +} + } // namespace tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 363ff6f..a89142d 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -4230,12 +4230,6 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) -filegroup( - name = "hidden_ops", - srcs = ["ops/hidden_ops.txt"], - visibility = ["//tensorflow:__subpackages__"], -) - cuda_py_test( name = "accumulate_n_benchmark", size = "large", diff --git a/tensorflow/python/build_defs.bzl b/tensorflow/python/build_defs.bzl index 7f29adc..b9056f8 100644 --- a/tensorflow/python/build_defs.bzl +++ b/tensorflow/python/build_defs.bzl @@ -22,7 +22,6 @@ def tf_gen_op_wrapper_private_py(name, out=None, deps=[], bare_op_name = name[:-4] # Strip off the _gen tf_gen_op_wrapper_py(name=bare_op_name, out=out, - hidden_file="ops/hidden_ops.txt", visibility=visibility, deps=deps, require_shape_functions=require_shape_functions, diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index 768a5db..f12c005 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -46,6 +46,7 @@ from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops # Import resource_variable_ops for the variables-to-tensor implicit conversion. @@ -1745,8 +1746,10 @@ class SessionTest(test_util.TensorFlowTestCase): def runTestBuildGraphError(self, sess): # Ensure that errors from building the graph get propagated. data = array_ops.placeholder(dtypes.float32, shape=[]) - enter_1 = control_flow_ops.enter(data, 'foo_1', False) - enter_2 = control_flow_ops.enter(data, 'foo_2', False) + # pylint: disable=protected-access + enter_1 = gen_control_flow_ops._enter(data, 'foo_1', False) + enter_2 = gen_control_flow_ops._enter(data, 'foo_2', False) + # pylint: enable=protected-access res = math_ops.add(enter_1, enter_2) with self.assertRaisesOpError('has inputs from different frames'): sess.run(res, feed_dict={data: 1.0}) diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index 90a8779..0f18f28 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -756,11 +756,21 @@ from tensorflow.python.util.tf_export import tf_export auto out = cleaned_ops.mutable_op(); out->Reserve(ops.op_size()); for (const auto& op_def : ops.op()) { - bool is_hidden = false; - for (const string& hidden : hidden_ops) { - if (op_def.name() == hidden) { - is_hidden = true; - break; + const auto* api_def = api_defs.GetApiDef(op_def.name()); + + if (api_def->visibility() == ApiDef::SKIP) { + continue; + } + + // An op is hidden if either its ApiDef visibility is HIDDEN + // or it is in the hidden_ops list. + bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; + if (!is_hidden) { + for (const string& hidden : hidden_ops) { + if (op_def.name() == hidden) { + is_hidden = true; + break; + } } } @@ -777,7 +787,6 @@ from tensorflow.python.util.tf_export import tf_export continue; } - const auto* api_def = api_defs.GetApiDef(op_def.name()); strings::StrAppend(&result, GetEagerPythonOp(op_def, *api_def, function_name)); diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 65810fa..85cba59 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -476,9 +476,6 @@ GenPythonOp::GenPythonOp(const OpDef& op_def, const ApiDef& api_def, GenPythonOp::~GenPythonOp() {} string GenPythonOp::Code() { - if (api_def_.visibility() == ApiDef::SKIP) { - return ""; - } // This has all the input args followed by those attrs that don't have // defaults. std::vector params_no_default; @@ -805,11 +802,21 @@ from tensorflow.python.util.tf_export import tf_export auto out = cleaned_ops.mutable_op(); out->Reserve(ops.op_size()); for (const auto& op_def : ops.op()) { - bool is_hidden = false; - for (const string& hidden : hidden_ops) { - if (op_def.name() == hidden) { - is_hidden = true; - break; + const auto* api_def = api_defs.GetApiDef(op_def.name()); + + if (api_def->visibility() == ApiDef::SKIP) { + continue; + } + + // An op is hidden if either its ApiDef visibility is HIDDEN + // or it is in the hidden_ops list. + bool is_hidden = api_def->visibility() == ApiDef::HIDDEN; + if (!is_hidden) { + for (const string& hidden : hidden_ops) { + if (op_def.name() == hidden) { + is_hidden = true; + break; + } } } @@ -826,7 +833,6 @@ from tensorflow.python.util.tf_export import tf_export continue; } - const auto* api_def = api_defs.GetApiDef(op_def.name()); strings::StrAppend(&result, GetPythonOp(op_def, *api_def, function_name)); if (!require_shapes) { diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 5d648bb..4fafc36 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_logging_ops from tensorflow.python.ops import gen_state_ops @@ -143,7 +144,7 @@ class ControlFlowTest(test.TestCase): enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True) nine = constant_op.constant(9) - enter_nine = control_flow_ops.enter(nine, "foo_1") + enter_nine = gen_control_flow_ops._enter(nine, "foo_1") op = state_ops.assign(enter_v, enter_nine) v2 = control_flow_ops.with_dependencies([op], enter_v) v3 = control_flow_ops.exit(v2) @@ -163,9 +164,9 @@ class ControlFlowTest(test.TestCase): def testEnterMulExit(self): with self.test_session(): data = constant_op.constant([1, 2, 3, 4, 5, 6], name="data") - enter_data = control_flow_ops.enter(data, "foo_1", False) + enter_data = gen_control_flow_ops._enter(data, "foo_1", False) five = constant_op.constant(5) - enter_five = control_flow_ops.enter(five, "foo_1", False) + enter_five = gen_control_flow_ops._enter(five, "foo_1", False) mul_op = math_ops.multiply(enter_data, enter_five) exit_op = control_flow_ops.exit(mul_op) @@ -177,11 +178,12 @@ class ControlFlowTest(test.TestCase): v = variables.Variable([0.0, 0.0], dtype=dtypes.float32) # If is_constant=True, the shape information should be propagated. - enter_v_constant = control_flow_ops.enter(v, "frame1", is_constant=True) + enter_v_constant = gen_control_flow_ops._enter( + v, "frame1", is_constant=True) self.assertEqual(enter_v_constant.shape, [2]) # Otherwise, the shape should be unknown. - enter_v_non_constant = control_flow_ops.enter( + enter_v_non_constant = gen_control_flow_ops._enter( v, "frame2", is_constant=False) self.assertEqual(enter_v_non_constant.shape, None) @@ -255,8 +257,8 @@ class ControlFlowTest(test.TestCase): false = ops.convert_to_tensor(False) n = constant_op.constant(10) - enter_false = control_flow_ops.enter(false, "foo_1", False) - enter_n = control_flow_ops.enter(n, "foo_1", False) + enter_false = gen_control_flow_ops._enter(false, "foo_1", False) + enter_n = gen_control_flow_ops._enter(n, "foo_1", False) merge_n = control_flow_ops.merge([enter_n, enter_n], name="merge_n")[0] switch_n = control_flow_ops.switch(merge_n, enter_false) @@ -273,9 +275,9 @@ class ControlFlowTest(test.TestCase): one = constant_op.constant(1) n = constant_op.constant(10) - enter_i = control_flow_ops.enter(zero, "foo", False) - enter_one = control_flow_ops.enter(one, "foo", True) - enter_n = control_flow_ops.enter(n, "foo", True) + enter_i = gen_control_flow_ops._enter(zero, "foo", False) + enter_one = gen_control_flow_ops._enter(one, "foo", True) + enter_n = gen_control_flow_ops._enter(n, "foo", True) with ops.device(test.gpu_device_name()): merge_i = control_flow_ops.merge([enter_i, enter_i])[0] @@ -299,9 +301,9 @@ class ControlFlowTest(test.TestCase): one = constant_op.constant(1) n = constant_op.constant(10) - enter_i = control_flow_ops.enter(zero, "foo", False) - enter_one = control_flow_ops.enter(one, "foo", True) - enter_n = control_flow_ops.enter(n, "foo", True) + enter_i = gen_control_flow_ops._enter(zero, "foo", False) + enter_one = gen_control_flow_ops._enter(one, "foo", True) + enter_n = gen_control_flow_ops._enter(n, "foo", True) merge_i = control_flow_ops.merge([enter_i, enter_i])[0] @@ -322,8 +324,8 @@ class ControlFlowTest(test.TestCase): def testDifferentFrame(self): with self.test_session(): data = array_ops.placeholder(dtypes.float32, shape=[]) - enter_1 = control_flow_ops.enter(data, "foo_1", False) - enter_2 = control_flow_ops.enter(data, "foo_2", False) + enter_1 = gen_control_flow_ops._enter(data, "foo_1", False) + enter_2 = gen_control_flow_ops._enter(data, "foo_2", False) res = math_ops.add(enter_1, enter_2) with self.assertRaisesOpError("has inputs from different frames"): res.eval(feed_dict={data: 1.0}) diff --git a/tensorflow/python/kernel_tests/control_flow_util_test.py b/tensorflow/python/kernel_tests/control_flow_util_test.py index 39e96f7..23185ea 100644 --- a/tensorflow/python/kernel_tests/control_flow_util_test.py +++ b/tensorflow/python/kernel_tests/control_flow_util_test.py @@ -41,17 +41,17 @@ class ControlFlowUtilTest(test.TestCase): self.assertFalse(control_flow_util.IsSwitch(test_ops.int_output().op)) def testIsLoopEnter(self): - enter = gen_control_flow_ops.enter(1, frame_name="name").op + enter = gen_control_flow_ops._enter(1, frame_name="name").op self.assertTrue(control_flow_util.IsLoopEnter(enter)) self.assertFalse(control_flow_util.IsLoopConstantEnter(enter)) - ref_enter = gen_control_flow_ops.ref_enter(test_ops.ref_output(), - frame_name="name").op + ref_enter = gen_control_flow_ops._ref_enter(test_ops.ref_output(), + frame_name="name").op self.assertTrue(control_flow_util.IsLoopEnter(ref_enter)) self.assertFalse(control_flow_util.IsLoopConstantEnter(ref_enter)) - const_enter = gen_control_flow_ops.enter(1, frame_name="name", - is_constant=True).op + const_enter = gen_control_flow_ops._enter(1, frame_name="name", + is_constant=True).op self.assertTrue(control_flow_util.IsLoopEnter(const_enter)) self.assertTrue(control_flow_util.IsLoopConstantEnter(const_enter)) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 49191c6..33a9263 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -261,10 +261,10 @@ def _Enter(data, data = ops.internal_convert_to_tensor_or_indexed_slices(data, as_ref=True) if isinstance(data, ops.Tensor): if data.dtype._is_ref_dtype and use_ref: # pylint: disable=protected-access - result = ref_enter( + result = gen_control_flow_ops._ref_enter( data, frame_name, is_constant, parallel_iterations, name=name) else: - result = enter( + result = gen_control_flow_ops._enter( data, frame_name, is_constant, parallel_iterations, name=name) if use_input_shape: result.set_shape(data.get_shape()) @@ -279,7 +279,7 @@ def _Enter(data, parallel_iterations=parallel_iterations, use_input_shape=use_input_shape, name=name) - indices = enter( + indices = gen_control_flow_ops._enter( data.indices, frame_name, is_constant, @@ -290,7 +290,7 @@ def _Enter(data, if isinstance(data, ops.IndexedSlices): dense_shape = data.dense_shape if dense_shape is not None: - dense_shape = enter( + dense_shape = gen_control_flow_ops._enter( dense_shape, frame_name, is_constant, @@ -300,7 +300,7 @@ def _Enter(data, dense_shape.set_shape(data.dense_shape.get_shape()) return ops.IndexedSlices(values, indices, dense_shape) else: - dense_shape = enter( + dense_shape = gen_control_flow_ops._enter( data.dense_shape, frame_name, is_constant, diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 8fb6b1c..608a34a 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -17,10 +17,6 @@ py_test( name = "api_compatibility_test", srcs = ["api_compatibility_test.py"], data = [ - ":convert_from_multiline", - "//tensorflow/core/api_def:base_api_def", - "//tensorflow/core/api_def:python_api_def", - "//tensorflow/python:hidden_ops", "//tensorflow/tools/api/golden:api_golden", "//tensorflow/tools/api/tests:API_UPDATE_WARNING.txt", "//tensorflow/tools/api/tests:README.txt", @@ -29,7 +25,6 @@ py_test( deps = [ "//tensorflow:tensorflow_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/tools/api/lib:python_object_to_proto_visitor", diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index afcbf50..c1e09cc 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -28,10 +28,8 @@ from __future__ import division from __future__ import print_function import argparse -from collections import defaultdict import os import re -import subprocess import sys import unittest @@ -39,7 +37,6 @@ import tensorflow as tf from google.protobuf import text_format -from tensorflow.core.framework import api_def_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -67,11 +64,6 @@ _API_GOLDEN_FOLDER = 'tensorflow/tools/api/golden' _TEST_README_FILE = 'tensorflow/tools/api/tests/README.txt' _UPDATE_WARNING_FILE = 'tensorflow/tools/api/tests/API_UPDATE_WARNING.txt' -_CONVERT_FROM_MULTILINE_SCRIPT = 'tensorflow/tools/api/tests/convert_from_multiline' -_BASE_API_DIR = 'tensorflow/core/api_def/base_api' -_PYTHON_API_DIR = 'tensorflow/core/api_def/python_api' -_HIDDEN_OPS_FILE = 'tensorflow/python/ops/hidden_ops.txt' - def _KeyToFilePath(key): """From a given key, construct a filepath.""" @@ -96,55 +88,6 @@ def _FileNameToKey(filename): return api_object_key -def _GetSymbol(symbol_id): - """Get TensorFlow symbol based on the given identifier. - - Args: - symbol_id: Symbol identifier in the form module1.module2. ... .sym. - - Returns: - Symbol corresponding to the given id. - """ - # Ignore first module which should be tensorflow - symbol_id_split = symbol_id.split('.')[1:] - symbol = tf - for sym in symbol_id_split: - symbol = getattr(symbol, sym) - return symbol - - -def _IsGenModule(module_name): - if not module_name: - return False - module_name_split = module_name.split('.') - return module_name_split[-1].startswith('gen_') - - -def _GetHiddenOps(): - hidden_ops_file = file_io.FileIO(_HIDDEN_OPS_FILE, 'r') - hidden_ops = set() - for line in hidden_ops_file: - line = line.strip() - if not line: - continue - if line[0] == '#': # comment line - continue - # If line is of the form "op_name # comment", only keep the op_name. - line_split = line.split('#') - hidden_ops.add(line_split[0].strip()) - return hidden_ops - - -def _GetGoldenApiDefs(): - old_api_def_files = file_io.get_matching_files(_GetApiDefFilePath('*')) - return {file_path: file_io.read_file_to_string(file_path) - for file_path in old_api_def_files} - - -def _GetApiDefFilePath(graph_op_name): - return os.path.join(_PYTHON_API_DIR, 'api_def_%s.pbtxt' % graph_op_name) - - class ApiCompatibilityTest(test.TestCase): def __init__(self, *args, **kwargs): @@ -287,188 +230,6 @@ class ApiCompatibilityTest(test.TestCase): update_goldens=FLAGS.update_goldens) -class ApiDefTest(test.TestCase): - - def __init__(self, *args, **kwargs): - super(ApiDefTest, self).__init__(*args, **kwargs) - self._first_cap_pattern = re.compile('(.)([A-Z][a-z]+)') - self._all_cap_pattern = re.compile('([a-z0-9])([A-Z])') - - def _GenerateLowerCaseOpName(self, op_name): - lower_case_name = self._first_cap_pattern.sub(r'\1_\2', op_name) - return self._all_cap_pattern.sub(r'\1_\2', lower_case_name).lower() - - def _CreatePythonApiDef(self, base_api_def, endpoint_names): - """Creates Python ApiDef that overrides base_api_def if needed. - - Args: - base_api_def: (api_def_pb2.ApiDef) base ApiDef instance. - endpoint_names: List of Python endpoint names. - - Returns: - api_def_pb2.ApiDef instance with overrides for base_api_def - if module.name endpoint is different from any existing - endpoints in base_api_def. Otherwise, returns None. - """ - endpoint_names_set = set(endpoint_names) - - # If the only endpoint is equal to graph_op_name then - # it is equivalent to having no endpoints. - if (not base_api_def.endpoint and len(endpoint_names) == 1 - and endpoint_names[0] == - self._GenerateLowerCaseOpName(base_api_def.graph_op_name)): - return None - - base_endpoint_names_set = { - self._GenerateLowerCaseOpName(endpoint.name) - for endpoint in base_api_def.endpoint} - - if endpoint_names_set == base_endpoint_names_set: - return None # All endpoints are the same - - api_def = api_def_pb2.ApiDef() - api_def.graph_op_name = base_api_def.graph_op_name - - for endpoint_name in sorted(endpoint_names): - new_endpoint = api_def.endpoint.add() - new_endpoint.name = endpoint_name - - return api_def - - def _GetBaseApiMap(self): - """Get a map from graph op name to its base ApiDef. - - Returns: - Dictionary mapping graph op name to corresponding ApiDef. - """ - # Convert base ApiDef in Multiline format to Proto format. - converted_base_api_dir = os.path.join( - test.get_temp_dir(), 'temp_base_api_defs') - subprocess.check_call( - [os.path.join(resource_loader.get_root_dir_with_all_resources(), - _CONVERT_FROM_MULTILINE_SCRIPT), - _BASE_API_DIR, converted_base_api_dir]) - - name_to_base_api_def = {} - base_api_files = file_io.get_matching_files( - os.path.join(converted_base_api_dir, 'api_def_*.pbtxt')) - for base_api_file in base_api_files: - if file_io.file_exists(base_api_file): - api_defs = api_def_pb2.ApiDefs() - text_format.Merge( - file_io.read_file_to_string(base_api_file), api_defs) - for api_def in api_defs.op: - name_to_base_api_def[api_def.graph_op_name] = api_def - return name_to_base_api_def - - def _AddHiddenOpOverrides(self, name_to_base_api_def, api_def_map): - """Adds ApiDef overrides to api_def_map for hidden Python ops. - - Args: - name_to_base_api_def: Map from op name to base api_def_pb2.ApiDef. - api_def_map: Map from file path to api_def_pb2.ApiDefs for Python API - overrides. - """ - hidden_ops = _GetHiddenOps() - for hidden_op in hidden_ops: - if hidden_op not in name_to_base_api_def: - logging.warning('Unexpected hidden op name: %s' % hidden_op) - continue - - base_api_def = name_to_base_api_def[hidden_op] - if base_api_def.visibility != api_def_pb2.ApiDef.HIDDEN: - api_def = api_def_pb2.ApiDef() - api_def.graph_op_name = base_api_def.graph_op_name - api_def.visibility = api_def_pb2.ApiDef.HIDDEN - - file_path = _GetApiDefFilePath(base_api_def.graph_op_name) - api_def_map[file_path].op.extend([api_def]) - - @unittest.skipUnless( - sys.version_info.major == 2 and os.uname()[0] == 'Linux', - 'API compabitility test goldens are generated using python2 on Linux.') - def testAPIDefCompatibility(self): - # Get base ApiDef - name_to_base_api_def = self._GetBaseApiMap() - snake_to_camel_graph_op_names = { - self._GenerateLowerCaseOpName(name): name - for name in name_to_base_api_def.keys()} - # Extract Python API - visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() - public_api_visitor = public_api.PublicAPIVisitor(visitor) - public_api_visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf, public_api_visitor) - proto_dict = visitor.GetProtos() - - # Map from file path to Python ApiDefs. - new_api_defs_map = defaultdict(api_def_pb2.ApiDefs) - # We need to override all endpoints even if 1 endpoint differs from base - # ApiDef. So, we first create a map from an op to all its endpoints. - op_to_endpoint_name = defaultdict(list) - - # Generate map from generated python op to endpoint names. - for public_module, value in proto_dict.items(): - module_obj = _GetSymbol(public_module) - for sym in value.tf_module.member_method: - obj = getattr(module_obj, sym.name) - - # Check if object is defined in gen_* module. That is, - # the object has been generated from OpDef. - if hasattr(obj, '__module__') and _IsGenModule(obj.__module__): - if obj.__name__ not in snake_to_camel_graph_op_names: - # Symbol might be defined only in Python and not generated from - # C++ api. - continue - relative_public_module = public_module[len('tensorflow.'):] - full_name = (relative_public_module + '.' + sym.name - if relative_public_module else sym.name) - op_to_endpoint_name[obj].append(full_name) - - # Generate Python ApiDef overrides. - for op, endpoint_names in op_to_endpoint_name.items(): - graph_op_name = snake_to_camel_graph_op_names[op.__name__] - api_def = self._CreatePythonApiDef( - name_to_base_api_def[graph_op_name], endpoint_names) - - if api_def: - file_path = _GetApiDefFilePath(graph_op_name) - api_defs = new_api_defs_map[file_path] - api_defs.op.extend([api_def]) - - self._AddHiddenOpOverrides(name_to_base_api_def, new_api_defs_map) - - old_api_defs_map = _GetGoldenApiDefs() - for file_path, new_api_defs in new_api_defs_map.items(): - # Get new ApiDef string. - new_api_defs_str = str(new_api_defs) - - # Get current ApiDef for the given file. - old_api_defs_str = ( - old_api_defs_map[file_path] if file_path in old_api_defs_map else '') - - if old_api_defs_str == new_api_defs_str: - continue - - if FLAGS.update_goldens: - logging.info('Updating %s...' % file_path) - file_io.write_string_to_file(file_path, new_api_defs_str) - else: - self.assertMultiLineEqual( - old_api_defs_str, new_api_defs_str, - 'To update golden API files, run api_compatibility_test locally ' - 'with --update_goldens=True flag.') - - for file_path in set(old_api_defs_map) - set(new_api_defs_map): - if FLAGS.update_goldens: - logging.info('Deleting %s...' % file_path) - file_io.delete_file(file_path) - else: - self.fail( - '%s file is no longer needed and should be removed.' - 'To update golden API files, run api_compatibility_test locally ' - 'with --update_goldens=True flag.' % file_path) - - if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( -- 2.7.4