Add is_discrete, is_continuous, is_bounded methods to TensorSpecs.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 20:05:26 +0000 (13:05 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 20:11:36 +0000 (13:11 -0700)
PiperOrigin-RevId: 188766232

tensorflow/python/framework/tensor_spec.py
tensorflow/python/framework/tensor_spec_test.py

index 27a9ab8..546c48a 100644 (file)
@@ -65,6 +65,11 @@ class TensorSpec(object):
     else:
       raise ValueError("`tensor` should be a tf.Tensor")
 
+  @classmethod
+  def is_bounded(cls):
+    del cls
+    return False
+
   @property
   def shape(self):
     """Returns the `TensorShape` that represents the shape of the tensor."""
@@ -80,6 +85,16 @@ class TensorSpec(object):
     """Returns the name of the described tensor."""
     return self._name
 
+  @property
+  def is_discrete(self):
+    """Whether spec is discrete."""
+    return self.dtype.is_integer
+
+  @property
+  def is_continuous(self):
+    """Whether spec is continuous."""
+    return self.dtype.is_floating
+
   def is_compatible_with(self, spec_or_tensor):
     """True if the shape and dtype of `spec_or_tensor` are compatible."""
     return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
@@ -164,6 +179,11 @@ class BoundedTensorSpec(TensorSpec):
     self._maximum.setflags(write=False)
 
   @classmethod
+  def is_bounded(cls):
+    del cls
+    return True
+
+  @classmethod
   def from_spec(cls, spec):
     dtype = dtypes.as_dtype(spec.dtype)
     minimum = getattr(spec, "minimum", dtype.min)
index 54ca4d9..b33d769 100644 (file)
@@ -127,6 +127,22 @@ class TensorSpecTest(test_util.TensorFlowTestCase):
     self.assertEqual(bounded_spec.dtype, spec.dtype)
     self.assertEqual(bounded_spec.name, spec.name)
 
+  def testIsDiscrete(self):
+    discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+    continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32)
+    self.assertTrue(discrete_spec.is_discrete)
+    self.assertFalse(continuous_spec.is_discrete)
+
+  def testIsContinuous(self):
+    discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+    continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32)
+    self.assertFalse(discrete_spec.is_continuous)
+    self.assertTrue(continuous_spec.is_continuous)
+
+  def testIsBounded(self):
+    unbounded_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+    self.assertFalse(unbounded_spec.is_bounded())
+
 
 class BoundedTensorSpecTest(test_util.TensorFlowTestCase):
 
@@ -138,6 +154,11 @@ class BoundedTensorSpecTest(test_util.TensorFlowTestCase):
     with self.assertRaisesRegexp(ValueError, "not compatible"):
       tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1))
 
+  def testIsBounded(self):
+    bounded_spec = tensor_spec.BoundedTensorSpec(
+        (1, 2), dtypes.int32, minimum=0, maximum=1)
+    self.assertTrue(bounded_spec.is_bounded())
+
   def testMinimumMaximumAttributes(self):
     spec = tensor_spec.BoundedTensorSpec(
         (1, 2, 3), dtypes.float32, 0, (5, 5, 5))