Added type check to feature column keys. So that users will get meaningful error...
authorMustafa Ispir <ispir@google.com>
Tue, 15 May 2018 05:04:50 +0000 (22:04 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 15 May 2018 05:07:42 +0000 (22:07 -0700)
PiperOrigin-RevId: 196616638

tensorflow/python/feature_column/feature_column.py
tensorflow/python/feature_column/feature_column_test.py

index c16c3cd..1d50892 100644 (file)
@@ -1068,6 +1068,7 @@ def numeric_column(key,
     raise TypeError(
         'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
 
+  _assert_key_is_string(key)
   return _NumericColumn(
       key,
       shape=shape,
@@ -1166,6 +1167,13 @@ def _assert_string_or_int(dtype, prefix):
         '{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
 
 
+def _assert_key_is_string(key):
+  if not isinstance(key, six.string_types):
+    raise ValueError(
+        'key must be a string. Got: type {}. Given key: {}.'.format(
+            type(key), key))
+
+
 @tf_export('feature_column.categorical_column_with_hash_bucket')
 def categorical_column_with_hash_bucket(key,
                                         hash_bucket_size,
@@ -1218,6 +1226,7 @@ def categorical_column_with_hash_bucket(key,
                      'hash_bucket_size: {}, key: {}'.format(
                          hash_bucket_size, key))
 
+  _assert_key_is_string(key)
   _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
 
   return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
@@ -1334,6 +1343,7 @@ def categorical_column_with_vocabulary_file(key,
       raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
           num_oov_buckets, key))
   _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+  _assert_key_is_string(key)
   return _VocabularyFileCategoricalColumn(
       key=key,
       vocabulary_file=vocabulary_file,
@@ -1448,6 +1458,7 @@ def categorical_column_with_vocabulary_list(
         'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
             dtype, vocabulary_dtype, key))
   _assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+  _assert_key_is_string(key)
 
   return _VocabularyListCategoricalColumn(
       key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
@@ -1518,6 +1529,7 @@ def categorical_column_with_identity(key, num_buckets, default_value=None):
     raise ValueError(
         'default_value {} not in range [0, {}), column_name {}'.format(
             default_value, num_buckets, key))
+  _assert_key_is_string(key)
   return _IdentityCategoricalColumn(
       key=key, num_buckets=num_buckets, default_value=default_value)
 
index b065404..03c47ee 100644 (file)
@@ -182,6 +182,10 @@ class NumericColumnTest(test.TestCase):
     self.assertEqual(dtypes.float32, a.dtype)
     self.assertIsNone(a.normalizer_fn)
 
+  def test_key_should_be_string(self):
+    with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+      fc.numeric_column(key=('aaa',))
+
   def test_shape_saved_as_tuple(self):
     a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
     self.assertEqual((1, 2), a.shape)
@@ -645,6 +649,10 @@ class HashedCategoricalColumnTest(test.TestCase):
     self.assertEqual(10, a.hash_bucket_size)
     self.assertEqual(dtypes.string, a.dtype)
 
+  def test_key_should_be_string(self):
+    with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+      fc.categorical_column_with_hash_bucket(('key',), 10)
+
   def test_bucket_size_should_be_given(self):
     with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'):
       fc.categorical_column_with_hash_bucket('aaa', None)
@@ -3327,6 +3335,11 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
         'aaa': parsing_ops.VarLenFeature(dtypes.string)
     }, column._parse_example_spec)
 
+  def test_key_should_be_string(self):
+    with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+      fc.categorical_column_with_vocabulary_file(
+          key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
+
   def test_all_constructor_args(self):
     column = fc.categorical_column_with_vocabulary_file(
         key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
@@ -3752,6 +3765,11 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
         'aaa': parsing_ops.VarLenFeature(dtypes.string)
     }, column._parse_example_spec)
 
+  def test_key_should_be_string(self):
+    with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+      fc.categorical_column_with_vocabulary_list(
+          key=('aaa',), vocabulary_list=('omar', 'stringer', 'marlo'))
+
   def test_defaults_int(self):
     column = fc.categorical_column_with_vocabulary_list(
         key='aaa', vocabulary_list=(12, 24, 36))
@@ -4143,6 +4161,10 @@ class IdentityCategoricalColumnTest(test.TestCase):
         'aaa': parsing_ops.VarLenFeature(dtypes.int64)
     }, column._parse_example_spec)
 
+  def test_key_should_be_string(self):
+    with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+      fc.categorical_column_with_identity(key=('aaa',), num_buckets=3)
+
   def test_deep_copy(self):
     original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
     for column in (original, copy.deepcopy(original)):