From 36882e882c3de9be4381c266af6049b08fe2326c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 9 Apr 2018 13:24:03 -0700 Subject: [PATCH] Add a utility that can detect the class that defined a method. This is useful when converting a e.g. a custom Keras model, to avoid recompiling the base model methods. PiperOrigin-RevId: 192177117 --- tensorflow/contrib/autograph/pyct/inspect_utils.py | 12 +++++++++++ .../contrib/autograph/pyct/inspect_utils_test.py | 24 ++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py index 30a5961..386a6d2 100644 --- a/tensorflow/contrib/autograph/pyct/inspect_utils.py +++ b/tensorflow/contrib/autograph/pyct/inspect_utils.py @@ -50,6 +50,18 @@ def getnamespace(f): return namespace +def getdefiningclass(m, owner_class): + """Resolves the class (e.g. one of the superclasses) that defined a method.""" + m = six.get_unbound_function(m) + last_defining = owner_class + for superclass in tf_inspect.getmro(owner_class): + if hasattr(superclass, m.__name__): + superclass_m = getattr(superclass, m.__name__) + if six.get_unbound_function(superclass_m) == m: + last_defining = superclass + return last_defining + + def getmethodclass(m): """Resolves a function's owner, e.g. a method's class. diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py index eda3fc1..58f827b 100644 --- a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py @@ -234,6 +234,30 @@ class InspectUtilsTest(test.TestCase): c = TestCallable() self.assertEqual(inspect_utils.getmethodclass(c), TestCallable) + def test_getdefiningclass(self): + class Superclass(object): + + def foo(self): + pass + + def bar(self): + pass + + class Subclass(Superclass): + + def foo(self): + pass + + def baz(self): + pass + + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.foo, Subclass) is Subclass) + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.bar, Subclass) is Superclass) + self.assertTrue( + inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass) + if __name__ == '__main__': test.main() -- 2.7.4