TFE: Implement __r*__ operators for `Dimension`.
authorAkshay Agrawal <akshayka@google.com>
Fri, 9 Mar 2018 02:25:29 +0000 (18:25 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 9 Mar 2018 03:05:33 +0000 (19:05 -0800)
This lets you use Dimension objects in numerical computations; e.g.,
it lets you evaluate expressions like 3 + my_tensor.shape[0] when executing
eagerly.

At time of writing, without this change,

`matplotlib.pyplot.plt(my_tensor, my_other_tensor)`

fails when executing eagerly, but it works with this change.

This change also makes it possible to right-multiply a dimension by a list
(e.g., dimension * [3]); previously, only the left-multiply worked ([3] *
dimension).

PiperOrigin-RevId: 188424557

tensorflow/python/framework/tensor_shape.py
tensorflow/python/framework/tensor_shape_test.py

index 6f2ab84..af2a5b1 100644 (file)
@@ -156,7 +156,7 @@ class Dimension(object):
     ```
 
     Args:
-      other: Another Dimension.
+      other: Another Dimension, or a value accepted by `as_dimension`.
 
     Returns:
       A Dimension whose value is the sum of `self` and `other`.
@@ -167,6 +167,17 @@ class Dimension(object):
     else:
       return Dimension(self._value + other.value)
 
+  def __radd__(self, other):
+    """Returns the sum of `other` and `self`.
+
+    Args:
+      other: Another Dimension, or a value accepted by `as_dimension`.
+
+    Returns:
+      A Dimension whose value is the sum of `self` and `other`.
+    """
+    return self + other
+
   def __sub__(self, other):
     """Returns the subtraction of `other` from `self`.
 
@@ -180,10 +191,10 @@ class Dimension(object):
     ```
 
     Args:
-      other: Another Dimension.
+      other: Another Dimension, or a value accepted by `as_dimension`.
 
     Returns:
-      A Dimension whose value is the subtraction of sum of `other` from `self`.
+      A Dimension whose value is the subtraction of `other` from `self`.
     """
     other = as_dimension(other)
     if self._value is None or other.value is None:
@@ -191,6 +202,21 @@ class Dimension(object):
     else:
       return Dimension(self._value - other.value)
 
+  def __rsub__(self, other):
+    """Returns the subtraction of `self` from `other`.
+
+    Args:
+      other: Another Dimension, or a value accepted by `as_dimension`.
+
+    Returns:
+      A Dimension whose value is the subtraction of `self` from `other`.
+    """
+    other = as_dimension(other)
+    if self._value is None or other.value is None:
+      return Dimension(None)
+    else:
+      return Dimension(other.value - self._value)
+
   def __mul__(self, other):
     """Returns the product of `self` and `other`.
 
@@ -204,17 +230,32 @@ class Dimension(object):
     ```
 
     Args:
-      other: Another Dimension.
+      other: Another Dimension, or a value accepted by `as_dimension`.
 
     Returns:
       A Dimension whose value is the product of `self` and `other`.
     """
-    other = as_dimension(other)
+    try:
+      other = as_dimension(other)
+    except (TypeError, ValueError):
+      return NotImplemented
+
     if self._value is None or other.value is None:
       return Dimension(None)
     else:
       return Dimension(self._value * other.value)
 
+  def __rmul__(self, other):
+    """Returns the product of `self` and `other`.
+
+    Args:
+      other: Another Dimension, or a value accepted by `as_dimension`.
+
+    Returns:
+      A Dimension whose value is the product of `self` and `other`.
+    """
+    return self * other
+
   def __floordiv__(self, other):
     """Returns the quotient of `self` and `other` rounded down.
 
@@ -228,17 +269,35 @@ class Dimension(object):
     ```
 
     Args:
-      other: Another `Dimension`.
+      other: Another Dimension, or a value accepted by `as_dimension`.
 
     Returns:
       A `Dimension` whose value is the integer quotient of `self` and `other`.
     """
-    other = as_dimension(other)
+    try:
+      other = as_dimension(other)
+    except (TypeError, ValueError):
+      return NotImplemented
     if self._value is None or other.value is None:
       return Dimension(None)
     else:
       return Dimension(self._value // other.value)
 
+  def __rfloordiv__(self, other):
+    """Returns the quotient of `other` and `self` rounded down.
+
+    Args:
+      other: Another Dimension, or a value accepted by `as_dimension`.
+
+    Returns:
+      A `Dimension` whose value is the integer quotient of `self` and `other`.
+    """
+    other = as_dimension(other)
+    if self._value is None or other.value is None:
+      return Dimension(None)
+    else:
+      return Dimension(other.value // self._value)
+
   def __div__(self, other):
     """DEPRECATED: Use `__floordiv__` via `x // y` instead.
 
@@ -256,7 +315,7 @@ class Dimension(object):
     return self // other
 
   def __mod__(self, other):
-    """Returns `self` modulo `other.
+    """Returns `self` modulo `other`.
 
     Dimension moduli are computed as follows:
 
@@ -268,17 +327,35 @@ class Dimension(object):
     ```
 
     Args:
-      other: Another Dimension.
+      other: Another Dimension, or a value accepted by `as_dimension`.
 
     Returns:
       A Dimension whose value is `self` modulo `other`.
     """
-    other = as_dimension(other)
+    try:
+      other = as_dimension(other)
+    except (TypeError, ValueError):
+      return NotImplemented
     if self._value is None or other.value is None:
       return Dimension(None)
     else:
       return Dimension(self._value % other.value)
 
+  def __rmod__(self, other):
+    """Returns `other` modulo `self`.
+
+    Args:
+      other: Another Dimension, or a value accepted by `as_dimension`.
+
+    Returns:
+      A Dimension whose value is `other` modulo `self`.
+    """
+    try:
+      other = as_dimension(other)
+    except (TypeError, ValueError):
+      return NotImplemented
+    return other % self
+
   def __lt__(self, other):
     """Returns True if `self` is known to be less than `other`.
 
index fffd86c..4e8ce4d 100644 (file)
@@ -34,12 +34,20 @@ class DimensionTest(test_util.TensorFlowTestCase):
     self.assertEqual(tensor_shape.Dimension(15),
                      dim + tensor_shape.Dimension(3))
     self.assertEqual(tensor_shape.Dimension(15), dim + 3)
+    self.assertEqual(tensor_shape.Dimension(15), 3 + dim)
+    self.assertEqual(tensor_shape.Dimension(9), dim - 3)
+    self.assertEqual(tensor_shape.Dimension(1), 13 - dim)
     self.assertEqual(tensor_shape.Dimension(24),
                      dim * tensor_shape.Dimension(2))
     self.assertEqual(tensor_shape.Dimension(24), dim * 2)
+    self.assertEqual(tensor_shape.Dimension(24), 2 * dim)
+    self.assertEqual([4] * 12, [4] * dim)
+    self.assertEqual(12 * [4], dim * [4])
+    self.assertEqual(tensor_shape.Dimension(24), 2 * dim)
     self.assertEqual(
         tensor_shape.Dimension(6), dim // tensor_shape.Dimension(2))
     self.assertEqual(tensor_shape.Dimension(6), dim // 2)
+    self.assertEqual(tensor_shape.Dimension(0), 2 // dim)
     self.assertEqual(tensor_shape.Dimension(12),
                      dim.merge_with(tensor_shape.Dimension(12)))
     self.assertEqual(tensor_shape.Dimension(12), dim.merge_with(12))
@@ -176,6 +184,14 @@ class DimensionTest(test_util.TensorFlowTestCase):
     self.assertEqual(str(tensor_shape.Dimension(7)), "7")
     self.assertEqual(str(tensor_shape.Dimension(None)), "?")
 
+  def testMod(self):
+    four = tensor_shape.Dimension(4)
+    nine = tensor_shape.Dimension(9)
+    self.assertEqual(nine % four, 1)
+    # test both __mod__ and __rmod__.
+    self.assertEqual(nine % 4, 1)
+    self.assertEqual(4 % nine, 4)
+
 
 class ShapeTest(test_util.TensorFlowTestCase):