Consolidate the builtin function overrides into a single module, and use a generic...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 27 Feb 2018 01:04:09 +0000 (17:04 -0800)
committerGunhan Gulsoy <gunan@google.com>
Tue, 27 Feb 2018 22:33:33 +0000 (14:33 -0800)
PiperOrigin-RevId: 187104685

tensorflow/contrib/py2tf/converters/builtin_functions.py
tensorflow/contrib/py2tf/utils/BUILD
tensorflow/contrib/py2tf/utils/__init__.py
tensorflow/contrib/py2tf/utils/builtins.py [moved from tensorflow/contrib/py2tf/utils/printing.py with 62% similarity]
tensorflow/contrib/py2tf/utils/builtins_test.py [moved from tensorflow/contrib/py2tf/utils/printing_test.py with 56% similarity]
tensorflow/contrib/py2tf/utils/misc.py
tensorflow/contrib/py2tf/utils/misc_test.py

index e69038a..b5aa975 100644 (file)
@@ -36,23 +36,24 @@ class BuiltinFunctionTransformer(transformer.Base):
 
   # pylint:disable=invalid-name
 
-  def _convert_len(self, node):
+  def _convert_builtin(self, node):
     template = """
-      py2tf_utils.dynamic_len(args)
+      py2tf_utils.dynamic_builtin(func, args)
     """
-    return templates.replace(template, args=node.args)[0].value
+    return templates.replace(template, func=node.func, args=node.args)[0].value
 
   def _convert_print(self, node):
     template = """
-      py2tf_utils.call_print(args)
+      py2tf_utils.dynamic_print(args)
     """
     return templates.replace(template, args=node.args)[0].value
 
   def visit_Call(self, node):
     self.generic_visit(node)
     # TODO(mdan): This won't work if the function was hidden.
-    if isinstance(node.func, gast.Name) and node.func.id == 'len':
-      return self._convert_len(node)
+    if isinstance(node.func, gast.Name) and node.func.id in ('len',):
+      return self._convert_builtin(node)
+    # Print needs to be handled separately because it can be read as statement.
     if isinstance(node.func, gast.Name) and node.func.id == 'print':
       return self._convert_print(node)
     return node
