Core code transformers.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 8 Jan 2018 22:42:52 +0000 (14:42 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 8 Jan 2018 22:46:39 +0000 (14:46 -0800)
PiperOrigin-RevId: 181221299

16 files changed:
tensorflow/BUILD
tensorflow/contrib/py2tf/convert/BUILD [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/__init__.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/call_trees.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/call_trees_test.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/control_flow.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/control_flow_test.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/gradients_function.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/gradients_function_test.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/logical_expressions.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/logical_expressions_test.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/print_functions.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/print_functions_test.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/side_effect_guards.py [new file with mode: 0644]
tensorflow/contrib/py2tf/convert/side_effect_guards_test.py [new file with mode: 0644]
tensorflow/tools/pip_package/BUILD

index 06c9c2ba9d34f26fe57782d6da6b17dec23b358b..ebd2cd56926dc1c0c256a778019db28de8e8df2c 100644 (file)
@@ -521,6 +521,7 @@ filegroup(
         "//tensorflow/contrib/periodic_resample:all_files",
         "//tensorflow/contrib/predictor:all_files",
         "//tensorflow/contrib/py2tf:all_files",
+        "//tensorflow/contrib/py2tf/convert:all_files",
         "//tensorflow/contrib/py2tf/pyct:all_files",
         "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
         "//tensorflow/contrib/quantize:all_files",
diff --git a/tensorflow/contrib/py2tf/convert/BUILD b/tensorflow/contrib/py2tf/convert/BUILD
new file mode 100644 (file)
index 0000000..84a75ff
--- /dev/null
@@ -0,0 +1,100 @@
+licenses(["notice"])  # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+    name = "all_files",
+    srcs = glob(
+        ["**/*"],
+        exclude = [
+            "**/METADATA",
+            "**/OWNERS",
+        ],
+    ),
+    visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+    name = "convert",
+    srcs = [
+        "call_trees.py",
+        "control_flow.py",
+        "gradients_function.py",
+        "logical_expressions.py",
+        "print_functions.py",
+        "side_effect_guards.py",
+    ],
+    srcs_version = "PY2AND3",
+    visibility = ["//tensorflow:__subpackages__"],
+    deps = [
+        "@gast_archive//:gast",
+    ],
+)
+
+py_test(
+    name = "call_trees_test",
+    srcs = ["call_trees_test.py"],
+    deps = [
+        ":convert",
+        "//tensorflow/contrib/py2tf/pyct",
+        "//tensorflow/contrib/py2tf/pyct/static_analysis",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "control_flow_test",
+    srcs = ["control_flow_test.py"],
+    deps = [
+        ":convert",
+        "//tensorflow/contrib/py2tf/pyct",
+        "//tensorflow/contrib/py2tf/pyct/static_analysis",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "gradients_function_test",
+    srcs = ["gradients_function_test.py"],
+    deps = [
+        ":convert",
+        "//tensorflow/contrib/eager/python:tfe",
+        "//tensorflow/contrib/py2tf/pyct",
+        "//tensorflow/contrib/py2tf/pyct/static_analysis",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "logical_expressions_test",
+    srcs = ["logical_expressions_test.py"],
+    deps = [
+        ":convert",
+        "//tensorflow/contrib/py2tf/pyct",
+        "//tensorflow/contrib/py2tf/pyct/static_analysis",
+        "//tensorflow/python:client_testlib",
+    ],
+)
+
+py_test(
+    name = "print_functions_test",
+    srcs = ["print_functions_test.py"],
+    deps = [
+        ":convert",
+        "//tensorflow/contrib/py2tf/pyct",
+        "//tensorflow/contrib/py2tf/pyct/static_analysis",
+        "//tensorflow/python:client_testlib",
+        "@gast_archive//:gast",
+    ],
+)
+
+py_test(
+    name = "side_effect_guards_test",
+    srcs = ["side_effect_guards_test.py"],
+    deps = [
+        ":convert",
+        "//tensorflow/contrib/py2tf/pyct",
+        "//tensorflow/contrib/py2tf/pyct/static_analysis",
+        "//tensorflow/python:client_testlib",
+    ],
+)
diff --git a/tensorflow/contrib/py2tf/convert/__init__.py b/tensorflow/contrib/py2tf/convert/__init__.py
new file mode 100644 (file)
index 0000000..dba3b02
--- /dev/null
@@ -0,0 +1,21 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Code converters used by Py2TF."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(mdan): Define a base transformer class that can recognize skip_processing
diff --git a/tensorflow/contrib/py2tf/convert/call_trees.py b/tensorflow/contrib/py2tf/convert/call_trees.py
new file mode 100644 (file)
index 0000000..a0886f5
--- /dev/null
@@ -0,0 +1,158 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles function calls, by generating compiled function names and calls."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import types
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class FunctionNamer(object):
+  """Describes the interface for CallTreeTransformer's namer."""
+
+  def compiled_function_name(self, original_name, live_object=None):
+    """Generate the name corresponding to the compiled version of a function.
+
+    Args:
+      original_name: String
+      live_object: Callable, the actual target function, if known.
+    Returns:
+      String.
+    """
+    raise NotImplementedError()
+
+
+class CallTreeTransformer(gast.NodeTransformer):
+  """Transforms the call tree by renaming transformed symbols."""
+
+  def __init__(self, namer, uncompiled_modules):
+    self.namer = namer
+    self.uncompiled_modules = uncompiled_modules
+
+  # pylint:disable=invalid-name
+
+  def visit_FunctionDef(self, node):
+    self.generic_visit(node)
+    node.name = self.namer.compiled_function_name(node.name)
+    return node
+
+  def _rename_compilable_function(self, node):
+    assert anno.hasanno(node.func, 'live_val')
+    assert anno.hasanno(node.func, 'fqn')
+    target_obj = anno.getanno(node.func, 'live_val')
+    target_fqn = anno.getanno(node.func, 'fqn')
+
+    fqn = ''
+    for s in target_fqn:
+      if fqn:
+        fqn += '.'
+      fqn += s
+      if fqn in self.uncompiled_modules:
+        return node
+
+    new_name = self.namer.compiled_function_name(fqn, live_object=target_obj)
+    node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
+    return node
+
+  def _rename_member_function_of_known_type(self, node):
+    target_fqn = anno.getanno(node.func, 'type_fqn')
+
+    fqn = ''
+    for s in target_fqn:
+      if fqn:
+        fqn += '.'
+      fqn += s
+      if fqn in self.uncompiled_modules:
+        return node
+
+    raise NotImplementedError('Member function call (of known type).')
+
+  def _wrap_to_py_func_no_return(self, node):
+    args_scope = anno.getanno(node, 'args_scope')
+    # TODO(mdan): Properly handle varargs, kwargs, etc.
+    args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)
+
+    # pylint:disable=undefined-variable,unused-argument,function-redefined
+
+    def template(call, wrapper, args):
+
+      def wrapper(args):
+        call(args)
+        return 1
+
+      tf.py_func(wrapper, [args], [tf.int64])
+
+    # pylint:enable=undefined-variable,unused-argument,function-redefined
+
+    wrapper_name = self.namer.compiled_function_name(node.func.id)
+    wrapper_def, call_expr = templates.replace(
+        template,
+        call=node.func,
+        wrapper=gast.Name(wrapper_name, gast.Load(), None),
+        args=args)
+    anno.setanno(call_expr.value, 'args_scope', args_scope)
+    anno.setanno(wrapper_def, 'skip_processing', True)
+
+    return (wrapper_def, call_expr)
+
+  def visit_Expr(self, node):
+    if isinstance(node.value, gast.Call):
+      node = self._wrap_to_py_func_no_return(node.value)
+    else:
+      self.generic_visit(node)
+    return node
+
+  def visit_Call(self, node):
+    self.generic_visit(node)
+    if anno.hasanno(node.func, 'live_val'):
+      target_obj = anno.getanno(node.func, 'live_val')
+      if isinstance(target_obj, types.BuiltinFunctionType):
+        raise NotImplementedError('py_func with return values')
+      else:
+        node = self._rename_compilable_function(node)
+    elif anno.hasanno(node.func, 'type_fqn'):
+      node = self._rename_member_function_of_known_type(node)
+    else:
+      raise NotImplementedError(
+          'Member function call (of unknown type): %s.' % node.func.id)
+    return node
+
+  # pylint:enable=invalid-name
+
+
+def transform(node, namer, uncompiled_modules):
+  """Transform function call to the compiled counterparts.
+
+  Args:
+    node: AST to transform.
+    namer: FunctionNamer-like.
+    uncompiled_modules: set of string tuples, each tuple represents the fully
+        qualified name of a package containing functions that will not be
+        compiled.
+  Returns:
+    A tuple (node, new_names):
+        node: The transformed AST
+        new_names: set(string), containing any newly-generated names
+  """
+  transformer = CallTreeTransformer(namer, uncompiled_modules)
+  node = transformer.visit(node)
+  return node
diff --git a/tensorflow/contrib/py2tf/convert/call_trees_test.py b/tensorflow/contrib/py2tf/convert/call_trees_test.py
new file mode 100644 (file)
index 0000000..a27aa2d
--- /dev/null
@@ -0,0 +1,91 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for call_trees module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import call_trees
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class TestNamer(call_trees.FunctionNamer):
+
+  def compiled_function_name(self, original_name, live_object=None):
+    return 'renamed_%s' % original_name
+
+
+class CallTreesTest(test.TestCase):
+
+  def _parse_and_analyze(self, test_fn, namespace):
+    node = parser.parse_object(test_fn)
+    node = access.resolve(node)
+    node = live_values.resolve(node, namespace, {})
+    node = type_info.resolve(node, None)
+    return node
+
+  def test_basic(self):
+
+    def test_fn_1(_):
+      raise ValueError('This should not be called in the compiled verison.')
+
+    def renamed_test_fn_1(a):
+      return a + 1
+
+    def test_fn_2(a):
+      return test_fn_1(a) + 1
+
+    node = self._parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
+    node = call_trees.transform(node, TestNamer(), set())
+    result = compiler.ast_to_object(node)
+    # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
+    setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
+
+    self.assertEquals(3, result.renamed_test_fn_2(1))
+
+  def test_uncompiled_modules(self):
+
+    def test_fn(a):
+      a = math_ops.multiply(a, constant_op.constant(2))
+      a = math_ops.add(a, constant_op.constant(1))
+      return a
+
+    node = self._parse_and_analyze(test_fn, {
+        'math_ops': math_ops,
+        'constant_op': constant_op
+    })
+    node = call_trees.transform(node, TestNamer(),
+                                set((math_ops.__name__, constant_op.__name__)))
+    result = compiler.ast_to_object(node)
+    setattr(result, 'math_ops', math_ops)
+    setattr(result, 'constant_op', constant_op)
+
+    with self.test_session() as sess:
+      result_tensor = result.renamed_test_fn(constant_op.constant(1))
+      result_val = sess.run(result_tensor)
+
+    self.assertEquals(3, result_val)
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/py2tf/convert/control_flow.py b/tensorflow/contrib/py2tf/convert/control_flow.py
new file mode 100644 (file)
index 0000000..fed9ac4
--- /dev/null
@@ -0,0 +1,118 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Identity converter. Useful for testing and diagnostic."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class SymbolNamer(object):
+  """Describes the interface for ControlFlowTransformer's namer."""
+
+  def new_symbol(self, name_root, reserved_locals):
+    """Generate a new unique symbol.
+
+    Args:
+      name_root: String, used as stem in the new name.
+      reserved_locals: Set(string), additional local symbols that are reserved
+          and which should not be used.
+    Returns:
+      String.
+    """
+    raise NotImplementedError()
+
+
+class ControlFlowTransformer(gast.NodeTransformer):
+  """Transforms control flow structures like loops an conditionals."""
+
+  def __init__(self, namer):
+    self.namer = namer
+
+  # pylint:disable=invalid-name
+
+  def _tuple_or_item(self, elts):
+    elts = tuple(elts)
+    if len(elts) == 1:
+      return elts[0]
+    return elts
+
+  def _ast_tuple_or_item(self, elts, ctx):
+    elts = list(elts)
+    if len(elts) == 1:
+      return elts[0]
+    return gast.Tuple(elts, ctx)
+
+  def visit_If(self, node):
+    raise NotImplementedError()
+
+  def visit_While(self, node):
+    self.generic_visit(node)
+    # Scrape out the data flow analysis
+    body_scope = anno.getanno(node, 'body_scope')
+    parent_scope_values = anno.getanno(node, 'parent_scope_values')
+    body_closure = tuple(body_scope.modified - body_scope.created)
+
+    def template(
+        state_args,  # pylint:disable=unused-argument
+        state_locals,
+        state_results,  # pylint:disable=unused-argument
+        test_name,
+        test,  # pylint:disable=unused-argument
+        body_name,
+        body,
+        state_init):
+
+      def test_name(state_args):  # pylint:disable=function-redefined,unused-argument
+        return test
+
+      def body_name(state_args):  # pylint:disable=function-redefined,unused-argument
+        body  # pylint:disable=pointless-statement
+        return state_locals
+
+      state_results = tf.while_loop(test_name, body_name, [state_init])  # pylint:disable=undefined-variable
+
+    test_name = self.namer.new_symbol('loop_test', body_scope.used)
+    body_name = self.namer.new_symbol('loop_body', body_scope.used)
+    node = templates.replace(
+        template,
+        state_args=self._tuple_or_item(
+            gast.Name(n, gast.Param(), None) for n in body_closure),
+        state_locals=self._ast_tuple_or_item(
+            (gast.Name(n, gast.Load(), None) for n in body_closure),
+            gast.Load()),
+        state_results=self._ast_tuple_or_item(
+            (gast.Name(n, gast.Store(), None) for n in body_closure),
+            gast.Store()),
+        test_name=gast.Name(test_name, gast.Load(), None),
+        test=node.test,
+        body_name=gast.Name(body_name, gast.Load(), None),
+        body=node.body,
+        state_init=[parent_scope_values.getval(n) for n in body_closure])
+
+    return node
+
+  # pylint:enable=invalid-name
+
+
+def transform(node, namer):
+  transformer = ControlFlowTransformer(namer)
+  node = transformer.visit(node)
+  return node
diff --git a/tensorflow/contrib/py2tf/convert/control_flow_test.py b/tensorflow/contrib/py2tf/convert/control_flow_test.py
new file mode 100644 (file)
index 0000000..c27a079
--- /dev/null
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for control_flow module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import control_flow
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.platform import test
+
+
+class TestNamer(control_flow.SymbolNamer):
+
+  def new_symbol(self, name_root, _):
+    return name_root
+
+
+class ControlFlowTest(test.TestCase):
+
+  def _parse_and_analyze(self, test_fn, namespace):
+    node = parser.parse_object(test_fn)
+    node = access.resolve(node)
+    node = live_values.resolve(node, namespace, {})
+    node = type_info.resolve(node, None)
+    return node
+
+  def test_simple_while(self):
+
+    def test_fn(n):
+      i = 0
+      s = 0
+      while i < n:
+        s += i
+        i += 1
+      return s, i, n
+
+    node = self._parse_and_analyze(test_fn, {})
+    node = control_flow.transform(node, TestNamer())
+    result = compiler.ast_to_object(node)
+    setattr(result, 'tf', control_flow_ops)
+
+    with self.test_session() as sess:
+      self.assertEqual((10, 5, 5),
+                       sess.run(result.test_fn(constant_op.constant(5))))
+
+  def test_while_single_var(self):
+
+    def test_fn(n):
+      while n > 0:
+        n -= 1
+      return n
+
+    node = self._parse_and_analyze(test_fn, {})
+    node = control_flow.transform(node, TestNamer())
+    result = compiler.ast_to_object(node)
+    setattr(result, 'tf', control_flow_ops)
+
+    with self.test_session() as sess:
+      self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/py2tf/convert/gradients_function.py b/tensorflow/contrib/py2tf/convert/gradients_function.py
new file mode 100644 (file)
index 0000000..f3c07db
--- /dev/null
@@ -0,0 +1,80 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Allows converting Eager-style gradients to graph versions."""
+# TODO(mdan): This is not needed. Remove once the static analysis works.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class GradientsFunctionTransformer(gast.NodeTransformer):
+  """Hack: transforms eager-style gradients to TF compatible calls.
+
+  Requires an expression of exactly this form:
+      ... = tfe.value_and_gradients_function(...)(...)
+  """
+
+  # pylint:disable=invalid-name
+
+  def visit_Assign(self, node):
+    self.generic_visit(node)
+
+    val = node.value
+    if isinstance(val, gast.Call):
+      if isinstance(val.func, gast.Call):
+        if isinstance(val.func.func, gast.Attribute):
+          if isinstance(val.func.func.value, gast.Name):
+            if (val.func.func.value.id == 'tfe' and
+                val.func.func.attr == 'value_and_gradients_function'):
+
+              # pylint:disable=unused-argument,undefined-variable
+
+              def template(loss_var, loss_fn, args, d_vars, wrt_vars):
+                loss_var = loss_fn(args)
+                d_vars = tf.gradients(loss_var, [wrt_vars])
+
+              # pylint:enable=unused-argument,undefined-variable
+
+              # How to get these values? Print out the node.
+              loss_var = gast.Name(node.targets[0].elts[0].id, gast.Store(),
+                                   None)
+              loss_fn = gast.Name(val.func.args[0].id, gast.Load(), None)
+              args = tuple(
+                  gast.Name(a.id, gast.Param(), None) for a in val.args)
+              d_vars = node.targets[0].elts[1]
+              wrt_vars = [val.args[e.n] for e in val.func.args[1].elts]
+
+              node = templates.replace(
+                  template,
+                  loss_var=loss_var,
+                  loss_fn=loss_fn,
+                  args=args,
+                  d_vars=d_vars,
+                  wrt_vars=wrt_vars)
+
+    return node
+
+  # pylint:enable=invalid-name
+
+
+def transform(node):
+  transformer = GradientsFunctionTransformer()
+  node = transformer.visit(node)
+  return node
diff --git a/tensorflow/contrib/py2tf/convert/gradients_function_test.py b/tensorflow/contrib/py2tf/convert/gradients_function_test.py
new file mode 100644 (file)
index 0000000..7ef22f7
--- /dev/null
@@ -0,0 +1,55 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for gradients_function module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.eager.python import tfe
+from tensorflow.contrib.py2tf.convert import gradients_function
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.platform import test
+
+
+class GradientsFunctionTest(test.TestCase):
+
+  def test_transform(self):
+
+    def loss(x, w):
+      return x * w
+
+    def test_fn(x, w):
+      l, (dw,) = tfe.value_and_gradients_function(loss, [1])(x, w)  # pylint:disable=undefined-variable
+      return l, dw
+
+    node = parser.parse_object(test_fn)
+    node = gradients_function.transform(node)
+    result = compiler.ast_to_object(node)
+    setattr(result, 'tf', gradients_impl)
+    setattr(result, 'loss', loss)
+
+    with self.test_session() as sess:
+      self.assertEqual(
+          (12, 3),
+          sess.run(
+              result.test_fn(constant_op.constant(3), constant_op.constant(4))))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/py2tf/convert/logical_expressions.py b/tensorflow/contrib/py2tf/convert/logical_expressions.py
new file mode 100644 (file)
index 0000000..c2f27a5
--- /dev/null
@@ -0,0 +1,55 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter for logical expressions.
+
+e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import parser
+
+
+class LogicalExpressionTransformer(gast.NodeTransformer):
+  """Converts logical expressions to corresponding TF calls."""
+
+  def __init__(self):
+    # TODO(mdan): Look into replacing with bitwise operators instead.
+    self.op_mapping = {
+        gast.And: 'tf.logical_and',
+        gast.Or: 'tf.logical_or',
+    }
+
+  def visit_UnaryOp(self, node):
+    raise NotImplementedError()
+
+  def visit_BoolOp(self, node):
+    # TODO(mdan): A normalizer may be useful here. Use ANF?
+    tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value
+    left = node.values[0]
+    for i in range(1, len(node.values)):
+      left = gast.Call(
+          func=tf_function, args=[left, node.values[i]], keywords=[])
+    return left
+
+
+def transform(node):
+  transformer = LogicalExpressionTransformer()
+  node = transformer.visit(node)
+  return node
diff --git a/tensorflow/contrib/py2tf/convert/logical_expressions_test.py b/tensorflow/contrib/py2tf/convert/logical_expressions_test.py
new file mode 100644 (file)
index 0000000..9679ec2
--- /dev/null
@@ -0,0 +1,45 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for logical_expressions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import logical_expressions
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class GradientsFunctionTest(test.TestCase):
+
+  def test_transform(self):
+
+    def test_fn(a, b, c):
+      return (a or b) and (a or b or c)
+
+    node = parser.parse_object(test_fn)
+    node = logical_expressions.transform(node)
+    result = compiler.ast_to_object(node)
+    setattr(result, 'tf', math_ops)
+
+    with self.test_session() as sess:
+      self.assertTrue(sess.run(result.test_fn(True, False, True)))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/py2tf/convert/print_functions.py b/tensorflow/contrib/py2tf/convert/print_functions.py
new file mode 100644 (file)
index 0000000..5da738c
--- /dev/null
@@ -0,0 +1,51 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Compatibility support. Converts Print nodes to function calls."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+
+
+class PrintFunctionTransformer(gast.NodeTransformer):
+  """Transforms Print nodes to Call so they can be handled as functions."""
+
+  # pylint:disable=invalid-name
+
+  def visit_Print(self, node):
+    self.generic_visit(node)
+    for n in node.values:
+      n.ctx = gast.Param()
+    call_node = gast.Call(
+        func=gast.Name('print', gast.Load(), None),
+        args=node.values,
+        keywords=[])
+    anno.setanno(call_node.func, 'live_val', print)
+    anno.setanno(call_node.func, 'fqn', 'print')
+    anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope'))
+    node = gast.Expr(call_node)
+    return node
+
+  # pylint:enable=invalid-name
+
+
+def transform(node):
+  transformer = PrintFunctionTransformer()
+  node = transformer.visit(node)
+  return node
diff --git a/tensorflow/contrib/py2tf/convert/print_functions_test.py b/tensorflow/contrib/py2tf/convert/print_functions_test.py
new file mode 100644 (file)
index 0000000..f8fee87
--- /dev/null
@@ -0,0 +1,55 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for print_functions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.convert import print_functions
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.platform import test
+
+
+class PrintFunctionsTest(test.TestCase):
+
+  def _parse_and_analyze(self, test_fn, namespace):
+    node = parser.parse_object(test_fn)
+    node = access.resolve(node)
+    node = live_values.resolve(node, namespace, {})
+    node = type_info.resolve(node, None)
+    return node
+
+  def test_transform(self):
+
+    def test_fn(a):
+      print(a)
+
+    node = self._parse_and_analyze(test_fn, {'print': print})
+    node = print_functions.transform(node)
+    result = compiler.ast_to_object(node)
+
+    result.test_fn('a')
+    self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+
+
+if __name__ == '__main__':
+  test.main()
diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards.py b/tensorflow/contrib/py2tf/convert/side_effect_guards.py
new file mode 100644 (file)
index 0000000..17f7607
--- /dev/null
@@ -0,0 +1,155 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Adds guards against function calls with side effects.
+
+Only standalone calls are guarded.
+
+WARNING: This mechanism is incomplete. Particularly, it only guards the
+arguments passed to functions, and does not account for indirectly modified
+state.
+
+Example:
+  y = tf.layers.dense(x)       # Creates TF variable 'foo'
+  loss = loss(y)
+  opt.minimize(loss)           # indirectly affects 'foo'
+  z = tf.get_variable('foo')   # Indirectly affects `loss` and 'foo'
+  # Here, `loss` can be guarded. But `z` cannot.
+
+# TODO(mdan): We should probably define a safe mode where we guard everything.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class SymbolNamer(object):
+  """Describes the interface for SideEffectGuardTransformer's namer."""
+
+  def new_symbol(self, name_root, reserved_locals):
+    """Generate a new unique function_name.
+
+    Args:
+      name_root: String, used as stem in the new name.
+      reserved_locals: Set(string), additional local symbols that are reserved.
+    Returns:
+      String.
+    """
+    raise NotImplementedError()
+
+
+class SideEffectGuardTransformer(gast.NodeTransformer):
+  """Adds control dependencies to functions with side effects."""
+
+  def __init__(self, namer):
+    self.namer = namer
+    self.indent_next = False
+    self.next_indent_owner = None
+
+  # pylint:disable=invalid-name
+
+  def _visit_and_reindent(self, nodes):
+    new_nodes = []
+    current_dest = new_nodes
+    for n in nodes:
+      n = self.visit(n)
+      if isinstance(n, (list, tuple)):
+        current_dest.extend(n)
+      else:
+        current_dest.append(n)
+      if self.indent_next:
+        assert self.next_indent_owner is not None
+        current_dest.append(self.next_indent_owner)
+        current_dest = self.next_indent_owner.body
+        self.next_indent_owner = None
+        self.indent_next = False
+    if not current_dest:
+      # TODO(mdan): There may still be something that could be done.
+      raise ValueError('Unable to insert statement into the computation flow: '
+                       'it is not followed by any computation that can we can '
+                       'condition on the statement.')
+    return new_nodes
+
+  def visit_FunctionDef(self, node):
+    if anno.hasanno(node, 'skip_processing'):
+      return node
+    node.body = self._visit_and_reindent(node.body)
+    return node
+
+  def _gate_symbols(self, guard_statement, guarded_args):
+
+    def template(dst_args, src_args):  # pylint:disable=unused-argument
+      (dst_args,) = (tf.identity(a) for a in (src_args,))  # pylint:disable=undefined-variable
+
+    guards = templates.replace(
+        template,
+        dst_args=tuple(gast.Name(a, gast.Store(), None) for a in guarded_args),
+        src_args=tuple(gast.Name(a, gast.Load(), None) for a in guarded_args))
+    guard_statement.body.extend(guards)
+    return guard_statement
+
+  def visit_Expr(self, node):
+    self.generic_visit(node)
+    if isinstance(node.value, gast.Call):
+      # Patterns of single function calls, like:
+      #   opt.minimize(loss)
+      # or:
+      #   tf.py_func(...)
+
+      args_scope = anno.getanno(node.value, 'args_scope')
+      temp_name = self.namer.new_symbol('temp', args_scope.parent.used)
+      # TODO(mdan): Unsafe reference modification!
+      args_scope.mark_write(temp_name)
+
+      def template(call, temp_result):
+        temp_result = call
+        if not isinstance(temp_result, (list, tuple)):
+          temp_result = (temp_result,)
+        with tf.control_dependencies(temp_result):  # pylint:disable=undefined-variable
+          # TODO(mdan): Also insert ops to re-fetch if variables are involved.
+          pass  # Will be removed below.
+
+      guard_var_assign, arg_checker, control_deps_guard = templates.replace(
+          template,
+          call=node.value,
+          temp_result=gast.Name(temp_name, gast.Store(), None))
+      control_deps_guard.body = []
+
+      # First, attempt to gate future evaluation of args. If that's not
+      # possible, gate all remaining statements (and that may fail too, see
+      # _visit_and_reindent.
+      guarded_args = tuple(
+          n for n in args_scope.used if n in args_scope.parent.modified)
+      if guarded_args:
+        node = (guard_var_assign, arg_checker,
+                self._gate_symbols(control_deps_guard, guarded_args))
+      else:
+        node = (guard_var_assign, arg_checker)
+        # The mechanism will insert the guard statement later.
+        self.indent_next = True
+        self.next_indent_owner = control_deps_guard
+    return node
+
+  # pylint:enable=invalid-name
+
+
+def transform(node, namer):
+  transformer = SideEffectGuardTransformer(namer)
+  return transformer.visit(node)
diff --git a/tensorflow/contrib/py2tf/convert/side_effect_guards_test.py b/tensorflow/contrib/py2tf/convert/side_effect_guards_test.py
new file mode 100644 (file)
index 0000000..e8888ab
--- /dev/null
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for side_effect_guards module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import side_effect_guards
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class TestNamer(side_effect_guards.SymbolNamer):
+
+  def new_symbol(self, name_root, _):
+    return name_root
+
+
+class SideEffectGuardsTest(test.TestCase):
+
+  def _parse_and_analyze(self, test_fn, namespace):
+    node = parser.parse_object(test_fn)
+    node = access.resolve(node)
+    node = live_values.resolve(node, namespace, {})
+    node = type_info.resolve(node, None)
+    return node
+
+  def test_transform(self):
+
+    def test_fn(a):
+      state_ops.assign(a, a + 1)
+      return a
+
+    node = self._parse_and_analyze(test_fn, {'state_ops': state_ops})
+    node = side_effect_guards.transform(node, TestNamer())
+    result = compiler.ast_to_object(node)
+    setattr(result, 'state_ops', state_ops)
+
+    # TODO(mdan): Configure the namespaces instead of doing these hacks.
+    ops.identity = array_ops.identity
+    setattr(result, 'tf', ops)
+
+    with self.test_session() as sess:
+      v = variables.Variable(2)
+      sess.run(v.initializer)
+      self.assertEqual(3, sess.run(result.test_fn(v)))
+
+
+if __name__ == '__main__':
+  test.main()
index 72116f71d132a79a71ca9984d6d772512e3433dd..28cb95329766e838a653497cdde7c8115ccdb65b 100644 (file)
@@ -168,6 +168,7 @@ sh_binary(
             "//tensorflow/contrib/ndlstm:ndlstm",
             "//tensorflow/contrib/nn:nn_py",
             "//tensorflow/contrib/predictor:predictor_pip",
+            "//tensorflow/contrib/py2tf/convert:convert",
             "//tensorflow/contrib/py2tf/pyct:pyct",
             "//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis",
             "//tensorflow/contrib/receptive_field:receptive_field_pip",