From 12baea6c9a2ccb15f24ca79f18bcdd639b149592 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 19 Mar 2018 15:09:23 -0700 Subject: [PATCH] Use fully-qualified function names and avoid the need to replace attributes. PiperOrigin-RevId: 189648496 --- .../py2tf/converters/logical_expressions.py | 52 ++++++++++------------ 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/tensorflow/contrib/py2tf/converters/logical_expressions.py b/tensorflow/contrib/py2tf/converters/logical_expressions.py index 10192e6..e0abf74 100644 --- a/tensorflow/contrib/py2tf/converters/logical_expressions.py +++ b/tensorflow/contrib/py2tf/converters/logical_expressions.py @@ -24,6 +24,7 @@ from __future__ import print_function import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct import templates from tensorflow.contrib.py2tf.pyct import transformer @@ -44,17 +45,18 @@ class LogicalExpressionTransformer(transformer.Base): def __init__(self, context): super(LogicalExpressionTransformer, self).__init__(context) # TODO(mdan): Look into replacing with bitwise operators instead. + # TODO(mdan): Skip replacing if the function is trivial. self.op_mapping = { - gast.And: 'logical_and', - gast.Eq: 'equal', - gast.Gt: 'greater', - gast.GtE: 'greater_equal', - gast.Lt: 'less', - gast.LtE: 'less_equal', - gast.Not: 'logical_not', - gast.NotEq: 'not_equal', - gast.Or: 'logical_or', - gast.USub: 'negative', + gast.And: 'tf.logical_and', + gast.Eq: 'tf.equal', + gast.Gt: 'tf.greater', + gast.GtE: 'tf.greater_equal', + gast.Lt: 'tf.less', + gast.LtE: 'tf.less_equal', + gast.Not: 'tf.logical_not', + gast.NotEq: 'tf.not_equal', + gast.Or: 'tf.logical_or', + gast.USub: 'tf.negative', gast.Is: 'py2tf_utils.dynamic_is', gast.IsNot: 'py2tf_utils.dynamic_is_not' } @@ -70,27 +72,19 @@ class LogicalExpressionTransformer(transformer.Base): '"a.x or b"; for a workaround, assign the expression to a local ' 'variable and use that instead, for example "tmp = a.x", "tmp or b"') - def _matching_tf_op(self, operator): + def _matching_func(self, operator): op_type = type(operator) mapped_op = self.op_mapping.get(op_type) if not mapped_op: raise NotImplementedError('operator %s is not yet supported' % op_type) return mapped_op - def _inline_tf_op(self, op_name, args): - if 'py2tf_utils' in op_name: - # TODO(alexbw): explicitly spelling out the attribute function name - # until fix for issue highlighted in cl/188931581 lands. - template = """ - py2tf_utils.op_name(args) + def _as_function(self, func_name, args): + template = """ + func_name(args) """ - op_name = op_name.replace('py2tf_utils.', '') - else: - template = """ - tf.op_name(args) - """ replacement = templates.replace_as_expression( - template, op_name=op_name, args=args) + template, func_name=parser.parse_expression(func_name), args=args) anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True) return replacement @@ -104,14 +98,14 @@ class LogicalExpressionTransformer(transformer.Base): # a < b < c -> a < b and b < c while ops_and_comps: op, right = ops_and_comps.pop(0) - binary_comparison = self._inline_tf_op(self._matching_tf_op(op), - (left, right)) + binary_comparison = self._as_function( + self._matching_func(op), (left, right)) if isinstance(left, gast.Name) and isinstance(right, gast.Name): anno.setanno(binary_comparison, SAFE_BOOLEAN_OPERAND, True) if op_tree: self._expect_simple_symbol(right) - op_tree = self._inline_tf_op('logical_and', - (binary_comparison, op_tree)) + op_tree = self._as_function('tf.logical_and', + (binary_comparison, op_tree)) else: op_tree = binary_comparison left = right @@ -120,7 +114,7 @@ class LogicalExpressionTransformer(transformer.Base): def visit_UnaryOp(self, node): node = self.generic_visit(node) - return self._inline_tf_op(self._matching_tf_op(node.op), node.operand) + return self._as_function(self._matching_func(node.op), node.operand) def visit_BoolOp(self, node): node = self.generic_visit(node) @@ -130,7 +124,7 @@ class LogicalExpressionTransformer(transformer.Base): while node_values: left = node_values.pop() self._expect_simple_symbol(left) - right = self._inline_tf_op(self._matching_tf_op(node.op), (left, right)) + right = self._as_function(self._matching_func(node.op), (left, right)) return right -- 2.7.4