Support for the Is and IsNot comparator nodes
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 14 Mar 2018 14:01:56 +0000 (07:01 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Mar 2018 14:06:20 +0000 (07:06 -0700)
PiperOrigin-RevId: 189022386

tensorflow/contrib/py2tf/converters/logical_expressions.py
tensorflow/contrib/py2tf/utils/__init__.py
tensorflow/contrib/py2tf/utils/multiple_dispatch.py
tensorflow/contrib/py2tf/utils/multiple_dispatch_test.py

index 766aa11..10192e6 100644 (file)
@@ -55,6 +55,8 @@ class LogicalExpressionTransformer(transformer.Base):
         gast.NotEq: 'not_equal',
         gast.Or: 'logical_or',
         gast.USub: 'negative',
+        gast.Is: 'py2tf_utils.dynamic_is',
+        gast.IsNot: 'py2tf_utils.dynamic_is_not'
     }
 
   def _expect_simple_symbol(self, operand):
@@ -76,14 +78,21 @@ class LogicalExpressionTransformer(transformer.Base):
     return mapped_op
 
   def _inline_tf_op(self, op_name, args):
-    template = """
-      tf.op_name(args)
+    if 'py2tf_utils' in op_name:
+      # TODO(alexbw): explicitly spelling out the attribute function name
+      # until fix for issue highlighted in cl/188931581 lands.
+      template = """
+      py2tf_utils.op_name(args)
     """
-    replacement = templates.replace(template, op_name=op_name, args=args)
-    # It's a body with a single expression, we want its value.
-    n = replacement[0].value
-    anno.setanno(n, SAFE_BOOLEAN_OPERAND, True)
-    return n
+      op_name = op_name.replace('py2tf_utils.', '')
+    else:
+      template = """
+        tf.op_name(args)
+      """
+    replacement = templates.replace_as_expression(
+        template, op_name=op_name, args=args)
+    anno.setanno(replacement, SAFE_BOOLEAN_OPERAND, True)
+    return replacement
 
   def visit_Compare(self, node):
     node = self.generic_visit(node)
index 4fc0121..d9d8e34 100644 (file)
@@ -23,6 +23,8 @@ from tensorflow.contrib.py2tf.utils.builtins import dynamic_print
 from tensorflow.contrib.py2tf.utils.builtins import dynamic_range
 from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns
 from tensorflow.contrib.py2tf.utils.misc import alias_tensors
+from tensorflow.contrib.py2tf.utils.multiple_dispatch import dynamic_is
+from tensorflow.contrib.py2tf.utils.multiple_dispatch import dynamic_is_not
 from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond
 from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while
 from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func
index a855fdc..da7a942 100644 (file)
@@ -22,6 +22,21 @@ import six
 
 from tensorflow.contrib.py2tf.utils.type_check import is_tensor
 from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+def dynamic_is(left, right):
+  if is_tensor(left, right):
+    return math_ops.equal(left.name, right.name)
+  else:
+    return left is right
+
+
+def dynamic_is_not(left, right):
+  if is_tensor(left, right):
+    return math_ops.not_equal(left.name, right.name)
+  else:
+    return left is not right
 
 
 def run_cond(condition, true_fn, false_fn):
index 5bb4d40..8d89b68 100644 (file)
@@ -17,6 +17,9 @@
 from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
+
+import numpy as np
+
 from tensorflow.contrib.py2tf.utils import multiple_dispatch
 from tensorflow.python.client.session import Session
 from tensorflow.python.framework.constant_op import constant
@@ -25,6 +28,33 @@ from tensorflow.python.platform import test
 
 class MultipleDispatchTest(test.TestCase):
 
+  def test_dynamic_is_python(self):
+    a = np.eye(3)
+    also_a = a
+    not_actually_a = np.eye(3)
+    should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
+    should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
+    should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
+    should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
+    self.assertTrue(should_be_true1)
+    self.assertTrue(should_be_true2)
+    self.assertFalse(should_be_false1)
+    self.assertFalse(should_be_false2)
+
+  def test_dynamic_is_tf(self):
+    with Session().as_default():
+      a = constant([2.0])
+      also_a = a
+      not_actually_a = constant([2.0])
+      should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
+      should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
+      should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
+      should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
+      self.assertTrue(should_be_true1.eval())
+      self.assertTrue(should_be_true2.eval())
+      self.assertFalse(should_be_false1.eval())
+      self.assertFalse(should_be_false2.eval())
+
   def test_run_cond_python(self):
     true_fn = lambda: 2.0
     false_fn = lambda: 3.0