Add dedicated code for the print function instead of wrapping it generically to py_func.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 15 Feb 2018 00:01:26 +0000 (16:01 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 15 Feb 2018 00:05:12 +0000 (16:05 -0800)
For now, we keep tf.Print disabled until we can find a way to test it. This might require launching the compiled code in a Python subprocess.

PiperOrigin-RevId: 185759599

tensorflow/contrib/py2tf/converters/builtin_functions.py
tensorflow/contrib/py2tf/converters/builtin_functions_test.py
tensorflow/contrib/py2tf/impl/conversion.py
tensorflow/contrib/py2tf/utils/BUILD
tensorflow/contrib/py2tf/utils/__init__.py
tensorflow/contrib/py2tf/utils/printing.py [new file with mode: 0644]
tensorflow/contrib/py2tf/utils/printing_test.py [new file with mode: 0644]

index 310681dd016ca94bf2b28d27a4968cc0c10a5842..2eb00f90575920ac948e799b0e97a9cfccb42fad 100644 (file)
@@ -25,36 +25,36 @@ from tensorflow.contrib.py2tf.pyct import transformer
 
 
 class BuiltinFunctionTransformer(transformer.Base):
-  """Handles builtin functions and canonicalizes old-style print statement.
+  """Handles builtin functions.
 
   This transformer only covers functions that are translated into a
   TF equivalent, like `len`.
-  Note that the `print` statement is converted to a function call here, but
-  wrapping the print function to a `py_func` is done by `call_trees` as a
-  generic uncompilable function wrap.
   """
 
-  # TODO(mdan): Handle print entirely in here.
-  # Fully handling print here makes sense especially since we're considering
-  # using tf.Print instead.
-
   def __init__(self, context):
     super(BuiltinFunctionTransformer, self).__init__(context)
 
+  # pylint:disable=invalid-name
+
   def _convert_len(self, node):
     template = """
       tf.shape(args)[0]
     """
-    new_call = templates.replace(template, args=node.args)[0].value
-    return new_call
+    return templates.replace(template, args=node.args)[0].value
 
-  # pylint:disable=invalid-name
+  def _convert_print(self, node):
+    template = """
+      py2tf_utils.call_print(args)
+    """
+    return templates.replace(template, args=node.args)[0].value
 
   def visit_Call(self, node):
     self.generic_visit(node)
     # TODO(mdan): This won't work if the function was hidden.
     if isinstance(node.func, gast.Name) and node.func.id == 'len':
       return self._convert_len(node)
+    if isinstance(node.func, gast.Name) and node.func.id == 'print':
+      return self._convert_print(node)
     return node
 
   def visit_Print(self, node):
@@ -66,7 +66,8 @@ class BuiltinFunctionTransformer(transformer.Base):
     template = """
       fname(args)
     """
-    return templates.replace(template, fname='print', args=args)
+    function_call = templates.replace(template, fname='print', args=args)[0]
+    return self.visit(function_call)
 
   # pylint:enable=invalid-name
 
index 983d1ffc03466ab3e2148e8cdf6e54050b9d3947..b279ff77ef10b96586d3d68585adb0d5424afb90 100644 (file)
@@ -26,6 +26,8 @@ from tensorflow.contrib.py2tf.converters import builtin_functions
 from tensorflow.contrib.py2tf.converters import converter_test_base
 from tensorflow.python.framework import constant_op
 from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import script_ops
 from tensorflow.python.platform import test
 
 
@@ -45,7 +47,7 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
                          sess.run(
                              result.test_fn(constant_op.constant([0, 0, 0]))))
 
-  def test_print(self):
+  def test_print_with_op(self):
 
     def test_fn(a):
       print(a)
@@ -53,16 +55,41 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
     node = self.parse_and_analyze(test_fn, {'print': print})
     node = builtin_functions.transform(node, self.ctx)
 
-    with self.compiled(node) as result:
-      try:
-        out_capturer = six.StringIO()
-        sys.stdout = out_capturer
-        result.test_fn('a')
-        self.assertEqual(out_capturer.getvalue(), 'a\n')
-      finally:
-        sys.stdout = sys.__stdout__
+    # Note: it's relevant not to include script_ops.py_func here, to verify
+    # that tf.Print is used.
+    with self.compiled(node, logging_ops.Print) as result:
+      with self.test_session() as sess:
+        try:
+          out_capturer = six.StringIO()
+          sys.stdout = out_capturer
+          result.test_fn('a')
+          sess.run(sess.graph.get_operations())
+          self.assertEqual(out_capturer.getvalue(), 'a\n')
+        finally:
+          sys.stdout = sys.__stdout__
+
+  def test_print_with_op_multiple_values(self):
+
+    def test_fn(a, b):
+      print(a, b)
 
-  def test_print_tuple(self):
+    node = self.parse_and_analyze(test_fn, {'print': print})
+    node = builtin_functions.transform(node, self.ctx)
+
+    # Note: it's relevant not to include script_ops.py_func here, to verify
+    # that tf.Print is used.
+    with self.compiled(node, logging_ops.Print) as result:
+      with self.test_session() as sess:
+        try:
+          out_capturer = six.StringIO()
+          sys.stdout = out_capturer
+          result.test_fn('a', 1)
+          sess.run(sess.graph.get_operations())
+          self.assertEqual(out_capturer.getvalue(), 'a 1\n')
+        finally:
+          sys.stdout = sys.__stdout__
+
+  def test_print_with_py_func(self):
 
     def test_fn(a, b, c):
       print(a, b, c)
@@ -70,18 +97,18 @@ class BuiltinFunctionsTest(converter_test_base.TestCase):
     node = self.parse_and_analyze(test_fn, {'print': print})
     node = builtin_functions.transform(node, self.ctx)
 
-    with self.compiled(node) as result:
-      try:
-        out_capturer = six.StringIO()
-        sys.stdout = out_capturer
-        result.test_fn('a', 1, [2, 3])
-        # It appears that the print output looks odd only under Python 2.
-        if six.PY2:
-          self.assertEqual(out_capturer.getvalue(), "('a', 1, [2, 3])\n")
-        else:
+    # Note: it's relevant not to include logging_ops.Print here, to verify
+    # that py_func is used.
+    with self.compiled(node, script_ops.py_func) as result:
+      with self.test_session() as sess:
+        try:
+          out_capturer = six.StringIO()
+          sys.stdout = out_capturer
+          result.test_fn('a', 1, [2, 3])
+          sess.run(sess.graph.get_operations())
           self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n')
-      finally:
-        sys.stdout = sys.__stdout__
+        finally:
+          sys.stdout = sys.__stdout__
 
 
 if __name__ == '__main__':
index ca13910ae5cff2c914ab7a17c843fe963e02f0df..3d5624b187ed47e9eed8afbb2e101e1098f81c15 100644 (file)
@@ -268,9 +268,6 @@ def node_to_graph(node, ctx, nocompile_decorators):
   node = for_loops.transform(node, ctx)
   # for_loops may insert new global references.
   node = builtin_functions.transform(node, ctx)
-  # TODO(mdan): Kept for CL consistency. Remove.
-  # builtin_functions may insert new global references.
-  ctx.namespace['print'] = print
 
   node = _static_analysis_pass(node, ctx)
   node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES,
index a679cb90765f08f024b3b1bb52b19aa5a0bc06f6..c2fdd40707775783140390e4b5c0186c9c3e562e 100644 (file)
@@ -23,6 +23,7 @@ py_library(
         "context_managers.py",
         "misc.py",
         "multiple_dispatch.py",
+        "printing.py",
         "py_func.py",
         "tensor_list.py",
         "type_check.py",
@@ -75,6 +76,16 @@ py_test(
     ],
 )
 
+py_test(
+    name = "printing_test",
+    srcs = ["printing_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":utils",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
 py_test(
     name = "type_check_test",
     srcs = ["type_check_test.py"],
index 838c29aafd8ab4c6b0165995d916291fdfcff10b..0a1b993fd366e1317e5f7e01fe849d86c93b8fc2 100644 (file)
@@ -22,5 +22,6 @@ from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_o
 from tensorflow.contrib.py2tf.utils.misc import alias_tensors
 from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
 from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
+from tensorflow.contrib.py2tf.utils.printing import call_print
 from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
 from tensorflow.contrib.py2tf.utils.type_check import is_tensor
diff --git a/tensorflow/contrib/py2tf/utils/printing.py b/tensorflow/contrib/py2tf/utils/printing.py
new file mode 100644 (file)
index 0000000..95a62bd
--- /dev/null
@@ -0,0 +1,47 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow printing support utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.utils import py_func
+from tensorflow.python.ops import logging_ops
+
+
+def is_tf_print_compatible(value):
+  # TODO(mdan): Enable once we can reliably test this.
+  # This is currently disabled because we can't capture the output of
+  # op kernels from Python.
+  del value
+  return False
+
+
+def call_print(*values):
+  """Compiled counterpart of the print builtin.
+
+  The function attempts to use tf.Print if all the values are compatible.
+  Otherwise, it will fall back to py_func.
+
+  Args:
+    *values: values to print
+  Returns:
+    A dummy value indicating the print completed. If tf.
+  """
+
+  if all(map(is_tf_print_compatible, values)):
+    return logging_ops.Print(1, values)
+  return py_func.wrap_py_func(print, None, values, use_dummy_return=True)
diff --git a/tensorflow/contrib/py2tf/utils/printing_test.py b/tensorflow/contrib/py2tf/utils/printing_test.py
new file mode 100644 (file)
index 0000000..2070deb
--- /dev/null
@@ -0,0 +1,53 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for printing module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import six
+
+from tensorflow.contrib.py2tf.utils import printing
+from tensorflow.python.platform import test
+
+
+class ContextManagersTest(test.TestCase):
+
+  def test_call_print_tf(self):
+    try:
+      out_capturer = six.StringIO()
+      sys.stdout = out_capturer
+      with self.test_session() as sess:
+        sess.run(printing.call_print('test message', 1))
+        self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
+    finally:
+      sys.stdout = sys.__stdout__
+
+  def test_call_print_py_func(self):
+    try:
+      out_capturer = six.StringIO()
+      sys.stdout = out_capturer
+      with self.test_session() as sess:
+        sess.run(printing.call_print('test message', [1, 2]))
+        self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
+    finally:
+      sys.stdout = sys.__stdout__
+
+
+if __name__ == '__main__':
+  test.main()