From 7feb32b92448f722aa089f599f75c59c82b901ba Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 12 Mar 2018 13:05:26 -0700 Subject: [PATCH] Add is_discrete, is_continuous, is_bounded methods to TensorSpecs. PiperOrigin-RevId: 188766232 --- tensorflow/python/framework/tensor_spec.py | 20 ++++++++++++++++++++ tensorflow/python/framework/tensor_spec_test.py | 21 +++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py index 27a9ab8..546c48a 100644 --- a/tensorflow/python/framework/tensor_spec.py +++ b/tensorflow/python/framework/tensor_spec.py @@ -65,6 +65,11 @@ class TensorSpec(object): else: raise ValueError("`tensor` should be a tf.Tensor") + @classmethod + def is_bounded(cls): + del cls + return False + @property def shape(self): """Returns the `TensorShape` that represents the shape of the tensor.""" @@ -80,6 +85,16 @@ class TensorSpec(object): """Returns the name of the described tensor.""" return self._name + @property + def is_discrete(self): + """Whether spec is discrete.""" + return self.dtype.is_integer + + @property + def is_continuous(self): + """Whether spec is continuous.""" + return self.dtype.is_floating + def is_compatible_with(self, spec_or_tensor): """True if the shape and dtype of `spec_or_tensor` are compatible.""" return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and @@ -164,6 +179,11 @@ class BoundedTensorSpec(TensorSpec): self._maximum.setflags(write=False) @classmethod + def is_bounded(cls): + del cls + return True + + @classmethod def from_spec(cls, spec): dtype = dtypes.as_dtype(spec.dtype) minimum = getattr(spec, "minimum", dtype.min) diff --git a/tensorflow/python/framework/tensor_spec_test.py b/tensorflow/python/framework/tensor_spec_test.py index 54ca4d9..b33d769 100644 --- a/tensorflow/python/framework/tensor_spec_test.py +++ b/tensorflow/python/framework/tensor_spec_test.py @@ -127,6 +127,22 @@ class TensorSpecTest(test_util.TensorFlowTestCase): self.assertEqual(bounded_spec.dtype, spec.dtype) self.assertEqual(bounded_spec.name, spec.name) + def testIsDiscrete(self): + discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32) + self.assertTrue(discrete_spec.is_discrete) + self.assertFalse(continuous_spec.is_discrete) + + def testIsContinuous(self): + discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32) + self.assertFalse(discrete_spec.is_continuous) + self.assertTrue(continuous_spec.is_continuous) + + def testIsBounded(self): + unbounded_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + self.assertFalse(unbounded_spec.is_bounded()) + class BoundedTensorSpecTest(test_util.TensorFlowTestCase): @@ -138,6 +154,11 @@ class BoundedTensorSpecTest(test_util.TensorFlowTestCase): with self.assertRaisesRegexp(ValueError, "not compatible"): tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1)) + def testIsBounded(self): + bounded_spec = tensor_spec.BoundedTensorSpec( + (1, 2), dtypes.int32, minimum=0, maximum=1) + self.assertTrue(bounded_spec.is_bounded()) + def testMinimumMaximumAttributes(self): spec = tensor_spec.BoundedTensorSpec( (1, 2, 3), dtypes.float32, 0, (5, 5, 5)) -- 2.7.4