else:
raise ValueError("`tensor` should be a tf.Tensor")
+ @classmethod
+ def is_bounded(cls):
+ del cls
+ return False
+
@property
def shape(self):
"""Returns the `TensorShape` that represents the shape of the tensor."""
"""Returns the name of the described tensor."""
return self._name
+ @property
+ def is_discrete(self):
+ """Whether spec is discrete."""
+ return self.dtype.is_integer
+
+ @property
+ def is_continuous(self):
+ """Whether spec is continuous."""
+ return self.dtype.is_floating
+
def is_compatible_with(self, spec_or_tensor):
"""True if the shape and dtype of `spec_or_tensor` are compatible."""
return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
self._maximum.setflags(write=False)
@classmethod
+ def is_bounded(cls):
+ del cls
+ return True
+
+ @classmethod
def from_spec(cls, spec):
dtype = dtypes.as_dtype(spec.dtype)
minimum = getattr(spec, "minimum", dtype.min)
self.assertEqual(bounded_spec.dtype, spec.dtype)
self.assertEqual(bounded_spec.name, spec.name)
+ def testIsDiscrete(self):
+ discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+ continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32)
+ self.assertTrue(discrete_spec.is_discrete)
+ self.assertFalse(continuous_spec.is_discrete)
+
+ def testIsContinuous(self):
+ discrete_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+ continuous_spec = tensor_spec.TensorSpec((1, 2), dtypes.float32)
+ self.assertFalse(discrete_spec.is_continuous)
+ self.assertTrue(continuous_spec.is_continuous)
+
+ def testIsBounded(self):
+ unbounded_spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+ self.assertFalse(unbounded_spec.is_bounded())
+
class BoundedTensorSpecTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, "not compatible"):
tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1))
+ def testIsBounded(self):
+ bounded_spec = tensor_spec.BoundedTensorSpec(
+ (1, 2), dtypes.int32, minimum=0, maximum=1)
+ self.assertTrue(bounded_spec.is_bounded())
+
def testMinimumMaximumAttributes(self):
spec = tensor_spec.BoundedTensorSpec(
(1, 2, 3), dtypes.float32, 0, (5, 5, 5))