index c2fdd40..2086a9e 100644 (file)
@@ -20,10 +20,10 @@ py_library(
     name = "utils",
     srcs = [
         "__init__.py",
+        "builtins.py",
         "context_managers.py",
         "misc.py",
         "multiple_dispatch.py",
-        "printing.py",
         "py_func.py",
         "tensor_list.py",
         "type_check.py",
@@ -77,16 +77,6 @@ py_test(
 )
 
 py_test(
-    name = "printing_test",
-    srcs = ["printing_test.py"],
-    srcs_version = "PY2AND3",
-    deps = [
-        ":utils",
-        "//tensorflow/python:client_testlib",
-    ],
-)
-
-py_test(
     name = "type_check_test",
     srcs = ["type_check_test.py"],
     srcs_version = "PY2AND3",
index d931322..19bf227 100644 (file)
@@ -18,11 +18,11 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin
+from tensorflow.contrib.py2tf.utils.builtins import dynamic_print
 from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
 from tensorflow.contrib.py2tf.utils.misc import alias_tensors
-from tensorflow.contrib.py2tf.utils.misc import dynamic_len
 from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
 from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
-from tensorflow.contrib.py2tf.utils.printing import call_print
 from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
 from tensorflow.contrib.py2tf.utils.type_check import is_tensor
similarity index 62%
rename from tensorflow/contrib/py2tf/utils/printing.py
rename to tensorflow/contrib/py2tf/utils/builtins.py
index 95a62bd..0a50b80 100644 (file)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""TensorFlow printing support utilities."""
+"""Builtin conversion utilities."""
 
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.py2tf.utils import py_func
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import logging_ops
+from tensorflow.python.util import tf_inspect
+
+
+def dynamic_builtin(f, *args, **kwargs):
+  """Converts a builtin function call inline."""
+  if not tf_inspect.isbuiltin(f):
+    return f(*args, **kwargs)
+
+  if f is len:
+    return dynamic_len(*args, **kwargs)
+
+  raise NotImplementedError('The "%s" builtin is not yet supported.' % f)
+
+
+def dynamic_len(list_or_tensor):
+  """Implementation of len using dynamic dispatch."""
+  if tensor_util.is_tensor(list_or_tensor):
+    shape = list_or_tensor.shape
+    if not shape:
+      raise ValueError(
+          'len requires non-zero rank for tensor "%s"' % list_or_tensor)
+    return array_ops.shape(list_or_tensor)[0]
+
+  return len(list_or_tensor)
 
 
 def is_tf_print_compatible(value):
@@ -30,8 +56,8 @@ def is_tf_print_compatible(value):
   return False
 
 
-def call_print(*values):
-  """Compiled counterpart of the print builtin.
+def dynamic_print(*values):
+  """Implementartion of print using dynamic dispatch.
 
   The function attempts to use tf.Print if all the values are compatible.
   Otherwise, it will fall back to py_func.
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Tests for printing module."""
+"""Tests for builtins module."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -22,28 +22,53 @@ import sys
 
 import six
 
-from tensorflow.contrib.py2tf.utils import printing
+from tensorflow.contrib.py2tf.utils import builtins
+from tensorflow.python.framework import constant_op
 from tensorflow.python.platform import test
 
 
-class ContextManagersTest(test.TestCase):
+class BuiltinsTest(test.TestCase):
 
-  def test_call_print_tf(self):
+  def test_dynamic_len_tf_scalar(self):
+    a = constant_op.constant(1)
+
+    with self.assertRaises(ValueError):
+      with self.test_session() as sess:
+        sess.run(builtins.dynamic_builtin(len, a))
+
+  def test_dynamic_len_tf_array(self):
+    a = constant_op.constant([1, 2, 3])
+
+    with self.test_session() as sess:
+      self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
+
+  def test_dynamic_len_tf_matrix(self):
+    a = constant_op.constant([[1, 2], [3, 4]])
+
+    with self.test_session() as sess:
+      self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a)))
+
+  def test_dynamic_len_py_list(self):
+    a = [3] * 5
+
+    self.assertEqual(5, builtins.dynamic_builtin(len, a))
+
+  def test_dynamic_print_tf(self):
     try:
       out_capturer = six.StringIO()
       sys.stdout = out_capturer
       with self.test_session() as sess:
-        sess.run(printing.call_print('test message', 1))
+        sess.run(builtins.dynamic_print('test message', 1))
         self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
     finally:
       sys.stdout = sys.__stdout__
 
-  def test_call_print_py_func(self):
+  def test_dynamic_print_complex(self):
     try:
       out_capturer = six.StringIO()
       sys.stdout = out_capturer
       with self.test_session() as sess:
-        sess.run(printing.call_print('test message', [1, 2]))
+        sess.run(builtins.dynamic_print('test message', [1, 2]))
         self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
     finally:
       sys.stdout = sys.__stdout__
index 7548048..1b06caf 100644 (file)
@@ -19,22 +19,9 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
 from tensorflow.python.ops import array_ops
 
 
-def dynamic_len(list_or_tensor):
-  """Implementation of len using dynamic dispatch."""
-  if tensor_util.is_tensor(list_or_tensor):
-    shape = list_or_tensor.shape
-    if not shape:
-      raise ValueError(
-          'len requires non-zero rank for tensor "%s"' % list_or_tensor)
-    return array_ops.shape(list_or_tensor)[0]
-
-  return len(list_or_tensor)
-
-
 def alias_tensors(*args):
   """Wrap any Tensor arguments with an identity op.
 
index ec88e7c..8aedd4c 100644 (file)
@@ -19,37 +19,12 @@ from __future__ import division
 from __future__ import print_function
 
 from tensorflow.contrib.py2tf.utils.misc import alias_tensors
-from tensorflow.contrib.py2tf.utils.misc import dynamic_len
 from tensorflow.python.framework.constant_op import constant
 from tensorflow.python.ops.variables import Variable
 from tensorflow.python.platform import test
 
 
-class ContextManagersTest(test.TestCase):
-
-  def test_dynamic_len_tf_scalar(self):
-    a = constant(1)
-
-    with self.assertRaises(ValueError):
-      with self.test_session() as sess:
-        sess.run(dynamic_len(a))
-
-  def test_dynamic_len_tf_array(self):
-    a = constant([1, 2, 3])
-
-    with self.test_session() as sess:
-      self.assertEqual(3, sess.run(dynamic_len(a)))
-
-  def test_dynamic_len_tf_matrix(self):
-    a = constant([[1, 2], [3, 4]])
-
-    with self.test_session() as sess:
-      self.assertEqual(2, sess.run(dynamic_len(a)))
-
-  def test_dynamic_len_py_list(self):
-    a = [3] * 5
-
-    self.assertEqual(5, dynamic_len(a))
+class MiscTest(test.TestCase):
 
   def test_alias_single_tensor(self):
     a = constant(1)