raise TypeError(
'normalizer_fn must be a callable. Given: {}'.format(normalizer_fn))
+ _assert_key_is_string(key)
return _NumericColumn(
key,
shape=shape,
'{} dtype must be string or integer. dtype: {}.'.format(prefix, dtype))
+def _assert_key_is_string(key):
+ if not isinstance(key, six.string_types):
+ raise ValueError(
+ 'key must be a string. Got: type {}. Given key: {}.'.format(
+ type(key), key))
+
+
@tf_export('feature_column.categorical_column_with_hash_bucket')
def categorical_column_with_hash_bucket(key,
hash_bucket_size,
'hash_bucket_size: {}, key: {}'.format(
hash_bucket_size, key))
+ _assert_key_is_string(key)
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
return _HashedCategoricalColumn(key, hash_bucket_size, dtype)
raise ValueError('Invalid num_oov_buckets {} in {}.'.format(
num_oov_buckets, key))
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
return _VocabularyFileCategoricalColumn(
key=key,
vocabulary_file=vocabulary_file,
'dtype {} and vocabulary dtype {} do not match, column_name: {}'.format(
dtype, vocabulary_dtype, key))
_assert_string_or_int(dtype, prefix='column_name: {}'.format(key))
+ _assert_key_is_string(key)
return _VocabularyListCategoricalColumn(
key=key, vocabulary_list=tuple(vocabulary_list), dtype=dtype,
raise ValueError(
'default_value {} not in range [0, {}), column_name {}'.format(
default_value, num_buckets, key))
+ _assert_key_is_string(key)
return _IdentityCategoricalColumn(
key=key, num_buckets=num_buckets, default_value=default_value)
self.assertEqual(dtypes.float32, a.dtype)
self.assertIsNone(a.normalizer_fn)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.numeric_column(key=('aaa',))
+
def test_shape_saved_as_tuple(self):
a = fc.numeric_column('aaa', shape=[1, 2], default_value=[[3, 2.]])
self.assertEqual((1, 2), a.shape)
self.assertEqual(10, a.hash_bucket_size)
self.assertEqual(dtypes.string, a.dtype)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_hash_bucket(('key',), 10)
+
def test_bucket_size_should_be_given(self):
with self.assertRaisesRegexp(ValueError, 'hash_bucket_size must be set.'):
fc.categorical_column_with_hash_bucket('aaa', None)
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_file(
+ key=('aaa',), vocabulary_file='path_to_file', vocabulary_size=3)
+
def test_all_constructor_args(self):
column = fc.categorical_column_with_vocabulary_file(
key='aaa', vocabulary_file='path_to_file', vocabulary_size=3,
'aaa': parsing_ops.VarLenFeature(dtypes.string)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_vocabulary_list(
+ key=('aaa',), vocabulary_list=('omar', 'stringer', 'marlo'))
+
def test_defaults_int(self):
column = fc.categorical_column_with_vocabulary_list(
key='aaa', vocabulary_list=(12, 24, 36))
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
}, column._parse_example_spec)
+ def test_key_should_be_string(self):
+ with self.assertRaisesRegexp(ValueError, 'key must be a string.'):
+ fc.categorical_column_with_identity(key=('aaa',), num_buckets=3)
+
def test_deep_copy(self):
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
for column in (original, copy.deepcopy(original)):