class BuiltinFunctionTransformer(transformer.Base):
- """Handles builtin functions and canonicalizes old-style print statement.
+ """Handles builtin functions.
This transformer only covers functions that are translated into a
TF equivalent, like `len`.
- Note that the `print` statement is converted to a function call here, but
- wrapping the print function to a `py_func` is done by `call_trees` as a
- generic uncompilable function wrap.
"""
- # TODO(mdan): Handle print entirely in here.
- # Fully handling print here makes sense especially since we're considering
- # using tf.Print instead.
-
def __init__(self, context):
super(BuiltinFunctionTransformer, self).__init__(context)
+ # pylint:disable=invalid-name
+
def _convert_len(self, node):
template = """
tf.shape(args)[0]
"""
- new_call = templates.replace(template, args=node.args)[0].value
- return new_call
+ return templates.replace(template, args=node.args)[0].value
- # pylint:disable=invalid-name
+ def _convert_print(self, node):
+ template = """
+ py2tf_utils.call_print(args)
+ """
+ return templates.replace(template, args=node.args)[0].value
def visit_Call(self, node):
self.generic_visit(node)
# TODO(mdan): This won't work if the function was hidden.
if isinstance(node.func, gast.Name) and node.func.id == 'len':
return self._convert_len(node)
+ if isinstance(node.func, gast.Name) and node.func.id == 'print':
+ return self._convert_print(node)
return node
def visit_Print(self, node):
template = """
fname(args)
"""
- return templates.replace(template, fname='print', args=args)
+ function_call = templates.replace(template, fname='print', args=args)[0]
+ return self.visit(function_call)
# pylint:enable=invalid-name
from tensorflow.contrib.py2tf.converters import converter_test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
sess.run(
result.test_fn(constant_op.constant([0, 0, 0]))))
- def test_print(self):
+ def test_print_with_op(self):
def test_fn(a):
print(a)
node = self.parse_and_analyze(test_fn, {'print': print})
node = builtin_functions.transform(node, self.ctx)
- with self.compiled(node) as result:
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- result.test_fn('a')
- self.assertEqual(out_capturer.getvalue(), 'a\n')
- finally:
- sys.stdout = sys.__stdout__
+ # Note: it's relevant not to include script_ops.py_func here, to verify
+ # that tf.Print is used.
+ with self.compiled(node, logging_ops.Print) as result:
+ with self.test_session() as sess:
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ result.test_fn('a')
+ sess.run(sess.graph.get_operations())
+ self.assertEqual(out_capturer.getvalue(), 'a\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_print_with_op_multiple_values(self):
+
+ def test_fn(a, b):
+ print(a, b)
- def test_print_tuple(self):
+ node = self.parse_and_analyze(test_fn, {'print': print})
+ node = builtin_functions.transform(node, self.ctx)
+
+ # Note: it's relevant not to include script_ops.py_func here, to verify
+ # that tf.Print is used.
+ with self.compiled(node, logging_ops.Print) as result:
+ with self.test_session() as sess:
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ result.test_fn('a', 1)
+ sess.run(sess.graph.get_operations())
+ self.assertEqual(out_capturer.getvalue(), 'a 1\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_print_with_py_func(self):
def test_fn(a, b, c):
print(a, b, c)
node = self.parse_and_analyze(test_fn, {'print': print})
node = builtin_functions.transform(node, self.ctx)
- with self.compiled(node) as result:
- try:
- out_capturer = six.StringIO()
- sys.stdout = out_capturer
- result.test_fn('a', 1, [2, 3])
- # It appears that the print output looks odd only under Python 2.
- if six.PY2:
- self.assertEqual(out_capturer.getvalue(), "('a', 1, [2, 3])\n")
- else:
+ # Note: it's relevant not to include logging_ops.Print here, to verify
+ # that py_func is used.
+ with self.compiled(node, script_ops.py_func) as result:
+ with self.test_session() as sess:
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ result.test_fn('a', 1, [2, 3])
+ sess.run(sess.graph.get_operations())
self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n')
- finally:
- sys.stdout = sys.__stdout__
+ finally:
+ sys.stdout = sys.__stdout__
if __name__ == '__main__':
node = for_loops.transform(node, ctx)
# for_loops may insert new global references.
node = builtin_functions.transform(node, ctx)
- # TODO(mdan): Kept for CL consistency. Remove.
- # builtin_functions may insert new global references.
- ctx.namespace['print'] = print
node = _static_analysis_pass(node, ctx)
node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES,
"context_managers.py",
"misc.py",
"multiple_dispatch.py",
+ "printing.py",
"py_func.py",
"tensor_list.py",
"type_check.py",
],
)
+py_test(
+ name = "printing_test",
+ srcs = ["printing_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":utils",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "type_check_test",
srcs = ["type_check_test.py"],
from tensorflow.contrib.py2tf.utils.misc import alias_tensors
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
+from tensorflow.contrib.py2tf.utils.printing import call_print
from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
from tensorflow.contrib.py2tf.utils.type_check import is_tensor
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""TensorFlow printing support utilities."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.utils import py_func
+from tensorflow.python.ops import logging_ops
+
+
+def is_tf_print_compatible(value):
+ # TODO(mdan): Enable once we can reliably test this.
+ # This is currently disabled because we can't capture the output of
+ # op kernels from Python.
+ del value
+ return False
+
+
+def call_print(*values):
+ """Compiled counterpart of the print builtin.
+
+ The function attempts to use tf.Print if all the values are compatible.
+ Otherwise, it will fall back to py_func.
+
+ Args:
+ *values: values to print
+ Returns:
+ A dummy value indicating the print completed. If tf.
+ """
+
+ if all(map(is_tf_print_compatible, values)):
+ return logging_ops.Print(1, values)
+ return py_func.wrap_py_func(print, None, values, use_dummy_return=True)
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for printing module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import six
+
+from tensorflow.contrib.py2tf.utils import printing
+from tensorflow.python.platform import test
+
+
+class ContextManagersTest(test.TestCase):
+
+ def test_call_print_tf(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(printing.call_print('test message', 1))
+ self.assertEqual(out_capturer.getvalue(), 'test message 1\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+ def test_call_print_py_func(self):
+ try:
+ out_capturer = six.StringIO()
+ sys.stdout = out_capturer
+ with self.test_session() as sess:
+ sess.run(printing.call_print('test message', [1, 2]))
+ self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n')
+ finally:
+ sys.stdout = sys.__stdout__
+
+
+if __name__ == '__main__':
+ test.main()