Adding TensorSpec to represent the specification of Tensors.
authorSergio Guadarrama <sguada@google.com>
Mon, 5 Feb 2018 23:16:29 +0000 (15:16 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 5 Feb 2018 23:23:19 +0000 (15:23 -0800)
PiperOrigin-RevId: 184594856

tensorflow/contrib/framework/__init__.py
tensorflow/python/BUILD
tensorflow/python/framework/tensor_spec.py [new file with mode: 0644]
tensorflow/python/framework/tensor_spec_test.py [new file with mode: 0644]

index 503b868..fb101c3 100644 (file)
@@ -86,6 +86,9 @@ See the @{$python/contrib.framework} guide.
 @@sort
 
 @@CriticalSection
+
+@@BoundedTensorSpec
+@@TensorSpec
 """
 
 from __future__ import absolute_import
@@ -100,6 +103,9 @@ from tensorflow.contrib.framework.python.ops import *
 from tensorflow.python.framework.ops import prepend_name_scope
 from tensorflow.python.framework.ops import strip_name_scope
 
+from tensorflow.python.framework.tensor_spec import BoundedTensorSpec
+from tensorflow.python.framework.tensor_spec import TensorSpec
+
 from tensorflow.python.util.all_util import remove_undocumented
 
 _allowed_symbols = ['nest']
index a89142d..2b4d5b8 100644 (file)
@@ -576,6 +576,7 @@ py_library(
         ":pywrap_tensorflow",
         ":random_seed",
         ":sparse_tensor",
+        ":tensor_spec",
         ":tensor_util",
         ":util",
         "//tensorflow/python/eager:context",
@@ -781,6 +782,18 @@ py_library(
 )
 
 py_library(
+    name = "tensor_spec",
+    srcs = ["framework/tensor_spec.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":common_shapes",
+        ":dtypes",
+        ":tensor_shape",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_library(
     name = "tensor_util",
     srcs = ["framework/tensor_util.py"],
     srcs_version = "PY2AND3",
@@ -1149,6 +1162,21 @@ py_test(
 )
 
 py_test(
+    name = "framework_tensor_spec_test",
+    size = "small",
+    srcs = ["framework/tensor_spec_test.py"],
+    main = "framework/tensor_spec_test.py",
+    srcs_version = "PY2AND3",
+    deps = [
+        ":framework_for_generated_wrappers",
+        ":framework_test_lib",
+        ":platform_test",
+        ":tensor_spec",
+        "//third_party/py/numpy",
+    ],
+)
+
+py_test(
     name = "framework_sparse_tensor_test",
     size = "small",
     srcs = ["framework/sparse_tensor_test.py"],
diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py
new file mode 100644 (file)
index 0000000..a0411bc
--- /dev/null
@@ -0,0 +1,201 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A TensorSpec class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import common_shapes
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class TensorSpec(object):
+  """Describes a tf.Tensor.
+
+  A TensorSpec allows an API to describe the Tensors that it accepts or
+  returns, before that Tensor exists. This allows dynamic and flexible graph
+  construction and configuration.
+  """
+
+  __slots__ = ["_shape", "_dtype", "_name"]
+
+  def __init__(self, shape, dtype, name=None):
+    """Creates a TensorSpec.
+
+    Args:
+      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
+      dtype: Value convertible to `tf.DType`. The type of the tensor values.
+      name: Optional name for the Tensor.
+
+    Raises:
+      TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
+        not convertible to a `tf.DType`.
+    """
+    self._shape = tensor_shape.TensorShape(shape)
+    self._dtype = dtypes.as_dtype(dtype)
+    self._name = name
+
+  @classmethod
+  def from_spec(cls, spec, name=None):
+    return cls(spec.shape, spec.dtype, name or spec.name)
+
+  @classmethod
+  def from_tensor(cls, tensor, name=None):
+    if isinstance(tensor, ops.EagerTensor):
+      return TensorSpec(tensor.shape, tensor.dtype, name)
+    elif isinstance(tensor, ops.Tensor):
+      return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
+    else:
+      raise ValueError("`tensor` should be a tf.Tensor")
+
+  @property
+  def shape(self):
+    """Returns the `TensorShape` that represents the shape of the tensor."""
+    return self._shape
+
+  @property
+  def dtype(self):
+    """Returns the `dtype` of elements in the tensor."""
+    return self._dtype
+
+  @property
+  def name(self):
+    """Returns the name of the described tensor."""
+    return self._name
+
+  def is_compatible_with(self, spec_or_tensor):
+    """True if the shape and dtype of `spec_or_tensor` are compatible."""
+    return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
+            self._shape.is_compatible_with(spec_or_tensor.shape))
+
+  def __repr__(self):
+    return "TensorSpec(shape={}, dtype={}, name={})".format(
+        self.shape, repr(self.dtype), repr(self.name))
+
+  def __eq__(self, other):
+    return self.shape == other.shape and self.dtype == other.dtype
+
+  def __ne__(self, other):
+    return not self == other
+
+
+class BoundedTensorSpec(TensorSpec):
+  """A `TensorSpec` that specifies minimum and maximum values.
+
+  Example usage:
+  ```python
+  spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5))
+  tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype)
+  tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype)
+  ```
+
+  Bounds are meant to be inclusive. This is especially important for
+  integer types. The following spec will be satisfied by tensors
+  with values in the set {0, 1, 2}:
+  ```python
+  spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2)
+  ```
+  """
+
+  __slots__ = ("_minimum", "_maximum")
+
+  def __init__(self, shape, dtype, minimum, maximum, name=None):
+    """Initializes a new `BoundedTensorSpec`.
+
+    Args:
+      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
+      dtype: Value convertible to `tf.DType`. The type of the tensor values.
+      minimum: Number or sequence specifying the minimum element bounds
+        (inclusive). Must be broadcastable to `shape`.
+      maximum: Number or sequence specifying the maximum element bounds
+        (inclusive). Must be broadcastable to `shape`.
+      name: Optional string containing a semantic name for the corresponding
+        array. Defaults to `None`.
+
+    Raises:
+      ValueError: If `minimum` or `maximum` are not provided or not
+        broadcastable to `shape`.
+      TypeError: If the shape is not an iterable or if the `dtype` is an invalid
+        numpy dtype.
+    """
+    super(BoundedTensorSpec, self).__init__(shape, dtype, name)
+
+    if minimum is None or maximum is None:
+      raise ValueError("minimum and maximum must be provided; but saw "
+                       "'%s' and '%s'" % (minimum, maximum))
+
+    try:
+      minimum_shape = np.shape(minimum)
+      common_shapes.broadcast_shape(
+          tensor_shape.TensorShape(minimum_shape), self.shape)
+    except ValueError as exception:
+      raise ValueError("minimum is not compatible with shape. "
+                       "Message: {!r}.".format(exception))
+
+    try:
+      maximum_shape = np.shape(maximum)
+      common_shapes.broadcast_shape(
+          tensor_shape.TensorShape(maximum_shape), self.shape)
+    except ValueError as exception:
+      raise ValueError("maximum is not compatible with shape. "
+                       "Message: {!r}.".format(exception))
+
+    self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype())
+    self._minimum.setflags(write=False)
+
+    self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype())
+    self._maximum.setflags(write=False)
+
+  @classmethod
+  def from_spec(cls, spec):
+    dtype = dtypes.as_dtype(spec.dtype)
+    if dtype in [dtypes.float64, dtypes.float32]:
+      # Avoid under/over-flow for `dtype.maximum - dtype.minimum`.
+      low = dtype.min / 2
+      high = dtype.max / 2
+    else:
+      low = dtype.min
+      high = dtype.max
+
+    minimum = getattr(spec, "minimum", low)
+    maximum = getattr(spec, "maximum", high)
+    return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name)
+
+  @property
+  def minimum(self):
+    """Returns a NumPy array specifying the minimum bounds (inclusive)."""
+    return self._minimum
+
+  @property
+  def maximum(self):
+    """Returns a NumPy array specifying the maximum bounds (inclusive)."""
+    return self._maximum
+
+  def __repr__(self):
+    s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})"
+    return s.format(self.shape, repr(self.dtype), repr(self.name),
+                    repr(self.minimum), repr(self.maximum))
+
+  def __eq__(self, other):
+    tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other)
+    return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and
+            np.allclose(self.maximum, other.maximum))
+
+
diff --git a/tensorflow/python/framework/tensor_spec_test.py b/tensorflow/python/framework/tensor_spec_test.py
new file mode 100644 (file)
index 0000000..54ca4d9
--- /dev/null
@@ -0,0 +1,227 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for tensor_spec."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import googletest
+
+
+class TensorSpecTest(test_util.TensorFlowTestCase):
+
+  def testAcceptsNumpyDType(self):
+    desc = tensor_spec.TensorSpec([1], np.float32)
+    self.assertEqual(desc.dtype, dtypes.float32)
+
+  def testAcceptsTensorShape(self):
+    desc = tensor_spec.TensorSpec(tensor_shape.TensorShape([1]), dtypes.float32)
+    self.assertEqual(desc.shape, tensor_shape.TensorShape([1]))
+
+  def testUnknownShape(self):
+    desc = tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32)
+    self.assertEqual(desc.shape, tensor_shape.TensorShape(None))
+
+  def testShapeCompatibility(self):
+    unknown = array_ops.placeholder(dtypes.int64)
+    partial = array_ops.placeholder(dtypes.int64, shape=[None, 1])
+    full = array_ops.placeholder(dtypes.int64, shape=[2, 3])
+    rank3 = array_ops.placeholder(dtypes.int64, shape=[4, 5, 6])
+
+    desc_unknown = tensor_spec.TensorSpec(None, dtypes.int64)
+    self.assertTrue(desc_unknown.is_compatible_with(unknown))
+    self.assertTrue(desc_unknown.is_compatible_with(partial))
+    self.assertTrue(desc_unknown.is_compatible_with(full))
+    self.assertTrue(desc_unknown.is_compatible_with(rank3))
+
+    desc_partial = tensor_spec.TensorSpec([2, None], dtypes.int64)
+    self.assertTrue(desc_partial.is_compatible_with(unknown))
+    self.assertTrue(desc_partial.is_compatible_with(partial))
+    self.assertTrue(desc_partial.is_compatible_with(full))
+    self.assertFalse(desc_partial.is_compatible_with(rank3))
+
+    desc_full = tensor_spec.TensorSpec([2, 3], dtypes.int64)
+    self.assertTrue(desc_full.is_compatible_with(unknown))
+    self.assertFalse(desc_full.is_compatible_with(partial))
+    self.assertTrue(desc_full.is_compatible_with(full))
+    self.assertFalse(desc_full.is_compatible_with(rank3))
+
+    desc_rank3 = tensor_spec.TensorSpec([4, 5, 6], dtypes.int64)
+    self.assertTrue(desc_rank3.is_compatible_with(unknown))
+    self.assertFalse(desc_rank3.is_compatible_with(partial))
+    self.assertFalse(desc_rank3.is_compatible_with(full))
+    self.assertTrue(desc_rank3.is_compatible_with(rank3))
+
+  def testTypeCompatibility(self):
+    floats = array_ops.placeholder(dtypes.float32, shape=[10, 10])
+    ints = array_ops.placeholder(dtypes.int32, shape=[10, 10])
+    desc = tensor_spec.TensorSpec(shape=(10, 10), dtype=dtypes.float32)
+    self.assertTrue(desc.is_compatible_with(floats))
+    self.assertFalse(desc.is_compatible_with(ints))
+
+  def testName(self):
+    desc = tensor_spec.TensorSpec([1], dtypes.float32, name="beep")
+    self.assertEqual(desc.name, "beep")
+
+  def testRepr(self):
+    desc1 = tensor_spec.TensorSpec([1], dtypes.float32, name="beep")
+    self.assertEqual(
+        repr(desc1),
+        "TensorSpec(shape=(1,), dtype=tf.float32, name='beep')")
+    desc2 = tensor_spec.TensorSpec([1, None], dtypes.int32)
+    self.assertEqual(
+        repr(desc2),
+        "TensorSpec(shape=(1, ?), dtype=tf.int32, name=None)")
+
+  def testFromTensorSpec(self):
+    spec_1 = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+    spec_2 = tensor_spec.TensorSpec.from_spec(spec_1)
+    self.assertEqual(spec_1, spec_2)
+
+  def testFromTensor(self):
+    zero = constant_op.constant(0)
+    spec = tensor_spec.TensorSpec.from_tensor(zero)
+    self.assertEqual(spec.dtype, dtypes.int32)
+    self.assertEqual(spec.shape, [])
+    self.assertEqual(spec.name, "Const")
+
+  def testFromPlaceholder(self):
+    unknown = array_ops.placeholder(dtypes.int64, name="unknown")
+    partial = array_ops.placeholder(dtypes.float32,
+                                    shape=[None, 1],
+                                    name="partial")
+    spec_1 = tensor_spec.TensorSpec.from_tensor(unknown)
+    self.assertEqual(spec_1.dtype, dtypes.int64)
+    self.assertEqual(spec_1.shape, None)
+    self.assertEqual(spec_1.name, "unknown")
+    spec_2 = tensor_spec.TensorSpec.from_tensor(partial)
+    self.assertEqual(spec_2.dtype, dtypes.float32)
+    self.assertEqual(spec_2.shape.as_list(), [None, 1])
+    self.assertEqual(spec_2.name, "partial")
+
+  def testFromBoundedTensorSpec(self):
+    bounded_spec = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, 0, 1)
+    spec = tensor_spec.TensorSpec.from_spec(bounded_spec)
+    self.assertEqual(bounded_spec.shape, spec.shape)
+    self.assertEqual(bounded_spec.dtype, spec.dtype)
+    self.assertEqual(bounded_spec.name, spec.name)
+
+
+class BoundedTensorSpecTest(test_util.TensorFlowTestCase):
+
+  def testInvalidMinimum(self):
+    with self.assertRaisesRegexp(ValueError, "not compatible"):
+      tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, (0, 0, 0), (1, 1))
+
+  def testInvalidMaximum(self):
+    with self.assertRaisesRegexp(ValueError, "not compatible"):
+      tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1))
+
+  def testMinimumMaximumAttributes(self):
+    spec = tensor_spec.BoundedTensorSpec(
+        (1, 2, 3), dtypes.float32, 0, (5, 5, 5))
+    self.assertEqual(type(spec.minimum), np.ndarray)
+    self.assertEqual(type(spec.maximum), np.ndarray)
+    self.assertAllEqual(spec.minimum, np.array(0, dtype=np.float32))
+    self.assertAllEqual(spec.maximum, np.array([5, 5, 5], dtype=np.float32))
+
+  def testNotWriteableNP(self):
+    spec = tensor_spec.BoundedTensorSpec(
+        (1, 2, 3), dtypes.float32, 0, (5, 5, 5))
+    with self.assertRaisesRegexp(ValueError, "read-only"):
+      spec.minimum[0] = -1
+    with self.assertRaisesRegexp(ValueError, "read-only"):
+      spec.maximum[0] = 100
+
+  def testReuseSpec(self):
+    spec_1 = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32,
+                                           minimum=0, maximum=1)
+    spec_2 = tensor_spec.BoundedTensorSpec(
+        spec_1.shape, spec_1.dtype, spec_1.minimum, spec_1.maximum)
+    self.assertEqual(spec_1, spec_2)
+
+  def testScalarBounds(self):
+    spec = tensor_spec.BoundedTensorSpec(
+        (), dtypes.float32, minimum=0.0, maximum=1.0)
+
+    self.assertIsInstance(spec.minimum, np.ndarray)
+    self.assertIsInstance(spec.maximum, np.ndarray)
+
+    # Sanity check that numpy compares correctly to a scalar for an empty shape.
+    self.assertEqual(0.0, spec.minimum)
+    self.assertEqual(1.0, spec.maximum)
+
+    # Check that the spec doesn't fail its own input validation.
+    _ = tensor_spec.BoundedTensorSpec(
+        spec.shape, spec.dtype, spec.minimum, spec.maximum)
+
+  def testFromBoundedTensorSpec(self):
+    spec_1 = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32,
+                                           minimum=0, maximum=1)
+    spec_2 = tensor_spec.BoundedTensorSpec.from_spec(spec_1)
+    self.assertEqual(spec_1, spec_2)
+
+  def testEquality(self):
+    spec_1_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32,
+                                             0, (5, 5, 5))
+    spec_1_2 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32,
+                                             0.00000001,
+                                             (5, 5, 5.00000000000000001))
+    spec_2_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32,
+                                             1, (5, 5, 5))
+    spec_2_2 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32,
+                                             (1, 1, 1), (5, 5, 5))
+    spec_2_3 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32,
+                                             (1, 1, 1), 5)
+    spec_3_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32,
+                                             (2, 1, 1), (5, 5, 5))
+
+    self.assertEqual(spec_1_1, spec_1_2)
+    self.assertEqual(spec_1_2, spec_1_1)
+
+    self.assertNotEqual(spec_1_1, spec_2_2)
+    self.assertNotEqual(spec_1_1, spec_2_1)
+    self.assertNotEqual(spec_2_2, spec_1_1)
+    self.assertNotEqual(spec_2_1, spec_1_1)
+
+    self.assertEqual(spec_2_1, spec_2_2)
+    self.assertEqual(spec_2_2, spec_2_1)
+    self.assertEqual(spec_2_2, spec_2_3)
+
+    self.assertNotEqual(spec_1_1, spec_3_1)
+    self.assertNotEqual(spec_2_1, spec_3_1)
+    self.assertNotEqual(spec_2_2, spec_3_1)
+
+  def testFromTensorSpec(self):
+    spec = tensor_spec.TensorSpec((1, 2), dtypes.int32)
+    bounded_spec = tensor_spec.BoundedTensorSpec.from_spec(spec)
+    self.assertEqual(spec.shape, bounded_spec.shape)
+    self.assertEqual(spec.dtype, bounded_spec.dtype)
+    self.assertEqual(spec.dtype.min, bounded_spec.minimum)
+    self.assertEqual(spec.dtype.max, bounded_spec.maximum)
+    self.assertEqual(spec.name, bounded_spec.name)
+
+
+if __name__ == "__main__":
+  googletest.main()