Remove the gradients function converter now that we can use the tape method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 10 Jan 2018 21:09:04 +0000 (13:09 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 10 Jan 2018 21:18:58 +0000 (13:18 -0800)
PiperOrigin-RevId: 181506626

tensorflow/contrib/py2tf/conversion.py
tensorflow/contrib/py2tf/convert/BUILD
tensorflow/contrib/py2tf/convert/gradients_function.py [deleted file]
tensorflow/contrib/py2tf/convert/gradients_function_test.py [deleted file]

index 08acd6ca924a5aa2f09fb3492038a155bf7c81e4..12dd70e497cf797e7eac85f2259a0ea6e01027ad 100644 (file)
@@ -24,7 +24,6 @@ from tensorflow.contrib.py2tf import config
 from tensorflow.contrib.py2tf import naming
 from tensorflow.contrib.py2tf.convert import call_trees
 from tensorflow.contrib.py2tf.convert import control_flow
-from tensorflow.contrib.py2tf.convert import gradients_function
 from tensorflow.contrib.py2tf.convert import logical_expressions
 from tensorflow.contrib.py2tf.convert import print_functions
 from tensorflow.contrib.py2tf.convert import side_effect_guards
@@ -143,9 +142,6 @@ def node_to_graph(node, namer, namespace, value_hints):
         * deps: A set of strings, the fully qualified names of object
             dependencies that this node has.
   """
-  # TODO(mdan): Get rid of this.
-  node = gradients_function.transform(node)
-
   node = access.resolve(node)
   node = live_values.resolve(node, namespace, config.PYTHON_LITERALS)
   node = type_info.resolve(node, value_hints)
index 84a75ff7e158088c274d975fb638c51d86528bf8..ddbf3369470c6dcf6515a9e99e5bb0f56098eba6 100644 (file)
@@ -19,7 +19,6 @@ py_library(
     srcs = [
         "call_trees.py",
         "control_flow.py",
-        "gradients_function.py",
         "logical_expressions.py",
         "print_functions.py",
         "side_effect_guards.py",
@@ -53,18 +52,6 @@ py_test(
     ],
 )
 
-py_test(
-    name = "gradients_function_test",
-    srcs = ["gradients_function_test.py"],
-    deps = [
-        ":convert",
-        "//tensorflow/contrib/eager/python:tfe",
-        "//tensorflow/contrib/py2tf/pyct",
-        "//tensorflow/contrib/py2tf/pyct/static_analysis",
-        "//tensorflow/python:client_testlib",
-    ],
-)
-
 py_test(
     name = "logical_expressions_test",
     srcs = ["logical_expressions_test.py"],
diff --git a/tensorflow/contrib/py2tf/convert/gradients_function.py b/tensorflow/contrib/py2tf/convert/gradients_function.py
deleted file mode 100644 (file)
index f3c07db..0000000
+++ /dev/null
@@ -1,80 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Allows converting Eager-style gradients to graph versions."""
-# TODO(mdan): This is not needed. Remove once the static analysis works.
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import gast
-
-from tensorflow.contrib.py2tf.pyct import templates
-
-
-class GradientsFunctionTransformer(gast.NodeTransformer):
-  """Hack: transforms eager-style gradients to TF compatible calls.
-
-  Requires an expression of exactly this form:
-      ... = tfe.value_and_gradients_function(...)(...)
-  """
-
-  # pylint:disable=invalid-name
-
-  def visit_Assign(self, node):
-    self.generic_visit(node)
-
-    val = node.value
-    if isinstance(val, gast.Call):
-      if isinstance(val.func, gast.Call):
-        if isinstance(val.func.func, gast.Attribute):
-          if isinstance(val.func.func.value, gast.Name):
-            if (val.func.func.value.id == 'tfe' and
-                val.func.func.attr == 'value_and_gradients_function'):
-
-              # pylint:disable=unused-argument,undefined-variable
-
-              def template(loss_var, loss_fn, args, d_vars, wrt_vars):
-                loss_var = loss_fn(args)
-                d_vars = tf.gradients(loss_var, [wrt_vars])
-
-              # pylint:enable=unused-argument,undefined-variable
-
-              # How to get these values? Print out the node.
-              loss_var = gast.Name(node.targets[0].elts[0].id, gast.Store(),
-                                   None)
-              loss_fn = gast.Name(val.func.args[0].id, gast.Load(), None)
-              args = tuple(
-                  gast.Name(a.id, gast.Param(), None) for a in val.args)
-              d_vars = node.targets[0].elts[1]
-              wrt_vars = [val.args[e.n] for e in val.func.args[1].elts]
-
-              node = templates.replace(
-                  template,
-                  loss_var=loss_var,
-                  loss_fn=loss_fn,
-                  args=args,
-                  d_vars=d_vars,
-                  wrt_vars=wrt_vars)
-
-    return node
-
-  # pylint:enable=invalid-name
-
-
-def transform(node):
-  transformer = GradientsFunctionTransformer()
-  node = transformer.visit(node)
-  return node
diff --git a/tensorflow/contrib/py2tf/convert/gradients_function_test.py b/tensorflow/contrib/py2tf/convert/gradients_function_test.py
deleted file mode 100644 (file)
index 7ef22f7..0000000
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for gradients_function module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from tensorflow.contrib.eager.python import tfe
-from tensorflow.contrib.py2tf.convert import gradients_function
-from tensorflow.contrib.py2tf.pyct import compiler
-from tensorflow.contrib.py2tf.pyct import parser
-from tensorflow.python.framework import constant_op
-from tensorflow.python.ops import gradients_impl
-from tensorflow.python.platform import test
-
-
-class GradientsFunctionTest(test.TestCase):
-
-  def test_transform(self):
-
-    def loss(x, w):
-      return x * w
-
-    def test_fn(x, w):
-      l, (dw,) = tfe.value_and_gradients_function(loss, [1])(x, w)  # pylint:disable=undefined-variable
-      return l, dw
-
-    node = parser.parse_object(test_fn)
-    node = gradients_function.transform(node)
-    result = compiler.ast_to_object(node)
-    setattr(result, 'tf', gradients_impl)
-    setattr(result, 'loss', loss)
-
-    with self.test_session() as sess:
-      self.assertEqual(
-          (12, 3),
-          sess.run(
-              result.test_fn(constant_op.constant(3), constant_op.constant(4))))
-
-
-if __name__ == '__main__':
-  test.main()