From 151e35680b0b2575aa8bdb6bddbb95536be4fed0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 17 May 2018 10:15:45 -0700 Subject: [PATCH] Change traverse_test.test_module to traverse a constructed dummy module rather than testcase itself. PiperOrigin-RevId: 197010681 --- tensorflow/tools/common/BUILD | 17 +++++++++++++++++ tensorflow/tools/common/test_module1.py | 31 +++++++++++++++++++++++++++++++ tensorflow/tools/common/test_module2.py | 29 +++++++++++++++++++++++++++++ tensorflow/tools/common/traverse_test.py | 15 +++++---------- 4 files changed, 82 insertions(+), 10 deletions(-) create mode 100644 tensorflow/tools/common/test_module1.py create mode 100644 tensorflow/tools/common/test_module2.py diff --git a/tensorflow/tools/common/BUILD b/tensorflow/tools/common/BUILD index b9032c0..8c01d15 100644 --- a/tensorflow/tools/common/BUILD +++ b/tensorflow/tools/common/BUILD @@ -40,7 +40,24 @@ py_test( srcs = ["traverse_test.py"], srcs_version = "PY2AND3", deps = [ + ":test_module1", + ":test_module2", ":traverse", "//tensorflow/python:platform_test", ], ) + +py_library( + name = "test_module1", + srcs = ["test_module1.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_module2", + ], +) + +py_library( + name = "test_module2", + srcs = ["test_module2.py"], + srcs_version = "PY2AND3", +) diff --git a/tensorflow/tools/common/test_module1.py b/tensorflow/tools/common/test_module1.py new file mode 100644 index 0000000..cc185cf --- /dev/null +++ b/tensorflow/tools/common/test_module1.py @@ -0,0 +1,31 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A module target for TraverseTest.test_module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.tools.common import test_module2 + + +class ModuleClass1(object): + + def __init__(self): + self._m2 = test_module2.ModuleClass2() + + def __model_class1_method__(self): + pass + diff --git a/tensorflow/tools/common/test_module2.py b/tensorflow/tools/common/test_module2.py new file mode 100644 index 0000000..d9da99d --- /dev/null +++ b/tensorflow/tools/common/test_module2.py @@ -0,0 +1,29 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A module target for TraverseTest.test_module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class ModuleClass2(object): + + def __init__(self): + pass + + def __model_class1_method__(self): + pass + diff --git a/tensorflow/tools/common/traverse_test.py b/tensorflow/tools/common/traverse_test.py index eb195ec..ed41069 100644 --- a/tensorflow/tools/common/traverse_test.py +++ b/tensorflow/tools/common/traverse_test.py @@ -18,9 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import sys - from tensorflow.python.platform import googletest +from tensorflow.tools.common import test_module1 +from tensorflow.tools.common import test_module2 from tensorflow.tools.common import traverse @@ -30,10 +30,6 @@ class TestVisitor(object): self.call_log = [] def __call__(self, path, parent, children): - # Do not traverse googletest, it's very deep. - for item in list(children): - if item[1] is googletest: - children.remove(item) self.call_log += [(path, parent, children)] @@ -51,13 +47,12 @@ class TraverseTest(googletest.TestCase): def test_module(self): visitor = TestVisitor() - traverse.traverse(sys.modules[__name__], visitor) + traverse.traverse(test_module1, visitor) called = [parent for _, parent, _ in visitor.call_log] - self.assertIn(TestVisitor, called) - self.assertIn(TraverseTest, called) - self.assertIn(traverse, called) + self.assertIn(test_module1.ModuleClass1, called) + self.assertIn(test_module2.ModuleClass2, called) def test_class(self): visitor = TestVisitor() -- 2.7.4