from tensorflow.contrib.py2tf.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import math_ops
def dynamic_is(left, right):
- if is_tensor(left, right):
- return math_ops.equal(left.name, right.name)
- else:
- return left is right
+ # TODO(alexbw) if we're sure we should leave 'is' in place,
+ # then change the semantics in converters/logical_expressions.py
+ return left is right
def dynamic_is_not(left, right):
- if is_tensor(left, right):
- return math_ops.not_equal(left.name, right.name)
- else:
- return left is not right
+ return left is not right
def run_cond(condition, true_fn, false_fn):
should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1.eval())
- self.assertTrue(should_be_true2.eval())
- self.assertFalse(should_be_false1.eval())
- self.assertFalse(should_be_false2.eval())
+ self.assertTrue(should_be_true1)
+ self.assertTrue(should_be_true2)
+ self.assertFalse(should_be_false1)
+ self.assertFalse(should_be_false2)
def test_run_cond_python(self):
true_fn = lambda: 2.0