From 7d89bfcd72bef4c5c9328a88ee520d81642b5284 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 12 Apr 2018 18:19:05 -0700 Subject: [PATCH] Adding autograph built-in function checker. PiperOrigin-RevId: 192703924 --- tensorflow/contrib/autograph/converters/call_trees.py | 3 +-- tensorflow/contrib/autograph/impl/api.py | 2 +- tensorflow/contrib/autograph/pyct/inspect_utils.py | 13 +++++++++++++ tensorflow/contrib/autograph/pyct/inspect_utils_test.py | 7 +++++++ 4 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index 61f6bfd..9424966 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -23,7 +23,6 @@ from __future__ import division from __future__ import print_function from collections import namedtuple -import types import gast @@ -114,7 +113,7 @@ class CallTreeTransformer(transformer.Base): def _function_is_compilable(self, target_entity): """Determines whether an entity can be compiled at all.""" # TODO(mdan): This is just a placeholder. Implement. - return not isinstance(target_entity, types.BuiltinFunctionType) + return not inspect_utils.isbuiltin(target_entity) def _should_compile(self, node, fqn): """Determines whether an entity should be compiled in the context.""" diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index dce994e..a553813 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -137,7 +137,7 @@ def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): unknown_arg_value = object() # Sentinel for arguments of unknown value - if tf_inspect.isbuiltin(f): + if inspect_utils.isbuiltin(f): return builtins.dynamic_builtin(f, *args, **kwargs) if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils.py b/tensorflow/contrib/autograph/pyct/inspect_utils.py index 386a6d2..63361cc 100644 --- a/tensorflow/contrib/autograph/pyct/inspect_utils.py +++ b/tensorflow/contrib/autograph/pyct/inspect_utils.py @@ -22,12 +22,25 @@ from __future__ import division from __future__ import print_function import itertools +import types import six from tensorflow.python.util import tf_inspect +def isbuiltin(f): + # Note these return false for isinstance(f, types.BuiltinFunctionType) so we + # need to specifically check for them. + if f in (range, int, float): + return True + if isinstance(f, types.BuiltinFunctionType): + return True + if tf_inspect.isbuiltin(f): + return True + return False + + def getnamespace(f): """Returns the complete namespace of a function. diff --git a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py index 58f827b..cf841da 100644 --- a/tensorflow/contrib/autograph/pyct/inspect_utils_test.py +++ b/tensorflow/contrib/autograph/pyct/inspect_utils_test.py @@ -258,6 +258,13 @@ class InspectUtilsTest(test.TestCase): self.assertTrue( inspect_utils.getdefiningclass(Subclass.baz, Subclass) is Subclass) + def test_isbuiltin(self): + self.assertTrue(inspect_utils.isbuiltin(range)) + self.assertTrue(inspect_utils.isbuiltin(float)) + self.assertTrue(inspect_utils.isbuiltin(int)) + self.assertTrue(inspect_utils.isbuiltin(len)) + self.assertFalse(inspect_utils.isbuiltin(function_decorator)) + if __name__ == '__main__': test.main() -- 2.7.4