Use fully-qualified function names and avoid the need to replace attributes.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 19 Mar 2018 22:09:23 +0000 (15:09 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 19 Mar 2018 22:15:35 +0000 (15:15 -0700)
PiperOrigin-RevId: 189648496

tensorflow/contrib/py2tf/converters/logical_expressions.py

index 10192e6..e0abf74 100644 (file)
@@ -24,6 +24,7 @@ from __future__ import print_function
 import gast
 
 from tensorflow.contrib.py2tf.pyct import anno
+from tensorflow.contrib.py2tf.pyct import parser
 from tensorflow.contrib.py2tf.pyct import templates
 from tensorflow.contrib.py2tf.pyct import transformer
 
@@ -44,17 +45,18 @@ class LogicalExpressionTransformer(transformer.Base):
   def __init__(self, context):
     super(LogicalExpressionTransformer, self).__init__(context)
     # TODO(mdan): Look into replacing with bitwise operators instead.
+    # TODO(mdan): Skip replacing if the function is trivial.
     self.op_mapping = {
-        gast.And: 'logical_and',
-        gast.Eq: 'equal',
-        gast.Gt: 'greater',
-        gast.GtE: 'greater_equal',
-        gast.Lt: 'less',
-        gast.LtE: 'less_equal',
-        gast.Not: 'logical_not',
-        gast.NotEq: 'not_equal',
-        gast.Or: 'logical_or',
-        gast.USub: 'negative',
+        gast.And: 'tf.logical_and',
+        gast.Eq: 'tf.equal',
+        gast.Gt: 'tf.greater',
+        gast.GtE: 'tf.greater_equal',
+        gast.Lt: 'tf.less',
+        gast.LtE: 'tf.less_equal',
+        gast.Not: 'tf.logical_not',
+        gast.NotEq: 'tf.not_equal',
+        gast.Or: 'tf.logical_or',
+        gast.USub: 'tf.negative',
         gast.Is: 'py2tf_utils.dynamic_is',
         gast.IsNot: 'py2tf_utils.dynamic_is_not'
     }
@@ -70,27 +72,19 @@ class LogicalExpressionTransformer(transformer.Base):
         '"a.x or b"; for a workaround, assign the expression to a local '
         'variable and use that instead, for example "tmp = a.x", "tmp or b"')
 
-  def _matching_tf_op(self, operator):
+  def _matching_func(self, operator):
     op_type = type(operator)
     mapped_op = self.op_mapping.get(op_type)
     if not mapped_op:
       raise NotImplementedError('operator %s is not yet supported' % op_type)
     return mapped_op
 
-  def _inline_tf_op(self, op_name, args):
-    if 'py2tf_utils' in op_name:
-      # TODO(alexbw): explicitly spelling out the attribute function name
-      # until fix for issue highlighted in cl/188931581 lands.
-      template = """
-      py2tf_utils.op_name(args)
+  def _as_function(self, func_name, args):
+    template = """
+      func_name(args)
     """
-      op_name = op_name.replace('py2tf_utils.', '')
-    else:
-      template = """
-        tf.op_name(args)
-      """
     replacement = templates.replace_as_expression(
-        template, op_name=op_name, args=args)
+        template, func_name=parser.parse_expression(func_name), args=args)
     anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
     return replacement
 
@@ -104,14 +98,14 @@ class LogicalExpressionTransformer(transformer.Base):
     #   a < b < c   ->   a < b and b < c
     while ops_and_comps:
       op, right = ops_and_comps.pop(0)
-      binary_comparison = self._inline_tf_op(self._matching_tf_op(op),
-                                             (left, right))
+      binary_comparison = self._as_function(
+          self._matching_func(op), (left, right))
       if isinstance(left, gast.Name) and isinstance(right, gast.Name):
         anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True)
       if op_tree:
         self._expect_simple_symbol(right)
-        op_tree = self._inline_tf_op('logical_and',
-                                     (binary_comparison, op_tree))
+        op_tree = self._as_function('tf.logical_and',
+                                    (binary_comparison, op_tree))
       else:
         op_tree = binary_comparison
       left = right
@@ -120,7 +114,7 @@ class LogicalExpressionTransformer(transformer.Base):
 
   def visit_UnaryOp(self, node):
     node = self.generic_visit(node)
-    return self._inline_tf_op(self._matching_tf_op(node.op), node.operand)
+    return self._as_function(self._matching_func(node.op), node.operand)
 
   def visit_BoolOp(self, node):
     node = self.generic_visit(node)
@@ -130,7 +124,7 @@ class LogicalExpressionTransformer(transformer.Base):
     while node_values:
       left = node_values.pop()
       self._expect_simple_symbol(left)
-      right = self._inline_tf_op(self._matching_tf_op(node.op), (left, right))
+      right = self._as_function(self._matching_func(node.op), (left, right))
     return right