From 7bcc7ee1a9da4ec55395a935123a46b4ecb2364f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 26 Feb 2018 17:04:09 -0800 Subject: [PATCH] Consolidate the builtin function overrides into a single module, and use a generic `dynamic_builtin` function to dispatch between implementations. Use the generic dispatcher in the generated code. PiperOrigin-RevId: 187104685 --- .../contrib/py2tf/converters/builtin_functions.py | 13 ++++---- tensorflow/contrib/py2tf/utils/BUILD | 12 +------ tensorflow/contrib/py2tf/utils/__init__.py | 4 +-- .../py2tf/utils/{printing.py => builtins.py} | 32 ++++++++++++++++-- .../utils/{printing_test.py => builtins_test.py} | 39 ++++++++++++++++++---- tensorflow/contrib/py2tf/utils/misc.py | 13 -------- tensorflow/contrib/py2tf/utils/misc_test.py | 27 +-------------- 7 files changed, 72 insertions(+), 68 deletions(-) rename tensorflow/contrib/py2tf/utils/{printing.py => builtins.py} (62%) rename tensorflow/contrib/py2tf/utils/{printing_test.py => builtins_test.py} (56%) diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py index e69038a..b5aa975 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions.py +++ b/tensorflow/contrib/py2tf/converters/builtin_functions.py @@ -36,23 +36,24 @@ class BuiltinFunctionTransformer(transformer.Base): # pylint:disable=invalid-name - def _convert_len(self, node): + def _convert_builtin(self, node): template = """ - py2tf_utils.dynamic_len(args) + py2tf_utils.dynamic_builtin(func, args) """ - return templates.replace(template, args=node.args)[0].value + return templates.replace(template, func=node.func, args=node.args)[0].value def _convert_print(self, node): template = """ - py2tf_utils.call_print(args) + py2tf_utils.dynamic_print(args) """ return templates.replace(template, args=node.args)[0].value def visit_Call(self, node): self.generic_visit(node) # TODO(mdan): This won't work if the function was hidden. - if isinstance(node.func, gast.Name) and node.func.id == 'len': - return self._convert_len(node) + if isinstance(node.func, gast.Name) and node.func.id in ('len',): + return self._convert_builtin(node) + # Print needs to be handled separately because it can be read as statement. if isinstance(node.func, gast.Name) and node.func.id == 'print': return self._convert_print(node) return node diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD index c2fdd40..2086a9e 100644 --- a/tensorflow/contrib/py2tf/utils/BUILD +++ b/tensorflow/contrib/py2tf/utils/BUILD @@ -20,10 +20,10 @@ py_library( name = "utils", srcs = [ "__init__.py", + "builtins.py", "context_managers.py", "misc.py", "multiple_dispatch.py", - "printing.py", "py_func.py", "tensor_list.py", "type_check.py", @@ -77,16 +77,6 @@ py_test( ) py_test( - name = "printing_test", - srcs = ["printing_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":utils", - "//tensorflow/python:client_testlib", - ], -) - -py_test( name = "type_check_test", srcs = ["type_check_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py index d931322..19bf227 100644 --- a/tensorflow/contrib/py2tf/utils/__init__.py +++ b/tensorflow/contrib/py2tf/utils/__init__.py @@ -18,11 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin +from tensorflow.contrib.py2tf.utils.builtins import dynamic_print from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns from tensorflow.contrib.py2tf.utils.misc import alias_tensors -from tensorflow.contrib.py2tf.utils.misc import dynamic_len from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while -from tensorflow.contrib.py2tf.utils.printing import call_print from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func from tensorflow.contrib.py2tf.utils.type_check import is_tensor diff --git a/tensorflow/contrib/py2tf/utils/printing.py b/tensorflow/contrib/py2tf/utils/builtins.py similarity index 62% rename from tensorflow/contrib/py2tf/utils/printing.py rename to tensorflow/contrib/py2tf/utils/builtins.py index 95a62bd..0a50b80 100644 --- a/tensorflow/contrib/py2tf/utils/printing.py +++ b/tensorflow/contrib/py2tf/utils/builtins.py @@ -12,14 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow printing support utilities.""" +"""Builtin conversion utilities.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.py2tf.utils import py_func +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops +from tensorflow.python.util import tf_inspect + + +def dynamic_builtin(f, *args, **kwargs): + """Converts a builtin function call inline.""" + if not tf_inspect.isbuiltin(f): + return f(*args, **kwargs) + + if f is len: + return dynamic_len(*args, **kwargs) + + raise NotImplementedError('The "%s" builtin is not yet supported.' % f) + + +def dynamic_len(list_or_tensor): + """Implementation of len using dynamic dispatch.""" + if tensor_util.is_tensor(list_or_tensor): + shape = list_or_tensor.shape + if not shape: + raise ValueError( + 'len requires non-zero rank for tensor "%s"' % list_or_tensor) + return array_ops.shape(list_or_tensor)[0] + + return len(list_or_tensor) def is_tf_print_compatible(value): @@ -30,8 +56,8 @@ def is_tf_print_compatible(value): return False -def call_print(*values): - """Compiled counterpart of the print builtin. +def dynamic_print(*values): + """Implementartion of print using dynamic dispatch. The function attempts to use tf.Print if all the values are compatible. Otherwise, it will fall back to py_func. diff --git a/tensorflow/contrib/py2tf/utils/printing_test.py b/tensorflow/contrib/py2tf/utils/builtins_test.py similarity index 56% rename from tensorflow/contrib/py2tf/utils/printing_test.py rename to tensorflow/contrib/py2tf/utils/builtins_test.py index 2070deb..19a72c6 100644 --- a/tensorflow/contrib/py2tf/utils/printing_test.py +++ b/tensorflow/contrib/py2tf/utils/builtins_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for printing module.""" +"""Tests for builtins module.""" from __future__ import absolute_import from __future__ import division @@ -22,28 +22,53 @@ import sys import six -from tensorflow.contrib.py2tf.utils import printing +from tensorflow.contrib.py2tf.utils import builtins +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test -class ContextManagersTest(test.TestCase): +class BuiltinsTest(test.TestCase): - def test_call_print_tf(self): + def test_dynamic_len_tf_scalar(self): + a = constant_op.constant(1) + + with self.assertRaises(ValueError): + with self.test_session() as sess: + sess.run(builtins.dynamic_builtin(len, a)) + + def test_dynamic_len_tf_array(self): + a = constant_op.constant([1, 2, 3]) + + with self.test_session() as sess: + self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_tf_matrix(self): + a = constant_op.constant([[1, 2], [3, 4]]) + + with self.test_session() as sess: + self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_py_list(self): + a = [3] * 5 + + self.assertEqual(5, builtins.dynamic_builtin(len, a)) + + def test_dynamic_print_tf(self): try: out_capturer = six.StringIO() sys.stdout = out_capturer with self.test_session() as sess: - sess.run(printing.call_print('test message', 1)) + sess.run(builtins.dynamic_print('test message', 1)) self.assertEqual(out_capturer.getvalue(), 'test message 1\n') finally: sys.stdout = sys.__stdout__ - def test_call_print_py_func(self): + def test_dynamic_print_complex(self): try: out_capturer = six.StringIO() sys.stdout = out_capturer with self.test_session() as sess: - sess.run(printing.call_print('test message', [1, 2])) + sess.run(builtins.dynamic_print('test message', [1, 2])) self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') finally: sys.stdout = sys.__stdout__ diff --git a/tensorflow/contrib/py2tf/utils/misc.py b/tensorflow/contrib/py2tf/utils/misc.py index 7548048..1b06caf 100644 --- a/tensorflow/contrib/py2tf/utils/misc.py +++ b/tensorflow/contrib/py2tf/utils/misc.py @@ -19,22 +19,9 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops -def dynamic_len(list_or_tensor): - """Implementation of len using dynamic dispatch.""" - if tensor_util.is_tensor(list_or_tensor): - shape = list_or_tensor.shape - if not shape: - raise ValueError( - 'len requires non-zero rank for tensor "%s"' % list_or_tensor) - return array_ops.shape(list_or_tensor)[0] - - return len(list_or_tensor) - - def alias_tensors(*args): """Wrap any Tensor arguments with an identity op. diff --git a/tensorflow/contrib/py2tf/utils/misc_test.py b/tensorflow/contrib/py2tf/utils/misc_test.py index ec88e7c..8aedd4c 100644 --- a/tensorflow/contrib/py2tf/utils/misc_test.py +++ b/tensorflow/contrib/py2tf/utils/misc_test.py @@ -19,37 +19,12 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.py2tf.utils.misc import alias_tensors -from tensorflow.contrib.py2tf.utils.misc import dynamic_len from tensorflow.python.framework.constant_op import constant from tensorflow.python.ops.variables import Variable from tensorflow.python.platform import test -class ContextManagersTest(test.TestCase): - - def test_dynamic_len_tf_scalar(self): - a = constant(1) - - with self.assertRaises(ValueError): - with self.test_session() as sess: - sess.run(dynamic_len(a)) - - def test_dynamic_len_tf_array(self): - a = constant([1, 2, 3]) - - with self.test_session() as sess: - self.assertEqual(3, sess.run(dynamic_len(a))) - - def test_dynamic_len_tf_matrix(self): - a = constant([[1, 2], [3, 4]]) - - with self.test_session() as sess: - self.assertEqual(2, sess.run(dynamic_len(a))) - - def test_dynamic_len_py_list(self): - a = [3] * 5 - - self.assertEqual(5, dynamic_len(a)) +class MiscTest(test.TestCase): def test_alias_single_tensor(self): a = constant(1) -- 2.7.4