"//tensorflow/contrib/periodic_resample:all_files",
"//tensorflow/contrib/predictor:all_files",
"//tensorflow/contrib/py2tf:all_files",
+ "//tensorflow/contrib/py2tf/convert:all_files",
"//tensorflow/contrib/py2tf/pyct:all_files",
"//tensorflow/contrib/py2tf/pyct/static_analysis:all_files",
"//tensorflow/contrib/quantize:all_files",
--- /dev/null
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "convert",
+ srcs = [
+ "call_trees.py",
+ "control_flow.py",
+ "gradients_function.py",
+ "logical_expressions.py",
+ "print_functions.py",
+ "side_effect_guards.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "call_trees_test",
+ srcs = ["call_trees_test.py"],
+ deps = [
+ ":convert",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "control_flow_test",
+ srcs = ["control_flow_test.py"],
+ deps = [
+ ":convert",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "gradients_function_test",
+ srcs = ["gradients_function_test.py"],
+ deps = [
+ ":convert",
+ "//tensorflow/contrib/eager/python:tfe",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "logical_expressions_test",
+ srcs = ["logical_expressions_test.py"],
+ deps = [
+ ":convert",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "print_functions_test",
+ srcs = ["print_functions_test.py"],
+ deps = [
+ ":convert",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/python:client_testlib",
+ "@gast_archive//:gast",
+ ],
+)
+
+py_test(
+ name = "side_effect_guards_test",
+ srcs = ["side_effect_guards_test.py"],
+ deps = [
+ ":convert",
+ "//tensorflow/contrib/py2tf/pyct",
+ "//tensorflow/contrib/py2tf/pyct/static_analysis",
+ "//tensorflow/python:client_testlib",
+ ],
+)
--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Code converters used by Py2TF."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(mdan): Define a base transformer class that can recognize skip_processing
--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Handles function calls, by generating compiled function names and calls."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import types
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class FunctionNamer(object):
+ """Describes the interface for CallTreeTransformer's namer."""
+
+ def compiled_function_name(self, original_name, live_object=None):
+ """Generate the name corresponding to the compiled version of a function.
+
+ Args:
+ original_name: String
+ live_object: Callable, the actual target function, if known.
+ Returns:
+ String.
+ """
+ raise NotImplementedError()
+
+
+class CallTreeTransformer(gast.NodeTransformer):
+ """Transforms the call tree by renaming transformed symbols."""
+
+ def __init__(self, namer, uncompiled_modules):
+ self.namer = namer
+ self.uncompiled_modules = uncompiled_modules
+
+ # pylint:disable=invalid-name
+
+ def visit_FunctionDef(self, node):
+ self.generic_visit(node)
+ node.name = self.namer.compiled_function_name(node.name)
+ return node
+
+ def _rename_compilable_function(self, node):
+ assert anno.hasanno(node.func, 'live_val')
+ assert anno.hasanno(node.func, 'fqn')
+ target_obj = anno.getanno(node.func, 'live_val')
+ target_fqn = anno.getanno(node.func, 'fqn')
+
+ fqn = ''
+ for s in target_fqn:
+ if fqn:
+ fqn += '.'
+ fqn += s
+ if fqn in self.uncompiled_modules:
+ return node
+
+ new_name = self.namer.compiled_function_name(fqn, live_object=target_obj)
+ node.func = gast.Name(id=new_name, ctx=gast.Load(), annotation=None)
+ return node
+
+ def _rename_member_function_of_known_type(self, node):
+ target_fqn = anno.getanno(node.func, 'type_fqn')
+
+ fqn = ''
+ for s in target_fqn:
+ if fqn:
+ fqn += '.'
+ fqn += s
+ if fqn in self.uncompiled_modules:
+ return node
+
+ raise NotImplementedError('Member function call (of known type).')
+
+ def _wrap_to_py_func_no_return(self, node):
+ args_scope = anno.getanno(node, 'args_scope')
+ # TODO(mdan): Properly handle varargs, kwargs, etc.
+ args = tuple(gast.Name(n, gast.Load(), None) for n in args_scope.used)
+
+ # pylint:disable=undefined-variable,unused-argument,function-redefined
+
+ def template(call, wrapper, args):
+
+ def wrapper(args):
+ call(args)
+ return 1
+
+ tf.py_func(wrapper, [args], [tf.int64])
+
+ # pylint:enable=undefined-variable,unused-argument,function-redefined
+
+ wrapper_name = self.namer.compiled_function_name(node.func.id)
+ wrapper_def, call_expr = templates.replace(
+ template,
+ call=node.func,
+ wrapper=gast.Name(wrapper_name, gast.Load(), None),
+ args=args)
+ anno.setanno(call_expr.value, 'args_scope', args_scope)
+ anno.setanno(wrapper_def, 'skip_processing', True)
+
+ return (wrapper_def, call_expr)
+
+ def visit_Expr(self, node):
+ if isinstance(node.value, gast.Call):
+ node = self._wrap_to_py_func_no_return(node.value)
+ else:
+ self.generic_visit(node)
+ return node
+
+ def visit_Call(self, node):
+ self.generic_visit(node)
+ if anno.hasanno(node.func, 'live_val'):
+ target_obj = anno.getanno(node.func, 'live_val')
+ if isinstance(target_obj, types.BuiltinFunctionType):
+ raise NotImplementedError('py_func with return values')
+ else:
+ node = self._rename_compilable_function(node)
+ elif anno.hasanno(node.func, 'type_fqn'):
+ node = self._rename_member_function_of_known_type(node)
+ else:
+ raise NotImplementedError(
+ 'Member function call (of unknown type): %s.' % node.func.id)
+ return node
+
+ # pylint:enable=invalid-name
+
+
+def transform(node, namer, uncompiled_modules):
+ """Transform function call to the compiled counterparts.
+
+ Args:
+ node: AST to transform.
+ namer: FunctionNamer-like.
+ uncompiled_modules: set of string tuples, each tuple represents the fully
+ qualified name of a package containing functions that will not be
+ compiled.
+ Returns:
+ A tuple (node, new_names):
+ node: The transformed AST
+ new_names: set(string), containing any newly-generated names
+ """
+ transformer = CallTreeTransformer(namer, uncompiled_modules)
+ node = transformer.visit(node)
+ return node
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for call_trees module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import call_trees
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class TestNamer(call_trees.FunctionNamer):
+
+ def compiled_function_name(self, original_name, live_object=None):
+ return 'renamed_%s' % original_name
+
+
+class CallTreesTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn, namespace):
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ node = live_values.resolve(node, namespace, {})
+ node = type_info.resolve(node, None)
+ return node
+
+ def test_basic(self):
+
+ def test_fn_1(_):
+ raise ValueError('This should not be called in the compiled verison.')
+
+ def renamed_test_fn_1(a):
+ return a + 1
+
+ def test_fn_2(a):
+ return test_fn_1(a) + 1
+
+ node = self._parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
+ node = call_trees.transform(node, TestNamer(), set())
+ result = compiler.ast_to_object(node)
+ # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1 manually.
+ setattr(result, 'renamed_test_fn_1', renamed_test_fn_1)
+
+ self.assertEquals(3, result.renamed_test_fn_2(1))
+
+ def test_uncompiled_modules(self):
+
+ def test_fn(a):
+ a = math_ops.multiply(a, constant_op.constant(2))
+ a = math_ops.add(a, constant_op.constant(1))
+ return a
+
+ node = self._parse_and_analyze(test_fn, {
+ 'math_ops': math_ops,
+ 'constant_op': constant_op
+ })
+ node = call_trees.transform(node, TestNamer(),
+ set((math_ops.__name__, constant_op.__name__)))
+ result = compiler.ast_to_object(node)
+ setattr(result, 'math_ops', math_ops)
+ setattr(result, 'constant_op', constant_op)
+
+ with self.test_session() as sess:
+ result_tensor = result.renamed_test_fn(constant_op.constant(1))
+ result_val = sess.run(result_tensor)
+
+ self.assertEquals(3, result_val)
+
+
+if __name__ == '__main__':
+ test.main()
--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Identity converter. Useful for testing and diagnostic."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class SymbolNamer(object):
+ """Describes the interface for ControlFlowTransformer's namer."""
+
+ def new_symbol(self, name_root, reserved_locals):
+ """Generate a new unique symbol.
+
+ Args:
+ name_root: String, used as stem in the new name.
+ reserved_locals: Set(string), additional local symbols that are reserved
+ and which should not be used.
+ Returns:
+ String.
+ """
+ raise NotImplementedError()
+
+
+class ControlFlowTransformer(gast.NodeTransformer):
+ """Transforms control flow structures like loops an conditionals."""
+
+ def __init__(self, namer):
+ self.namer = namer
+
+ # pylint:disable=invalid-name
+
+ def _tuple_or_item(self, elts):
+ elts = tuple(elts)
+ if len(elts) == 1:
+ return elts[0]
+ return elts
+
+ def _ast_tuple_or_item(self, elts, ctx):
+ elts = list(elts)
+ if len(elts) == 1:
+ return elts[0]
+ return gast.Tuple(elts, ctx)
+
+ def visit_If(self, node):
+ raise NotImplementedError()
+
+ def visit_While(self, node):
+ self.generic_visit(node)
+ # Scrape out the data flow analysis
+ body_scope = anno.getanno(node, 'body_scope')
+ parent_scope_values = anno.getanno(node, 'parent_scope_values')
+ body_closure = tuple(body_scope.modified - body_scope.created)
+
+ def template(
+ state_args, # pylint:disable=unused-argument
+ state_locals,
+ state_results, # pylint:disable=unused-argument
+ test_name,
+ test, # pylint:disable=unused-argument
+ body_name,
+ body,
+ state_init):
+
+ def test_name(state_args): # pylint:disable=function-redefined,unused-argument
+ return test
+
+ def body_name(state_args): # pylint:disable=function-redefined,unused-argument
+ body # pylint:disable=pointless-statement
+ return state_locals
+
+ state_results = tf.while_loop(test_name, body_name, [state_init]) # pylint:disable=undefined-variable
+
+ test_name = self.namer.new_symbol('loop_test', body_scope.used)
+ body_name = self.namer.new_symbol('loop_body', body_scope.used)
+ node = templates.replace(
+ template,
+ state_args=self._tuple_or_item(
+ gast.Name(n, gast.Param(), None) for n in body_closure),
+ state_locals=self._ast_tuple_or_item(
+ (gast.Name(n, gast.Load(), None) for n in body_closure),
+ gast.Load()),
+ state_results=self._ast_tuple_or_item(
+ (gast.Name(n, gast.Store(), None) for n in body_closure),
+ gast.Store()),
+ test_name=gast.Name(test_name, gast.Load(), None),
+ test=node.test,
+ body_name=gast.Name(body_name, gast.Load(), None),
+ body=node.body,
+ state_init=[parent_scope_values.getval(n) for n in body_closure])
+
+ return node
+
+ # pylint:enable=invalid-name
+
+
+def transform(node, namer):
+ transformer = ControlFlowTransformer(namer)
+ node = transformer.visit(node)
+ return node
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for control_flow module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import control_flow
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.platform import test
+
+
+class TestNamer(control_flow.SymbolNamer):
+
+ def new_symbol(self, name_root, _):
+ return name_root
+
+
+class ControlFlowTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn, namespace):
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ node = live_values.resolve(node, namespace, {})
+ node = type_info.resolve(node, None)
+ return node
+
+ def test_simple_while(self):
+
+ def test_fn(n):
+ i = 0
+ s = 0
+ while i < n:
+ s += i
+ i += 1
+ return s, i, n
+
+ node = self._parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, TestNamer())
+ result = compiler.ast_to_object(node)
+ setattr(result, 'tf', control_flow_ops)
+
+ with self.test_session() as sess:
+ self.assertEqual((10, 5, 5),
+ sess.run(result.test_fn(constant_op.constant(5))))
+
+ def test_while_single_var(self):
+
+ def test_fn(n):
+ while n > 0:
+ n -= 1
+ return n
+
+ node = self._parse_and_analyze(test_fn, {})
+ node = control_flow.transform(node, TestNamer())
+ result = compiler.ast_to_object(node)
+ setattr(result, 'tf', control_flow_ops)
+
+ with self.test_session() as sess:
+ self.assertEqual(0, sess.run(result.test_fn(constant_op.constant(5))))
+
+
+if __name__ == '__main__':
+ test.main()
--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Allows converting Eager-style gradients to graph versions."""
+# TODO(mdan): This is not needed. Remove once the static analysis works.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class GradientsFunctionTransformer(gast.NodeTransformer):
+ """Hack: transforms eager-style gradients to TF compatible calls.
+
+ Requires an expression of exactly this form:
+ ... = tfe.value_and_gradients_function(...)(...)
+ """
+
+ # pylint:disable=invalid-name
+
+ def visit_Assign(self, node):
+ self.generic_visit(node)
+
+ val = node.value
+ if isinstance(val, gast.Call):
+ if isinstance(val.func, gast.Call):
+ if isinstance(val.func.func, gast.Attribute):
+ if isinstance(val.func.func.value, gast.Name):
+ if (val.func.func.value.id == 'tfe' and
+ val.func.func.attr == 'value_and_gradients_function'):
+
+ # pylint:disable=unused-argument,undefined-variable
+
+ def template(loss_var, loss_fn, args, d_vars, wrt_vars):
+ loss_var = loss_fn(args)
+ d_vars = tf.gradients(loss_var, [wrt_vars])
+
+ # pylint:enable=unused-argument,undefined-variable
+
+ # How to get these values? Print out the node.
+ loss_var = gast.Name(node.targets[0].elts[0].id, gast.Store(),
+ None)
+ loss_fn = gast.Name(val.func.args[0].id, gast.Load(), None)
+ args = tuple(
+ gast.Name(a.id, gast.Param(), None) for a in val.args)
+ d_vars = node.targets[0].elts[1]
+ wrt_vars = [val.args[e.n] for e in val.func.args[1].elts]
+
+ node = templates.replace(
+ template,
+ loss_var=loss_var,
+ loss_fn=loss_fn,
+ args=args,
+ d_vars=d_vars,
+ wrt_vars=wrt_vars)
+
+ return node
+
+ # pylint:enable=invalid-name
+
+
+def transform(node):
+ transformer = GradientsFunctionTransformer()
+ node = transformer.visit(node)
+ return node
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for gradients_function module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.eager.python import tfe
+from tensorflow.contrib.py2tf.convert import gradients_function
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.platform import test
+
+
+class GradientsFunctionTest(test.TestCase):
+
+ def test_transform(self):
+
+ def loss(x, w):
+ return x * w
+
+ def test_fn(x, w):
+ l, (dw,) = tfe.value_and_gradients_function(loss, [1])(x, w) # pylint:disable=undefined-variable
+ return l, dw
+
+ node = parser.parse_object(test_fn)
+ node = gradients_function.transform(node)
+ result = compiler.ast_to_object(node)
+ setattr(result, 'tf', gradients_impl)
+ setattr(result, 'loss', loss)
+
+ with self.test_session() as sess:
+ self.assertEqual(
+ (12, 3),
+ sess.run(
+ result.test_fn(constant_op.constant(3), constant_op.constant(4))))
+
+
+if __name__ == '__main__':
+ test.main()
--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converter for logical expressions.
+
+e.g. `a and b -> tf.logical_and(a, b)`. This is not done automatically in TF.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import parser
+
+
+class LogicalExpressionTransformer(gast.NodeTransformer):
+ """Converts logical expressions to corresponding TF calls."""
+
+ def __init__(self):
+ # TODO(mdan): Look into replacing with bitwise operators instead.
+ self.op_mapping = {
+ gast.And: 'tf.logical_and',
+ gast.Or: 'tf.logical_or',
+ }
+
+ def visit_UnaryOp(self, node):
+ raise NotImplementedError()
+
+ def visit_BoolOp(self, node):
+ # TODO(mdan): A normalizer may be useful here. Use ANF?
+ tf_function = parser.parse_str(self.op_mapping[type(node.op)]).body[0].value
+ left = node.values[0]
+ for i in range(1, len(node.values)):
+ left = gast.Call(
+ func=tf_function, args=[left, node.values[i]], keywords=[])
+ return left
+
+
+def transform(node):
+ transformer = LogicalExpressionTransformer()
+ node = transformer.visit(node)
+ return node
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for logical_expressions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import logical_expressions
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class GradientsFunctionTest(test.TestCase):
+
+ def test_transform(self):
+
+ def test_fn(a, b, c):
+ return (a or b) and (a or b or c)
+
+ node = parser.parse_object(test_fn)
+ node = logical_expressions.transform(node)
+ result = compiler.ast_to_object(node)
+ setattr(result, 'tf', math_ops)
+
+ with self.test_session() as sess:
+ self.assertTrue(sess.run(result.test_fn(True, False, True)))
+
+
+if __name__ == '__main__':
+ test.main()
--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Compatibility support. Converts Print nodes to function calls."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+
+
+class PrintFunctionTransformer(gast.NodeTransformer):
+ """Transforms Print nodes to Call so they can be handled as functions."""
+
+ # pylint:disable=invalid-name
+
+ def visit_Print(self, node):
+ self.generic_visit(node)
+ for n in node.values:
+ n.ctx = gast.Param()
+ call_node = gast.Call(
+ func=gast.Name('print', gast.Load(), None),
+ args=node.values,
+ keywords=[])
+ anno.setanno(call_node.func, 'live_val', print)
+ anno.setanno(call_node.func, 'fqn', 'print')
+ anno.setanno(call_node, 'args_scope', anno.getanno(node, 'args_scope'))
+ node = gast.Expr(call_node)
+ return node
+
+ # pylint:enable=invalid-name
+
+
+def transform(node):
+ transformer = PrintFunctionTransformer()
+ node = transformer.visit(node)
+ return node
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for print_functions module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.convert import print_functions
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.platform import test
+
+
+class PrintFunctionsTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn, namespace):
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ node = live_values.resolve(node, namespace, {})
+ node = type_info.resolve(node, None)
+ return node
+
+ def test_transform(self):
+
+ def test_fn(a):
+ print(a)
+
+ node = self._parse_and_analyze(test_fn, {'print': print})
+ node = print_functions.transform(node)
+ result = compiler.ast_to_object(node)
+
+ result.test_fn('a')
+ self.assertTrue(isinstance(node.body[0].body[0].value, gast.Call))
+
+
+if __name__ == '__main__':
+ test.main()
--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Adds guards against function calls with side effects.
+
+Only standalone calls are guarded.
+
+WARNING: This mechanism is incomplete. Particularly, it only guards the
+arguments passed to functions, and does not account for indirectly modified
+state.
+
+Example:
+ y = tf.layers.dense(x) # Creates TF variable 'foo'
+ loss = loss(y)
+ opt.minimize(loss) # indirectly affects 'foo'
+ z = tf.get_variable('foo') # Indirectly affects `loss` and 'foo'
+ # Here, `loss` can be guarded. But `z` cannot.
+
+# TODO(mdan): We should probably define a safe mode where we guard everything.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gast
+
+from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import templates
+
+
+class SymbolNamer(object):
+ """Describes the interface for SideEffectGuardTransformer's namer."""
+
+ def new_symbol(self, name_root, reserved_locals):
+ """Generate a new unique function_name.
+
+ Args:
+ name_root: String, used as stem in the new name.
+ reserved_locals: Set(string), additional local symbols that are reserved.
+ Returns:
+ String.
+ """
+ raise NotImplementedError()
+
+
+class SideEffectGuardTransformer(gast.NodeTransformer):
+ """Adds control dependencies to functions with side effects."""
+
+ def __init__(self, namer):
+ self.namer = namer
+ self.indent_next = False
+ self.next_indent_owner = None
+
+ # pylint:disable=invalid-name
+
+ def _visit_and_reindent(self, nodes):
+ new_nodes = []
+ current_dest = new_nodes
+ for n in nodes:
+ n = self.visit(n)
+ if isinstance(n, (list, tuple)):
+ current_dest.extend(n)
+ else:
+ current_dest.append(n)
+ if self.indent_next:
+ assert self.next_indent_owner is not None
+ current_dest.append(self.next_indent_owner)
+ current_dest = self.next_indent_owner.body
+ self.next_indent_owner = None
+ self.indent_next = False
+ if not current_dest:
+ # TODO(mdan): There may still be something that could be done.
+ raise ValueError('Unable to insert statement into the computation flow: '
+ 'it is not followed by any computation that can we can '
+ 'condition on the statement.')
+ return new_nodes
+
+ def visit_FunctionDef(self, node):
+ if anno.hasanno(node, 'skip_processing'):
+ return node
+ node.body = self._visit_and_reindent(node.body)
+ return node
+
+ def _gate_symbols(self, guard_statement, guarded_args):
+
+ def template(dst_args, src_args): # pylint:disable=unused-argument
+ (dst_args,) = (tf.identity(a) for a in (src_args,)) # pylint:disable=undefined-variable
+
+ guards = templates.replace(
+ template,
+ dst_args=tuple(gast.Name(a, gast.Store(), None) for a in guarded_args),
+ src_args=tuple(gast.Name(a, gast.Load(), None) for a in guarded_args))
+ guard_statement.body.extend(guards)
+ return guard_statement
+
+ def visit_Expr(self, node):
+ self.generic_visit(node)
+ if isinstance(node.value, gast.Call):
+ # Patterns of single function calls, like:
+ # opt.minimize(loss)
+ # or:
+ # tf.py_func(...)
+
+ args_scope = anno.getanno(node.value, 'args_scope')
+ temp_name = self.namer.new_symbol('temp', args_scope.parent.used)
+ # TODO(mdan): Unsafe reference modification!
+ args_scope.mark_write(temp_name)
+
+ def template(call, temp_result):
+ temp_result = call
+ if not isinstance(temp_result, (list, tuple)):
+ temp_result = (temp_result,)
+ with tf.control_dependencies(temp_result): # pylint:disable=undefined-variable
+ # TODO(mdan): Also insert ops to re-fetch if variables are involved.
+ pass # Will be removed below.
+
+ guard_var_assign, arg_checker, control_deps_guard = templates.replace(
+ template,
+ call=node.value,
+ temp_result=gast.Name(temp_name, gast.Store(), None))
+ control_deps_guard.body = []
+
+ # First, attempt to gate future evaluation of args. If that's not
+ # possible, gate all remaining statements (and that may fail too, see
+ # _visit_and_reindent.
+ guarded_args = tuple(
+ n for n in args_scope.used if n in args_scope.parent.modified)
+ if guarded_args:
+ node = (guard_var_assign, arg_checker,
+ self._gate_symbols(control_deps_guard, guarded_args))
+ else:
+ node = (guard_var_assign, arg_checker)
+ # The mechanism will insert the guard statement later.
+ self.indent_next = True
+ self.next_indent_owner = control_deps_guard
+ return node
+
+ # pylint:enable=invalid-name
+
+
+def transform(node, namer):
+ transformer = SideEffectGuardTransformer(namer)
+ return transformer.visit(node)
--- /dev/null
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for side_effect_guards module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.py2tf.convert import side_effect_guards
+from tensorflow.contrib.py2tf.pyct import compiler
+from tensorflow.contrib.py2tf.pyct import parser
+from tensorflow.contrib.py2tf.pyct.static_analysis import access
+from tensorflow.contrib.py2tf.pyct.static_analysis import live_values
+from tensorflow.contrib.py2tf.pyct.static_analysis import type_info
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class TestNamer(side_effect_guards.SymbolNamer):
+
+ def new_symbol(self, name_root, _):
+ return name_root
+
+
+class SideEffectGuardsTest(test.TestCase):
+
+ def _parse_and_analyze(self, test_fn, namespace):
+ node = parser.parse_object(test_fn)
+ node = access.resolve(node)
+ node = live_values.resolve(node, namespace, {})
+ node = type_info.resolve(node, None)
+ return node
+
+ def test_transform(self):
+
+ def test_fn(a):
+ state_ops.assign(a, a + 1)
+ return a
+
+ node = self._parse_and_analyze(test_fn, {'state_ops': state_ops})
+ node = side_effect_guards.transform(node, TestNamer())
+ result = compiler.ast_to_object(node)
+ setattr(result, 'state_ops', state_ops)
+
+ # TODO(mdan): Configure the namespaces instead of doing these hacks.
+ ops.identity = array_ops.identity
+ setattr(result, 'tf', ops)
+
+ with self.test_session() as sess:
+ v = variables.Variable(2)
+ sess.run(v.initializer)
+ self.assertEqual(3, sess.run(result.test_fn(v)))
+
+
+if __name__ == '__main__':
+ test.main()
"//tensorflow/contrib/ndlstm:ndlstm",
"//tensorflow/contrib/nn:nn_py",
"//tensorflow/contrib/predictor:predictor_pip",
+ "//tensorflow/contrib/py2tf/convert:convert",
"//tensorflow/contrib/py2tf/pyct:pyct",
"//tensorflow/contrib/py2tf/pyct/static_analysis:static_analysis",
"//tensorflow/contrib/receptive_field:receptive_field_pip",