From: Sergio Guadarrama Date: Mon, 5 Feb 2018 23:16:29 +0000 (-0800) Subject: Adding TensorSpec to represent the specification of Tensors. X-Git-Tag: upstream/v1.7.0~31^2~994 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=2074a568810ea20c267e9d063a066b81ad491eed;p=platform%2Fupstream%2Ftensorflow.git Adding TensorSpec to represent the specification of Tensors. PiperOrigin-RevId: 184594856 --- diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 503b868..fb101c3 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -86,6 +86,9 @@ See the @{$python/contrib.framework} guide. @@sort @@CriticalSection + +@@BoundedTensorSpec +@@TensorSpec """ from __future__ import absolute_import @@ -100,6 +103,9 @@ from tensorflow.contrib.framework.python.ops import * from tensorflow.python.framework.ops import prepend_name_scope from tensorflow.python.framework.ops import strip_name_scope +from tensorflow.python.framework.tensor_spec import BoundedTensorSpec +from tensorflow.python.framework.tensor_spec import TensorSpec + from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index a89142d..2b4d5b8 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -576,6 +576,7 @@ py_library( ":pywrap_tensorflow", ":random_seed", ":sparse_tensor", + ":tensor_spec", ":tensor_util", ":util", "//tensorflow/python/eager:context", @@ -781,6 +782,18 @@ py_library( ) py_library( + name = "tensor_spec", + srcs = ["framework/tensor_spec.py"], + srcs_version = "PY2AND3", + deps = [ + ":common_shapes", + ":dtypes", + ":tensor_shape", + "//third_party/py/numpy", + ], +) + +py_library( name = "tensor_util", srcs = ["framework/tensor_util.py"], srcs_version = "PY2AND3", @@ -1149,6 +1162,21 @@ py_test( ) py_test( + name = "framework_tensor_spec_test", + size = "small", + srcs = ["framework/tensor_spec_test.py"], + main = "framework/tensor_spec_test.py", + srcs_version = "PY2AND3", + deps = [ + ":framework_for_generated_wrappers", + ":framework_test_lib", + ":platform_test", + ":tensor_spec", + "//third_party/py/numpy", + ], +) + +py_test( name = "framework_sparse_tensor_test", size = "small", srcs = ["framework/sparse_tensor_test.py"], diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py new file mode 100644 index 0000000..a0411bc --- /dev/null +++ b/tensorflow/python/framework/tensor_spec.py @@ -0,0 +1,201 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TensorSpec class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import common_shapes +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class TensorSpec(object): + """Describes a tf.Tensor. + + A TensorSpec allows an API to describe the Tensors that it accepts or + returns, before that Tensor exists. This allows dynamic and flexible graph + construction and configuration. + """ + + __slots__ = ["_shape", "_dtype", "_name"] + + def __init__(self, shape, dtype, name=None): + """Creates a TensorSpec. + + Args: + shape: Value convertible to `tf.TensorShape`. The shape of the tensor. + dtype: Value convertible to `tf.DType`. The type of the tensor values. + name: Optional name for the Tensor. + + Raises: + TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is + not convertible to a `tf.DType`. + """ + self._shape = tensor_shape.TensorShape(shape) + self._dtype = dtypes.as_dtype(dtype) + self._name = name + + @classmethod + def from_spec(cls, spec, name=None): + return cls(spec.shape, spec.dtype, name or spec.name) + + @classmethod + def from_tensor(cls, tensor, name=None): + if isinstance(tensor, ops.EagerTensor): + return TensorSpec(tensor.shape, tensor.dtype, name) + elif isinstance(tensor, ops.Tensor): + return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) + else: + raise ValueError("`tensor` should be a tf.Tensor") + + @property + def shape(self): + """Returns the `TensorShape` that represents the shape of the tensor.""" + return self._shape + + @property + def dtype(self): + """Returns the `dtype` of elements in the tensor.""" + return self._dtype + + @property + def name(self): + """Returns the name of the described tensor.""" + return self._name + + def is_compatible_with(self, spec_or_tensor): + """True if the shape and dtype of `spec_or_tensor` are compatible.""" + return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and + self._shape.is_compatible_with(spec_or_tensor.shape)) + + def __repr__(self): + return "TensorSpec(shape={}, dtype={}, name={})".format( + self.shape, repr(self.dtype), repr(self.name)) + + def __eq__(self, other): + return self.shape == other.shape and self.dtype == other.dtype + + def __ne__(self, other): + return not self == other + + +class BoundedTensorSpec(TensorSpec): + """A `TensorSpec` that specifies minimum and maximum values. + + Example usage: + ```python + spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5)) + tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype) + tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype) + ``` + + Bounds are meant to be inclusive. This is especially important for + integer types. The following spec will be satisfied by tensors + with values in the set {0, 1, 2}: + ```python + spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2) + ``` + """ + + __slots__ = ("_minimum", "_maximum") + + def __init__(self, shape, dtype, minimum, maximum, name=None): + """Initializes a new `BoundedTensorSpec`. + + Args: + shape: Value convertible to `tf.TensorShape`. The shape of the tensor. + dtype: Value convertible to `tf.DType`. The type of the tensor values. + minimum: Number or sequence specifying the minimum element bounds + (inclusive). Must be broadcastable to `shape`. + maximum: Number or sequence specifying the maximum element bounds + (inclusive). Must be broadcastable to `shape`. + name: Optional string containing a semantic name for the corresponding + array. Defaults to `None`. + + Raises: + ValueError: If `minimum` or `maximum` are not provided or not + broadcastable to `shape`. + TypeError: If the shape is not an iterable or if the `dtype` is an invalid + numpy dtype. + """ + super(BoundedTensorSpec, self).__init__(shape, dtype, name) + + if minimum is None or maximum is None: + raise ValueError("minimum and maximum must be provided; but saw " + "'%s' and '%s'" % (minimum, maximum)) + + try: + minimum_shape = np.shape(minimum) + common_shapes.broadcast_shape( + tensor_shape.TensorShape(minimum_shape), self.shape) + except ValueError as exception: + raise ValueError("minimum is not compatible with shape. " + "Message: {!r}.".format(exception)) + + try: + maximum_shape = np.shape(maximum) + common_shapes.broadcast_shape( + tensor_shape.TensorShape(maximum_shape), self.shape) + except ValueError as exception: + raise ValueError("maximum is not compatible with shape. " + "Message: {!r}.".format(exception)) + + self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype()) + self._minimum.setflags(write=False) + + self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype()) + self._maximum.setflags(write=False) + + @classmethod + def from_spec(cls, spec): + dtype = dtypes.as_dtype(spec.dtype) + if dtype in [dtypes.float64, dtypes.float32]: + # Avoid under/over-flow for `dtype.maximum - dtype.minimum`. + low = dtype.min / 2 + high = dtype.max / 2 + else: + low = dtype.min + high = dtype.max + + minimum = getattr(spec, "minimum", low) + maximum = getattr(spec, "maximum", high) + return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name) + + @property + def minimum(self): + """Returns a NumPy array specifying the minimum bounds (inclusive).""" + return self._minimum + + @property + def maximum(self): + """Returns a NumPy array specifying the maximum bounds (inclusive).""" + return self._maximum + + def __repr__(self): + s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})" + return s.format(self.shape, repr(self.dtype), repr(self.name), + repr(self.minimum), repr(self.maximum)) + + def __eq__(self, other): + tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other) + return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and + np.allclose(self.maximum, other.maximum)) + + diff --git a/tensorflow/python/framework/tensor_spec_test.py b/tensorflow/python/framework/tensor_spec_test.py new file mode 100644 index 0000000..54ca4d9 --- /dev/null +++ b/tensorflow/python/framework/tensor_spec_test.py @@ -0,0 +1,227 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensor_spec.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import googletest + + +class TensorSpecTest(test_util.TensorFlowTestCase): + + def testAcceptsNumpyDType(self): + desc = tensor_spec.TensorSpec([1], np.float32) + self.assertEqual(desc.dtype, dtypes.float32) + + def testAcceptsTensorShape(self): + desc = tensor_spec.TensorSpec(tensor_shape.TensorShape([1]), dtypes.float32) + self.assertEqual(desc.shape, tensor_shape.TensorShape([1])) + + def testUnknownShape(self): + desc = tensor_spec.TensorSpec(shape=None, dtype=dtypes.float32) + self.assertEqual(desc.shape, tensor_shape.TensorShape(None)) + + def testShapeCompatibility(self): + unknown = array_ops.placeholder(dtypes.int64) + partial = array_ops.placeholder(dtypes.int64, shape=[None, 1]) + full = array_ops.placeholder(dtypes.int64, shape=[2, 3]) + rank3 = array_ops.placeholder(dtypes.int64, shape=[4, 5, 6]) + + desc_unknown = tensor_spec.TensorSpec(None, dtypes.int64) + self.assertTrue(desc_unknown.is_compatible_with(unknown)) + self.assertTrue(desc_unknown.is_compatible_with(partial)) + self.assertTrue(desc_unknown.is_compatible_with(full)) + self.assertTrue(desc_unknown.is_compatible_with(rank3)) + + desc_partial = tensor_spec.TensorSpec([2, None], dtypes.int64) + self.assertTrue(desc_partial.is_compatible_with(unknown)) + self.assertTrue(desc_partial.is_compatible_with(partial)) + self.assertTrue(desc_partial.is_compatible_with(full)) + self.assertFalse(desc_partial.is_compatible_with(rank3)) + + desc_full = tensor_spec.TensorSpec([2, 3], dtypes.int64) + self.assertTrue(desc_full.is_compatible_with(unknown)) + self.assertFalse(desc_full.is_compatible_with(partial)) + self.assertTrue(desc_full.is_compatible_with(full)) + self.assertFalse(desc_full.is_compatible_with(rank3)) + + desc_rank3 = tensor_spec.TensorSpec([4, 5, 6], dtypes.int64) + self.assertTrue(desc_rank3.is_compatible_with(unknown)) + self.assertFalse(desc_rank3.is_compatible_with(partial)) + self.assertFalse(desc_rank3.is_compatible_with(full)) + self.assertTrue(desc_rank3.is_compatible_with(rank3)) + + def testTypeCompatibility(self): + floats = array_ops.placeholder(dtypes.float32, shape=[10, 10]) + ints = array_ops.placeholder(dtypes.int32, shape=[10, 10]) + desc = tensor_spec.TensorSpec(shape=(10, 10), dtype=dtypes.float32) + self.assertTrue(desc.is_compatible_with(floats)) + self.assertFalse(desc.is_compatible_with(ints)) + + def testName(self): + desc = tensor_spec.TensorSpec([1], dtypes.float32, name="beep") + self.assertEqual(desc.name, "beep") + + def testRepr(self): + desc1 = tensor_spec.TensorSpec([1], dtypes.float32, name="beep") + self.assertEqual( + repr(desc1), + "TensorSpec(shape=(1,), dtype=tf.float32, name='beep')") + desc2 = tensor_spec.TensorSpec([1, None], dtypes.int32) + self.assertEqual( + repr(desc2), + "TensorSpec(shape=(1, ?), dtype=tf.int32, name=None)") + + def testFromTensorSpec(self): + spec_1 = tensor_spec.TensorSpec((1, 2), dtypes.int32) + spec_2 = tensor_spec.TensorSpec.from_spec(spec_1) + self.assertEqual(spec_1, spec_2) + + def testFromTensor(self): + zero = constant_op.constant(0) + spec = tensor_spec.TensorSpec.from_tensor(zero) + self.assertEqual(spec.dtype, dtypes.int32) + self.assertEqual(spec.shape, []) + self.assertEqual(spec.name, "Const") + + def testFromPlaceholder(self): + unknown = array_ops.placeholder(dtypes.int64, name="unknown") + partial = array_ops.placeholder(dtypes.float32, + shape=[None, 1], + name="partial") + spec_1 = tensor_spec.TensorSpec.from_tensor(unknown) + self.assertEqual(spec_1.dtype, dtypes.int64) + self.assertEqual(spec_1.shape, None) + self.assertEqual(spec_1.name, "unknown") + spec_2 = tensor_spec.TensorSpec.from_tensor(partial) + self.assertEqual(spec_2.dtype, dtypes.float32) + self.assertEqual(spec_2.shape.as_list(), [None, 1]) + self.assertEqual(spec_2.name, "partial") + + def testFromBoundedTensorSpec(self): + bounded_spec = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, 0, 1) + spec = tensor_spec.TensorSpec.from_spec(bounded_spec) + self.assertEqual(bounded_spec.shape, spec.shape) + self.assertEqual(bounded_spec.dtype, spec.dtype) + self.assertEqual(bounded_spec.name, spec.name) + + +class BoundedTensorSpecTest(test_util.TensorFlowTestCase): + + def testInvalidMinimum(self): + with self.assertRaisesRegexp(ValueError, "not compatible"): + tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, (0, 0, 0), (1, 1)) + + def testInvalidMaximum(self): + with self.assertRaisesRegexp(ValueError, "not compatible"): + tensor_spec.BoundedTensorSpec((3, 5), dtypes.uint8, 0, (1, 1, 1)) + + def testMinimumMaximumAttributes(self): + spec = tensor_spec.BoundedTensorSpec( + (1, 2, 3), dtypes.float32, 0, (5, 5, 5)) + self.assertEqual(type(spec.minimum), np.ndarray) + self.assertEqual(type(spec.maximum), np.ndarray) + self.assertAllEqual(spec.minimum, np.array(0, dtype=np.float32)) + self.assertAllEqual(spec.maximum, np.array([5, 5, 5], dtype=np.float32)) + + def testNotWriteableNP(self): + spec = tensor_spec.BoundedTensorSpec( + (1, 2, 3), dtypes.float32, 0, (5, 5, 5)) + with self.assertRaisesRegexp(ValueError, "read-only"): + spec.minimum[0] = -1 + with self.assertRaisesRegexp(ValueError, "read-only"): + spec.maximum[0] = 100 + + def testReuseSpec(self): + spec_1 = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, + minimum=0, maximum=1) + spec_2 = tensor_spec.BoundedTensorSpec( + spec_1.shape, spec_1.dtype, spec_1.minimum, spec_1.maximum) + self.assertEqual(spec_1, spec_2) + + def testScalarBounds(self): + spec = tensor_spec.BoundedTensorSpec( + (), dtypes.float32, minimum=0.0, maximum=1.0) + + self.assertIsInstance(spec.minimum, np.ndarray) + self.assertIsInstance(spec.maximum, np.ndarray) + + # Sanity check that numpy compares correctly to a scalar for an empty shape. + self.assertEqual(0.0, spec.minimum) + self.assertEqual(1.0, spec.maximum) + + # Check that the spec doesn't fail its own input validation. + _ = tensor_spec.BoundedTensorSpec( + spec.shape, spec.dtype, spec.minimum, spec.maximum) + + def testFromBoundedTensorSpec(self): + spec_1 = tensor_spec.BoundedTensorSpec((1, 2), dtypes.int32, + minimum=0, maximum=1) + spec_2 = tensor_spec.BoundedTensorSpec.from_spec(spec_1) + self.assertEqual(spec_1, spec_2) + + def testEquality(self): + spec_1_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 0, (5, 5, 5)) + spec_1_2 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 0.00000001, + (5, 5, 5.00000000000000001)) + spec_2_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + 1, (5, 5, 5)) + spec_2_2 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (1, 1, 1), (5, 5, 5)) + spec_2_3 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (1, 1, 1), 5) + spec_3_1 = tensor_spec.BoundedTensorSpec((1, 2, 3), dtypes.float32, + (2, 1, 1), (5, 5, 5)) + + self.assertEqual(spec_1_1, spec_1_2) + self.assertEqual(spec_1_2, spec_1_1) + + self.assertNotEqual(spec_1_1, spec_2_2) + self.assertNotEqual(spec_1_1, spec_2_1) + self.assertNotEqual(spec_2_2, spec_1_1) + self.assertNotEqual(spec_2_1, spec_1_1) + + self.assertEqual(spec_2_1, spec_2_2) + self.assertEqual(spec_2_2, spec_2_1) + self.assertEqual(spec_2_2, spec_2_3) + + self.assertNotEqual(spec_1_1, spec_3_1) + self.assertNotEqual(spec_2_1, spec_3_1) + self.assertNotEqual(spec_2_2, spec_3_1) + + def testFromTensorSpec(self): + spec = tensor_spec.TensorSpec((1, 2), dtypes.int32) + bounded_spec = tensor_spec.BoundedTensorSpec.from_spec(spec) + self.assertEqual(spec.shape, bounded_spec.shape) + self.assertEqual(spec.dtype, bounded_spec.dtype) + self.assertEqual(spec.dtype.min, bounded_spec.minimum) + self.assertEqual(spec.dtype.max, bounded_spec.maximum) + self.assertEqual(spec.name, bounded_spec.name) + + +if __name__ == "__main__": + googletest.main()