gast.NotEq: 'not_equal',
gast.Or: 'logical_or',
gast.USub: 'negative',
+ gast.Is: 'py2tf_utils.dynamic_is',
+ gast.IsNot: 'py2tf_utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
return mapped_op
def _inline_tf_op(self, op_name, args):
- template = """
- tf.op_name(args)
+ if 'py2tf_utils' in op_name:
+ # TODO(alexbw): explicitly spelling out the attribute function name
+ # until fix for issue highlighted in cl/188931581 lands.
+ template = """
+ py2tf_utils.op_name(args)
"""
- replacement = templates.replace(template, op_name=op_name, args=args)
- # It's a body with a single expression, we want its value.
- n = replacement[0].value
- anno.setanno(n, SAFE_BOOLEAN_OPERAND, True)
- return n
+ op_name = op_name.replace('py2tf_utils.', '')
+ else:
+ template = """
+ tf.op_name(args)
+ """
+ replacement = templates.replace_as_expression(
+ template, op_name=op_name, args=args)
+ anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
+ return replacement
def visit_Compare(self, node):
node = self.generic_visit(node)
from tensorflow.contrib.py2tf.utils.builtins import dynamic_range
from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
from tensorflow.contrib.py2tf.utils.misc import alias_tensors
+from tensorflow.contrib.py2tf.utils.multiple_dispatch import dynamic_is
+from tensorflow.contrib.py2tf.utils.multiple_dispatch import dynamic_is_not
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
from tensorflow.contrib.py2tf.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+def dynamic_is(left, right):
+ if is_tensor(left, right):
+ return math_ops.equal(left.name, right.name)
+ else:
+ return left is right
+
+
+def dynamic_is_not(left, right):
+ if is_tensor(left, right):
+ return math_ops.not_equal(left.name, right.name)
+ else:
+ return left is not right
def run_cond(condition, true_fn, false_fn):
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+
+import numpy as np
+
from tensorflow.contrib.py2tf.utils import multiple_dispatch
from tensorflow.python.client.session import Session
from tensorflow.python.framework.constant_op import constant
class MultipleDispatchTest(test.TestCase):
+ def test_dynamic_is_python(self):
+ a = np.eye(3)
+ also_a = a
+ not_actually_a = np.eye(3)
+ should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
+ should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
+ should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
+ should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
+ self.assertTrue(should_be_true1)
+ self.assertTrue(should_be_true2)
+ self.assertFalse(should_be_false1)
+ self.assertFalse(should_be_false2)
+
+ def test_dynamic_is_tf(self):
+ with Session().as_default():
+ a = constant([2.0])
+ also_a = a
+ not_actually_a = constant([2.0])
+ should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
+ should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
+ should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
+ should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
+ self.assertTrue(should_be_true1.eval())
+ self.assertTrue(should_be_true2.eval())
+ self.assertFalse(should_be_false1.eval())
+ self.assertFalse(should_be_false2.eval())
+
def test_run_cond_python(self):
true_fn = lambda: 2.0
false_fn = lambda: 3.0