Adds remaining validations in sequence_numeric_column.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 22 Mar 2018 18:12:10 +0000 (11:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 22 Mar 2018 18:14:40 +0000 (11:14 -0700)
PiperOrigin-RevId: 190094883

tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column.py
tensorflow/contrib/feature_column/python/feature_column/sequence_feature_column_test.py

index e601169..555bedd 100644 (file)
@@ -166,6 +166,10 @@ def sequence_categorical_column_with_identity(
 
   Returns:
     A `_SequenceCategoricalColumn`.
+
+  Raises:
+    ValueError: if `num_buckets` is less than one.
+    ValueError: if `default_value` is not in range `[0, num_buckets)`.
   """
   return fc._SequenceCategoricalColumn(
       fc.categorical_column_with_identity(
@@ -205,6 +209,10 @@ def sequence_categorical_column_with_hash_bucket(
 
   Returns:
     A `_SequenceCategoricalColumn`.
+
+  Raises:
+    ValueError: `hash_bucket_size` is not greater than 1.
+    ValueError: `dtype` is neither string nor integer.
   """
   return fc._SequenceCategoricalColumn(
       fc.categorical_column_with_hash_bucket(
@@ -257,6 +265,13 @@ def sequence_categorical_column_with_vocabulary_file(
 
   Returns:
     A `_SequenceCategoricalColumn`.
+
+  Raises:
+    ValueError: `vocabulary_file` is missing or cannot be opened.
+    ValueError: `vocabulary_size` is missing or < 1.
+    ValueError: `num_oov_buckets` is a negative integer.
+    ValueError: `num_oov_buckets` and `default_value` are both specified.
+    ValueError: `dtype` is neither string nor integer.
   """
   return fc._SequenceCategoricalColumn(
       fc.categorical_column_with_vocabulary_file(
@@ -311,6 +326,12 @@ def sequence_categorical_column_with_vocabulary_list(
 
   Returns:
     A `_SequenceCategoricalColumn`.
+
+  Raises:
+    ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
+    ValueError: `num_oov_buckets` is a negative integer.
+    ValueError: `num_oov_buckets` and `default_value` are both specified.
+    ValueError: if `dtype` is not integer or string.
   """
   return fc._SequenceCategoricalColumn(
       fc.categorical_column_with_vocabulary_list(
@@ -352,8 +373,17 @@ def sequence_numeric_column(
 
   Returns:
     A `_SequenceNumericColumn`.
+
+  Raises:
+    TypeError: if any dimension in shape is not an int.
+    ValueError: if any dimension in shape is not a positive integer.
+    ValueError: if `dtype` is not convertible to `tf.float32`.
   """
-  # TODO(b/73160931): Add validations.
+  shape = fc._check_shape(shape=shape, key=key)
+  if not (dtype.is_integer or dtype.is_floating):
+    raise ValueError('dtype must be convertible to float. '
+                     'dtype: {}, key: {}'.format(dtype, key))
+
   return _SequenceNumericColumn(
       key,
       shape=shape,
index b64f086..88f5d53 100644 (file)
@@ -662,6 +662,32 @@ class SequenceIndicatorColumnTest(test.TestCase):
 
 class SequenceNumericColumnTest(test.TestCase):
 
+  def test_defaults(self):
+    a = sfc.sequence_numeric_column('aaa')
+    self.assertEqual('aaa', a.key)
+    self.assertEqual('aaa', a.name)
+    self.assertEqual('aaa', a._var_scope_name)
+    self.assertEqual((1,), a.shape)
+    self.assertEqual(0., a.default_value)
+    self.assertEqual(dtypes.float32, a.dtype)
+
+  def test_shape_saved_as_tuple(self):
+    a = sfc.sequence_numeric_column('aaa', shape=[1, 2])
+    self.assertEqual((1, 2), a.shape)
+
+  def test_shape_must_be_positive_integer(self):
+    with self.assertRaisesRegexp(TypeError, 'shape dimensions must be integer'):
+      sfc.sequence_numeric_column('aaa', shape=[1.0])
+
+    with self.assertRaisesRegexp(
+        ValueError, 'shape dimensions must be greater than 0'):
+      sfc.sequence_numeric_column('aaa', shape=[0])
+
+  def test_dtype_is_convertible_to_float(self):
+    with self.assertRaisesRegexp(
+        ValueError, 'dtype must be convertible to float'):
+      sfc.sequence_numeric_column('aaa', dtype=dtypes.string)
+
   def test_get_sequence_dense_tensor(self):
     sparse_input = sparse_tensor.SparseTensorValue(
         # example 0, values [[0.], [1]]