From 3a6ca86801093895eb37155b1fe09e43cd4c9f80 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Thu, 22 Mar 2018 14:14:33 -0700 Subject: [PATCH] Update tf.keras preprocessing to Keras 2.1.5 API PiperOrigin-RevId: 190123773 --- .../keras/_impl/keras/preprocessing/image.py | 211 ++++++++++++++----- .../keras/_impl/keras/preprocessing/image_test.py | 151 +++++++++++++- .../keras/_impl/keras/preprocessing/sequence.py | 232 +++++++++++++++++---- .../_impl/keras/preprocessing/sequence_test.py | 82 +++++++- .../python/keras/_impl/keras/preprocessing/text.py | 31 ++- .../keras/_impl/keras/preprocessing/text_test.py | 42 +++- .../python/keras/preprocessing/image/__init__.py | 1 + .../keras/preprocessing/sequence/__init__.py | 1 + .../python/keras/preprocessing/text/__init__.py | 1 + ...s.preprocessing.image.-directory-iterator.pbtxt | 2 +- ...preprocessing.image.-image-data-generator.pbtxt | 6 +- ...preprocessing.image.-numpy-array-iterator.pbtxt | 2 +- .../tensorflow.keras.preprocessing.image.pbtxt | 4 + ...processing.sequence.-timeseries-generator.pbtxt | 14 ++ .../tensorflow.keras.preprocessing.sequence.pbtxt | 4 + .../tensorflow.keras.preprocessing.text.pbtxt | 4 + 16 files changed, 665 insertions(+), 123 deletions(-) create mode 100644 tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image.py b/tensorflow/python/keras/_impl/keras/preprocessing/image.py index d12f108..6299445 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image.py @@ -43,6 +43,7 @@ except ImportError: try: + from PIL import ImageEnhance from PIL import Image as pil_image except ImportError: pil_image = None @@ -227,6 +228,32 @@ def random_channel_shift(x, intensity, channel_axis=0): return x +@tf_export('keras.preprocessing.image.random_brightness') +def random_brightness(x, brightness_range): + """Performs a random adjustment of brightness of a Numpy image tensor. + + Arguments: + x: Input tensor. Must be 3D. + brightness_range: Tuple of floats; range to pick a brightness value from. + + Returns: + Brightness adjusted Numpy image tensor. + + Raises: + ValueError: if `brightness_range` isn't a tuple. + """ + if len(brightness_range) != 2: + raise ValueError('`brightness_range should be tuple or list of two floats. ' + 'Received arg: ', brightness_range) + + x = array_to_img(x) + x = ImageEnhance.Brightness(x) + u = np.random.uniform(brightness_range[0], brightness_range[1]) + x = x.enhance(u) + x = img_to_array(x) + return x + + def transform_matrix_offset_center(matrix, x, y): o_x = float(x) / 2 + 0.5 o_y = float(y) / 2 + 0.5 @@ -265,7 +292,7 @@ def apply_transform(x, x_channel, final_affine_matrix, final_offset, - order=0, + order=1, mode=fill_mode, cval=cval) for x_channel in x ] @@ -436,6 +463,7 @@ class ImageDataGenerator(object): rotation_range: degrees (0 to 180). width_shift_range: fraction of total width, if < 1, or pixels if >= 1. height_shift_range: fraction of total height, if < 1, or pixels if >= 1. + brightness_range: the range of brightness to apply shear_range: shear intensity (shear angle in degrees). zoom_range: amount of zoom. if scalar z, zoom will be randomly picked in the range [1-z, 1+z]. A sequence of two can be passed instead @@ -469,6 +497,8 @@ class ImageDataGenerator(object): It defaults to the `image_data_format` value found in your Keras config file at `~/.keras/keras.json`. If you never set it, then it will be "channels_last". + validation_split: fraction of images reserved for validation (strictly + between 0 and 1). """ def __init__(self, @@ -481,6 +511,7 @@ class ImageDataGenerator(object): rotation_range=0., width_shift_range=0., height_shift_range=0., + brightness_range=None, shear_range=0., zoom_range=0., channel_shift_range=0., @@ -490,7 +521,8 @@ class ImageDataGenerator(object): vertical_flip=False, rescale=None, preprocessing_function=None, - data_format=None): + data_format=None, + validation_split=0.0): if data_format is None: data_format = K.image_data_format() self.featurewise_center = featurewise_center @@ -502,6 +534,7 @@ class ImageDataGenerator(object): self.rotation_range = rotation_range self.width_shift_range = width_shift_range self.height_shift_range = height_shift_range + self.brightness_range = brightness_range self.shear_range = shear_range self.zoom_range = zoom_range self.channel_shift_range = channel_shift_range @@ -526,6 +559,10 @@ class ImageDataGenerator(object): self.channel_axis = 3 self.row_axis = 1 self.col_axis = 2 + if validation_split and not 0 < validation_split < 1: + raise ValueError('`validation_split` must be strictly between 0 and 1. ' + 'Received arg: ', validation_split) + self.validation_split = validation_split self.mean = None self.std = None @@ -574,7 +611,8 @@ class ImageDataGenerator(object): seed=None, save_to_dir=None, save_prefix='', - save_format='png'): + save_format='png', + subset=None): return NumpyArrayIterator( x, y, @@ -585,7 +623,8 @@ class ImageDataGenerator(object): data_format=self.data_format, save_to_dir=save_to_dir, save_prefix=save_prefix, - save_format=save_format) + save_format=save_format, + subset=subset) def flow_from_directory(self, directory, @@ -600,6 +639,7 @@ class ImageDataGenerator(object): save_prefix='', save_format='png', follow_links=False, + subset=None, interpolation='nearest'): return DirectoryIterator( directory, @@ -616,6 +656,7 @@ class ImageDataGenerator(object): save_prefix=save_prefix, save_format=save_format, follow_links=follow_links, + subset=subset, interpolation=interpolation) def standardize(self, x): @@ -628,7 +669,7 @@ class ImageDataGenerator(object): The inputs, normalized. """ if self.preprocessing_function: - x = self.preprocessing_function(x) + x = self.image_data_generator.preprocessing_function(x) if self.rescale: x *= self.rescale if self.samplewise_center: @@ -762,6 +803,9 @@ class ImageDataGenerator(object): if np.random.random() < 0.5: x = flip_axis(x, img_row_axis) + if self.brightness_range is not None: + x = random_brightness(x, self.brightness_range) + return x def fit(self, x, augment=False, rounds=1, seed=None): @@ -828,12 +872,10 @@ class ImageDataGenerator(object): raise ImportError('Scipy is required for zca_whitening.') flat_x = np.reshape(x, (x.shape[0], x.shape[1] * x.shape[2] * x.shape[3])) - num_examples = flat_x.shape[0] - _, s, vt = linalg.svd(flat_x / np.sqrt(num_examples)) - s_expand = np.hstack( - (s, np.zeros(vt.shape[0] - num_examples, dtype=flat_x.dtype))) - self.principal_components = ( - vt.T / np.sqrt(s_expand**2 + self.zca_epsilon)).dot(vt) + sigma = np.dot(flat_x.T, flat_x) / flat_x.shape[0] + u, s, _ = linalg.svd(sigma) + s_inv = 1. / np.sqrt(s[np.newaxis] + self.zca_epsilon) + self.principal_components = (u * s_inv).dot(u.T) @tf_export('keras.preprocessing.image.Iterator') @@ -947,6 +989,8 @@ class NumpyArrayIterator(Iterator): images (if `save_to_dir` is set). save_format: Format to use for saving sample images (if `save_to_dir` is set). + subset: Subset of data (`"training"` or `"validation"`) if + validation_split is set in ImageDataGenerator. """ def __init__(self, @@ -959,17 +1003,29 @@ class NumpyArrayIterator(Iterator): data_format=None, save_to_dir=None, save_prefix='', - save_format='png'): + save_format='png', + subset=None): if y is not None and len(x) != len(y): - raise ValueError('X (images tensor) and y (labels) ' + raise ValueError('`x` (images tensor) and `y` (labels) ' 'should have the same length. ' - 'Found: X.shape = %s, y.shape = %s' % + 'Found: x.shape = %s, y.shape = %s' % (np.asarray(x).shape, np.asarray(y).shape)) - + if subset is not None: + if subset not in {'training', 'validation'}: + raise ValueError('Invalid subset name:', subset, + '; expected "training" or "validation".') + split_idx = int(len(x) * image_data_generator.validation_split) + if subset == 'validation': + x = x[:split_idx] + if y is not None: + y = y[:split_idx] + else: + x = x[split_idx:] + if y is not None: + y = y[split_idx:] if data_format is None: data_format = K.image_data_format() self.x = np.asarray(x, dtype=K.floatx()) - if self.x.ndim != 4: raise ValueError('Input data in `NumpyArrayIterator` ' 'should have rank 4. You passed an array ' @@ -1032,8 +1088,7 @@ class NumpyArrayIterator(Iterator): return self._get_batches_of_transformed_samples(index_array) -def _count_valid_files_in_directory(directory, white_list_formats, - follow_links): +def _iter_valid_files(directory, white_list_formats, follow_links): """Count files with extension in `white_list_formats` contained in directory. Arguments: @@ -1043,29 +1098,54 @@ def _count_valid_files_in_directory(directory, white_list_formats, the files to be counted. follow_links: boolean. - Returns: - the count of files with extension in `white_list_formats` contained in - the directory. + Yields: + tuple of (root, filename) with extension in `white_list_formats`. """ def _recursive_list(subpath): return sorted( - os.walk(subpath, followlinks=follow_links), key=lambda tpl: tpl[0]) + os.walk(subpath, followlinks=follow_links), key=lambda x: x[0]) - samples = 0 - for _, _, files in _recursive_list(directory): - for fname in files: - is_valid = False + for root, _, files in _recursive_list(directory): + for fname in sorted(files): for extension in white_list_formats: + if fname.lower().endswith('.tiff'): + logging.warning( + 'Using \'.tiff\' files with multiple bands will cause ' + 'distortion. Please verify your output.') if fname.lower().endswith('.' + extension): - is_valid = True - break - if is_valid: - samples += 1 - return samples + yield root, fname -def _list_valid_filenames_in_directory(directory, white_list_formats, +def _count_valid_files_in_directory(directory, white_list_formats, split, + follow_links): + """Count files with extension in `white_list_formats` contained in directory. + + Arguments: + directory: absolute path to the directory + containing files to be counted + white_list_formats: set of strings containing allowed extensions for + the files to be counted. + split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into + account a certain fraction of files in each directory. + E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent + of images in each directory. + follow_links: boolean. + + Returns: + the count of files with extension in `white_list_formats` contained in + the directory. + """ + num_files = len( + list(_iter_valid_files(directory, white_list_formats, follow_links))) + if split: + start, stop = int(split[0] * num_files), int(split[1] * num_files) + else: + start, stop = 0, num_files + return stop - start + + +def _list_valid_filenames_in_directory(directory, white_list_formats, split, class_indices, follow_links): """List paths of files in `subdir` with extensions in `white_list_formats`. @@ -1075,6 +1155,10 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, `class_indices`. white_list_formats: set of strings containing allowed extensions for the files to be counted. + split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into + account a certain fraction of files in each directory. + E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent + of images in each directory. class_indices: dictionary mapping a class name to its index. follow_links: boolean. @@ -1084,27 +1168,26 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, `directory`'s parent (e.g., if `directory` is "dataset/class1", the filenames will be ["class1/file1.jpg", "class1/file2.jpg", ...]). """ - - def _recursive_list(subpath): - return sorted( - os.walk(subpath, followlinks=follow_links), key=lambda tpl: tpl[0]) + dirname = os.path.basename(directory) + if split: + num_files = len( + list(_iter_valid_files(directory, white_list_formats, follow_links))) + start, stop = int(split[0] * num_files), int(split[1] * num_files) + valid_files = list( + _iter_valid_files(directory, white_list_formats, + follow_links))[start:stop] + else: + valid_files = _iter_valid_files(directory, white_list_formats, follow_links) classes = [] filenames = [] - subdir = os.path.basename(directory) - basedir = os.path.dirname(directory) - for root, _, files in _recursive_list(directory): - for fname in sorted(files): - is_valid = False - for extension in white_list_formats: - if fname.lower().endswith('.' + extension): - is_valid = True - break - if is_valid: - classes.append(class_indices[subdir]) - # add filename relative to directory - absolute_path = os.path.join(root, fname) - filenames.append(os.path.relpath(absolute_path, basedir)) + for root, fname in valid_files: + classes.append(class_indices[dirname]) + absolute_path = os.path.join(root, fname) + relative_path = os.path.join(dirname, + os.path.relpath(absolute_path, directory)) + filenames.append(relative_path) + return classes, filenames @@ -1144,6 +1227,8 @@ class DirectoryIterator(Iterator): images (if `save_to_dir` is set). save_format: Format to use for saving sample images (if `save_to_dir` is set). + subset: Subset of data (`"training"` or `"validation"`) if + validation_split is set in ImageDataGenerator. interpolation: Interpolation method used to resample the image if the target size is different from that of the loaded image. Supported methods are "nearest", "bilinear", and "bicubic". @@ -1167,6 +1252,7 @@ class DirectoryIterator(Iterator): save_prefix='', save_format='png', follow_links=False, + subset=None, interpolation='nearest'): if data_format is None: data_format = K.image_data_format() @@ -1200,7 +1286,20 @@ class DirectoryIterator(Iterator): self.save_format = save_format self.interpolation = interpolation - white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm'} + if subset is not None: + validation_split = self.image_data_generator.validation_split + if subset == 'validation': + split = (0, validation_split) + elif subset == 'training': + split = (validation_split, 1) + else: + raise ValueError('Invalid subset name: ', subset, + '; expected "training" or "validation"') + else: + split = None + self.subset = subset + + white_list_formats = {'png', 'jpg', 'jpeg', 'bmp', 'ppm', 'tif', 'tiff'} # first, count the number of samples and classes self.samples = 0 @@ -1217,7 +1316,8 @@ class DirectoryIterator(Iterator): function_partial = partial( _count_valid_files_in_directory, white_list_formats=white_list_formats, - follow_links=follow_links) + follow_links=follow_links, + split=split) self.samples = sum( pool.map(function_partial, (os.path.join(directory, subdir) for subdir in classes))) @@ -1233,14 +1333,15 @@ class DirectoryIterator(Iterator): i = 0 for dirpath in (os.path.join(directory, subdir) for subdir in classes): results.append( - pool.apply_async( - _list_valid_filenames_in_directory, - (dirpath, white_list_formats, self.class_indices, follow_links))) + pool.apply_async(_list_valid_filenames_in_directory, + (dirpath, white_list_formats, split, + self.class_indices, follow_links))) for res in results: classes, filenames = res.get() self.classes[i:i + len(classes)] = classes self.filenames += filenames i += len(classes) + pool.close() pool.join() super(DirectoryIterator, self).__init__(self.samples, batch_size, shuffle, diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py index c0790b5..001fee9 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os import shutil +import tempfile import numpy as np @@ -74,6 +75,7 @@ class TestImage(test.TestCase): shear_range=0.5, zoom_range=0.2, channel_shift_range=0., + brightness_range=(1, 5), fill_mode='nearest', cval=0.5, horizontal_flip=True, @@ -92,6 +94,47 @@ class TestImage(test.TestCase): self.assertEqual(x.shape[1:], images.shape[1:]) break + def test_image_data_generator_with_validation_split(self): + if PIL is None: + return # Skip test if PIL is not available. + + for test_images in _generate_test_images(): + img_list = [] + for im in test_images: + img_list.append(keras.preprocessing.image.img_to_array(im)[None, ...]) + + images = np.vstack(img_list) + generator = keras.preprocessing.image.ImageDataGenerator( + validation_split=0.5) + seq = generator.flow( + images, + np.arange(images.shape[0]), + shuffle=False, + batch_size=3, + subset='validation') + _, y = seq[0] + self.assertEqual(list(y), [0, 1, 2]) + seq = generator.flow( + images, + np.arange(images.shape[0]), + shuffle=False, + batch_size=3, + subset='training') + _, y2 = seq[0] + self.assertEqual(list(y2), [4, 5, 6]) + + with self.assertRaises(ValueError): + generator.flow( + images, + np.arange(images.shape[0]), + shuffle=False, + batch_size=3, + subset='foo') + + def test_image_data_generator_with_split_value_error(self): + with self.assertRaises(ValueError): + keras.preprocessing.image.ImageDataGenerator(validation_split=5) + def test_image_data_generator_invalid_data(self): generator = keras.preprocessing.image.ImageDataGenerator( featurewise_center=True, @@ -202,9 +245,80 @@ class TestImage(test.TestCase): # check number of classes and images self.assertEqual(len(dir_iterator.class_indices), num_classes) self.assertEqual(len(dir_iterator.classes), count) - self.assertEqual(sorted(dir_iterator.filenames), sorted(filenames)) + self.assertEqual(set(dir_iterator.filenames), set(filenames)) _ = dir_iterator.next() + def directory_iterator_with_validation_split_test_helper( + self, validation_split): + if PIL is None: + return # Skip test if PIL is not available. + + num_classes = 2 + tmp_folder = tempfile.mkdtemp(prefix='test_images') + + # create folders and subfolders + paths = [] + for cl in range(num_classes): + class_directory = 'class-{}'.format(cl) + classpaths = [ + class_directory, + os.path.join(class_directory, 'subfolder-1'), + os.path.join(class_directory, 'subfolder-2'), + os.path.join(class_directory, 'subfolder-1', 'sub-subfolder') + ] + for path in classpaths: + os.mkdir(os.path.join(tmp_folder, path)) + paths.append(classpaths) + + # save the images in the paths + count = 0 + filenames = [] + for test_images in _generate_test_images(): + for im in test_images: + # rotate image class + im_class = count % num_classes + # rotate subfolders + classpaths = paths[im_class] + filename = os.path.join(classpaths[count % len(classpaths)], + 'image-{}.jpg'.format(count)) + filenames.append(filename) + im.save(os.path.join(tmp_folder, filename)) + count += 1 + + # create iterator + generator = keras.preprocessing.image.ImageDataGenerator( + validation_split=validation_split) + + with self.assertRaises(ValueError): + generator.flow_from_directory(tmp_folder, subset='foo') + + num_validation = int(count * validation_split) + num_training = count - num_validation + train_iterator = generator.flow_from_directory( + tmp_folder, subset='training') + self.assertEqual(train_iterator.samples, num_training) + + valid_iterator = generator.flow_from_directory( + tmp_folder, subset='validation') + self.assertEqual(valid_iterator.samples, num_validation) + + # check number of classes and images + self.assertEqual(len(train_iterator.class_indices), num_classes) + self.assertEqual(len(train_iterator.classes), num_training) + self.assertEqual( + len(set(train_iterator.filenames) & set(filenames)), num_training) + + shutil.rmtree(tmp_folder) + + def test_directory_iterator_with_validation_split_25_percent(self): + self.directory_iterator_with_validation_split_test_helper(0.25) + + def test_directory_iterator_with_validation_split_40_percent(self): + self.directory_iterator_with_validation_split_test_helper(0.40) + + def test_directory_iterator_with_validation_split_50_percent(self): + self.directory_iterator_with_validation_split_test_helper(0.50) + def test_img_utils(self): if PIL is None: return # Skip test if PIL is not available. @@ -241,6 +355,41 @@ class TestImage(test.TestCase): x = keras.preprocessing.image.img_to_array(img, data_format='channels_last') self.assertEqual(x.shape, (height, width, 1)) + def test_batch_standardize(self): + if PIL is None: + return # Skip test if PIL is not available. + + # ImageDataGenerator.standardize should work on batches + for test_images in _generate_test_images(): + img_list = [] + for im in test_images: + img_list.append(keras.preprocessing.image.img_to_array(im)[None, ...]) + + images = np.vstack(img_list) + generator = keras.preprocessing.image.ImageDataGenerator( + featurewise_center=True, + samplewise_center=True, + featurewise_std_normalization=True, + samplewise_std_normalization=True, + zca_whitening=True, + rotation_range=90., + width_shift_range=0.1, + height_shift_range=0.1, + shear_range=0.5, + zoom_range=0.2, + channel_shift_range=0., + brightness_range=(1, 5), + fill_mode='nearest', + cval=0.5, + horizontal_flip=True, + vertical_flip=True) + generator.fit(images, augment=True) + + transformed = np.copy(images) + for i, im in enumerate(transformed): + transformed[i] = generator.random_transform(im) + transformed = generator.standardize(transformed) + def test_img_transforms(self): x = np.random.random((3, 200, 200)) _ = keras.preprocessing.image.random_rotation(x, 20) diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py index a423d96..e68c171 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence.py @@ -22,6 +22,8 @@ import random import numpy as np from six.moves import range # pylint: disable=redefined-builtin + +from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence from tensorflow.python.util.tf_export import tf_export @@ -32,29 +34,40 @@ def pad_sequences(sequences, padding='pre', truncating='pre', value=0.): - """Pads each sequence to the same length (length of the longest sequence). + """Pads sequences to the same length. + + This function transforms a list of + `num_samples` sequences (lists of integers) + into a 2D Numpy array of shape `(num_samples, num_timesteps)`. + `num_timesteps` is either the `maxlen` argument if provided, + or the length of the longest sequence otherwise. + + Sequences that are shorter than `num_timesteps` + are padded with `value` at the end. - If maxlen is provided, any sequence longer - than maxlen is truncated to maxlen. - Truncation happens off either the beginning (default) or - the end of the sequence. + Sequences longer than `num_timesteps` are truncated + so that they fit the desired length. + The position where padding or truncation happens is determined by + the arguments `padding` and `truncating`, respectively. - Supports post-padding and pre-padding (default). + Pre-padding is the default. Arguments: - sequences: list of lists where each element is a sequence - maxlen: int, maximum length - dtype: type to cast the resulting sequence. - padding: 'pre' or 'post', pad either before or after each sequence. - truncating: 'pre' or 'post', remove values from sequences larger than - maxlen either in the beginning or in the end of the sequence - value: float, value to pad the sequences to the desired value. + sequences: List of lists, where each element is a sequence. + maxlen: Int, maximum length of all sequences. + dtype: Type of the output sequences. + padding: String, 'pre' or 'post': + pad either before or after each sequence. + truncating: String, 'pre' or 'post': + remove values from sequences larger than + `maxlen`, either at the beginning or at the end of the sequences. + value: Float, padding value. Returns: - x: numpy array with dimensions (number_of_sequences, maxlen) + x: Numpy array with shape `(len(sequences), maxlen)` Raises: - ValueError: in case of invalid values for `truncating` or `padding`, + ValueError: In case of invalid values for `truncating` or `padding`, or in case of invalid shape for a `sequences` entry. """ if not hasattr(sequences, '__len__'): @@ -92,10 +105,9 @@ def pad_sequences(sequences, # check `trunc` has expected shape trunc = np.asarray(trunc, dtype=dtype) if trunc.shape[1:] != sample_shape: - raise ValueError( - 'Shape of sample %s of sequence at position %s is different from ' - 'expected shape %s' - % (trunc.shape[1:], idx, sample_shape)) + raise ValueError('Shape of sample %s of sequence at position %s ' + 'is different from expected shape %s' % + (trunc.shape[1:], idx, sample_shape)) if padding == 'post': x[idx, :len(trunc)] = trunc @@ -110,22 +122,26 @@ def pad_sequences(sequences, def make_sampling_table(size, sampling_factor=1e-5): """Generates a word rank-based probabilistic sampling table. - This generates an array where the ith element - is the probability that a word of rank i would be sampled, - according to the sampling distribution used in word2vec. + Used for generating the `sampling_table` argument for `skipgrams`. + `sampling_table[i]` is the probability of sampling + the word i-th most common word in a dataset + (more common words should be sampled less frequently, for balance). - The word2vec formula is: - p(word) = min(1, sqrt(word.frequency/sampling_factor) / - (word.frequency/sampling_factor)) + The sampling probabilities are generated according + to the sampling distribution used in word2vec: + + `p(word) = min(1, sqrt(word_frequency / sampling_factor) / (word_frequency / + sampling_factor))` We assume that the word frequencies follow Zipf's law (s=1) to derive a numerical approximation of frequency(rank): - frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank)) - where gamma is the Euler-Mascheroni constant. + + `frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank))` + where `gamma` is the Euler-Mascheroni constant. Arguments: - size: int, number of possible words to sample. - sampling_factor: the sampling factor in the word2vec formula. + size: Int, number of possible words to sample. + sampling_factor: The sampling factor in the word2vec formula. Returns: A 1D Numpy array of length `size` where the ith entry @@ -151,30 +167,37 @@ def skipgrams(sequence, seed=None): """Generates skipgram word pairs. - Takes a sequence (list of indexes of words), - returns couples of [word_index, other_word index] and labels (1s or 0s), - where label = 1 if 'other_word' belongs to the context of 'word', - and label=0 if 'other_word' is randomly sampled + This function transforms a sequence of word indexes (list of integers) + into tuples of words of the form: + + - (word, word in the same window), with label 1 (positive samples). + - (word, random word from the vocabulary), with label 0 (negative samples). + + Read more about Skipgram in this gnomic paper by Mikolov et al.: + [Efficient Estimation of Word Representations in + Vector Space](http://arxiv.org/pdf/1301.3781v3.pdf) Arguments: - sequence: a word sequence (sentence), encoded as a list + sequence: A word sequence (sentence), encoded as a list of word indices (integers). If using a `sampling_table`, word indices are expected to match the rank of the words in a reference dataset (e.g. 10 would encode the 10-th most frequently occurring token). Note that index 0 is expected to be a non-word and will be skipped. - vocabulary_size: int. maximum possible word index + 1 - window_size: int. actually half-window. - The window of a word wi will be [i-window_size, i+window_size+1] - negative_samples: float >= 0. 0 for no negative (=random) samples. - 1 for same number as positive samples. etc. - shuffle: whether to shuffle the word couples before returning them. + vocabulary_size: Int, maximum possible word index + 1 + window_size: Int, size of sampling windows (technically half-window). + The window of a word `w_i` will be + `[i - window_size, i + window_size+1]`. + negative_samples: Float >= 0. 0 for no negative (i.e. random) samples. + 1 for same number as positive samples. + shuffle: Whether to shuffle the word couples before returning them. categorical: bool. if False, labels will be - integers (eg. [0, 1, 1 .. ]), - if True labels will be categorical eg. [[1,0],[0,1],[0,1] .. ] + integers (eg. `[0, 1, 1 .. ]`), + if `True`, labels will be categorical, e.g. + `[[1,0],[0,1],[0,1] .. ]`. sampling_table: 1D array of size `vocabulary_size` where the entry i encodes the probability to sample a word of rank i. - seed: random seed. + seed: Random seed. Returns: couples, labels: where `couples` are int pairs and @@ -234,9 +257,9 @@ def _remove_long_seq(maxlen, seq, label): """Removes sequences that exceed the maximum length. Arguments: - maxlen: int, maximum length - seq: list of lists where each sublist is a sequence - label: list where each element is an integer + maxlen: Int, maximum length of the output sequences. + seq: List of lists, where each sublist is a sequence. + label: List where each element is an integer. Returns: new_seq, new_label: shortened lists for `seq` and `label`. @@ -247,3 +270,120 @@ def _remove_long_seq(maxlen, seq, label): new_seq.append(x) new_label.append(y) return new_seq, new_label + + +@tf_export('keras.preprocessing.sequence.TimeseriesGenerator') +class TimeseriesGenerator(Sequence): + """Utility class for generating batches of temporal data. + + This class takes in a sequence of data-points gathered at + equal intervals, along with time series parameters such as + stride, length of history, etc., to produce batches for + training/validation. + + Arguments: + data: Indexable generator (such as list or Numpy array) + containing consecutive data points (timesteps). + The data should be at 2D, and axis 0 is expected + to be the time dimension. + targets: Targets corresponding to timesteps in `data`. + It should have same length as `data`. + length: Length of the output sequences (in number of timesteps). + sampling_rate: Period between successive individual timesteps + within sequences. For rate `r`, timesteps + `data[i]`, `data[i-r]`, ... `data[i - length]` + are used for create a sample sequence. + stride: Period between successive output sequences. + For stride `s`, consecutive output samples would + be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc. + start_index, end_index: Data points earlier than `start_index` + or later than `end_index` will not be used in the output sequences. + This is useful to reserve part of the data for test or validation. + shuffle: Whether to shuffle output samples, + or instead draw them in chronological order. + reverse: Boolean: if `true`, timesteps in each output sample will be + in reverse chronological order. + batch_size: Number of timeseries samples in each batch + (except maybe the last one). + + Returns: + A [Sequence](/utils/#sequence) instance. + + Examples: + + ```python + from keras.preprocessing.sequence import TimeseriesGenerator + import numpy as np + + data = np.array([[i] for i in range(50)]) + targets = np.array([[i] for i in range(50)]) + + data_gen = TimeseriesGenerator(data, targets, + length=10, sampling_rate=2, + batch_size=2) + assert len(data_gen) == 20 + + batch_0 = data_gen[0] + x, y = batch_0 + assert np.array_equal(x, + np.array([[[0], [2], [4], [6], [8]], + [[1], [3], [5], [7], [9]]])) + assert np.array_equal(y, + np.array([[10], [11]])) + ``` + """ + + def __init__(self, + data, + targets, + length, + sampling_rate=1, + stride=1, + start_index=0, + end_index=None, + shuffle=False, + reverse=False, + batch_size=128): + self.data = data + self.targets = targets + self.length = length + self.sampling_rate = sampling_rate + self.stride = stride + self.start_index = start_index + length + if end_index is None: + end_index = len(data) - 1 + self.end_index = end_index + self.shuffle = shuffle + self.reverse = reverse + self.batch_size = batch_size + + def __len__(self): + length = int( + np.ceil((self.end_index - self.start_index) / + (self.batch_size * self.stride))) + return length if length >= 0 else 0 + + def _empty_batch(self, num_rows): + samples_shape = [num_rows, self.length // self.sampling_rate] + samples_shape.extend(self.data.shape[1:]) + targets_shape = [num_rows] + targets_shape.extend(self.targets.shape[1:]) + return np.empty(samples_shape), np.empty(targets_shape) + + def __getitem__(self, index): + if self.shuffle: + rows = np.random.randint( + self.start_index, self.end_index, size=self.batch_size) + else: + i = self.start_index + self.batch_size * self.stride * index + rows = np.arange(i, min(i + self.batch_size * self.stride, + self.end_index), self.stride) + + samples, targets = self._empty_batch(len(rows)) + for j in range(len(rows)): + indices = range(rows[j] - self.length, rows[j], self.sampling_rate) + samples[j] = self.data[indices] + targets[j] = self.targets[rows[j]] + if self.reverse: + return samples[:, ::-1, ...], targets + return samples, targets diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py index 4529e6e..b9bfdd0 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/sequence_test.py @@ -84,15 +84,91 @@ class TestSequence(test.TestCase): couples, labels = keras.preprocessing.sequence.skipgrams( np.arange(3), vocabulary_size=3) for couple in couples: - assert couple[0] in [0, 1, 2] and couple[1] in [0, 1, 2] + self.assertIn(couple[0], [0, 1, 2]) + self.assertIn(couple[1], [0, 1, 2]) # test window size and categorical labels couples, labels = keras.preprocessing.sequence.skipgrams( np.arange(5), vocabulary_size=5, window_size=1, categorical=True) for couple in couples: - assert couple[0] - couple[1] <= 3 + self.assertLessEqual(couple[0] - couple[1], 3) for l in labels: - assert len(l) == 2 + self.assertEqual(len(l), 2) + + def test_TimeseriesGenerator(self): + data = np.array([[i] for i in range(50)]) + targets = np.array([[i] for i in range(50)]) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, targets, length=10, sampling_rate=2, batch_size=2) + self.assertEqual(len(data_gen), 20) + self.assertAllClose(data_gen[0][0], + np.array([[[0], [2], [4], [6], [8]], [[1], [3], [5], + [7], [9]]])) + self.assertAllClose(data_gen[0][1], np.array([[10], [11]])) + self.assertAllClose(data_gen[1][0], + np.array([[[2], [4], [6], [8], [10]], [[3], [5], [7], + [9], [11]]])) + self.assertAllClose(data_gen[1][1], np.array([[12], [13]])) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, targets, length=10, sampling_rate=2, reverse=True, batch_size=2) + self.assertEqual(len(data_gen), 20) + self.assertAllClose(data_gen[0][0], + np.array([[[8], [6], [4], [2], [0]], [[9], [7], [5], + [3], [1]]])) + self.assertAllClose(data_gen[0][1], np.array([[10], [11]])) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, targets, length=10, sampling_rate=2, shuffle=True, batch_size=1) + batch = data_gen[0] + r = batch[1][0][0] + self.assertAllClose(batch[0], + np.array([[[r - 10], [r - 8], [r - 6], [r - 4], + [r - 2]]])) + self.assertAllClose(batch[1], np.array([ + [r], + ])) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, targets, length=10, sampling_rate=2, stride=2, batch_size=2) + self.assertEqual(len(data_gen), 10) + self.assertAllClose(data_gen[1][0], + np.array([[[4], [6], [8], [10], [12]], [[6], [8], [10], + [12], [14]]])) + self.assertAllClose(data_gen[1][1], np.array([[14], [16]])) + + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, + targets, + length=10, + sampling_rate=2, + start_index=10, + end_index=30, + batch_size=2) + self.assertEqual(len(data_gen), 5) + self.assertAllClose(data_gen[0][0], + np.array([[[10], [12], [14], [16], [18]], + [[11], [13], [15], [17], [19]]])) + self.assertAllClose(data_gen[0][1], np.array([[20], [21]])) + + data = np.array([np.random.random_sample((1, 2, 3, 4)) for i in range(50)]) + targets = np.array([np.random.random_sample((3, 2, 1)) for i in range(50)]) + data_gen = keras.preprocessing.sequence.TimeseriesGenerator( + data, + targets, + length=10, + sampling_rate=2, + start_index=10, + end_index=30, + batch_size=2) + + self.assertEqual(len(data_gen), 5) + self.assertAllClose(data_gen[0][0], + np.array( + [np.array(data[10:19:2]), + np.array(data[11:20:2])])) + self.assertAllClose(data_gen[0][1], np.array([targets[20], targets[21]])) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text.py b/tensorflow/python/keras/_impl/keras/preprocessing/text.py index 1e3828c..f652f31 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/text.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/text.py @@ -91,6 +91,7 @@ def one_hot(text, text, n, hash_function=hash, filters=filters, lower=lower, split=split) +@tf_export('keras.preprocessing.text.hashing_trick') def hashing_trick(text, n, hash_function=None, @@ -187,21 +188,27 @@ class Tokenizer(object): self.document_count = 0 self.char_level = char_level self.oov_token = oov_token + self.index_docs = {} def fit_on_texts(self, texts): """Updates internal vocabulary based on a list of texts. + In the case where texts contains lists, we assume each entry of the lists + to be a token. + Required before using `texts_to_sequences` or `texts_to_matrix`. Arguments: texts: can be a list of strings, - or a generator of strings (for memory-efficiency) + a generator of strings (for memory-efficiency), + or a list of list of strings. """ - self.document_count = 0 for text in texts: self.document_count += 1 - seq = text if self.char_level else text_to_word_sequence( - text, self.filters, self.lower, self.split) + if self.char_level or isinstance(text, list): + seq = text + else: + seq = text_to_word_sequence(text, self.filters, self.lower, self.split) for w in seq: if w in self.word_counts: self.word_counts[w] += 1 @@ -226,7 +233,6 @@ class Tokenizer(object): if i is None: self.word_index[self.oov_token] = len(self.word_index) + 1 - self.index_docs = {} for w, c in list(self.word_docs.items()): self.index_docs[self.word_index[w]] = c @@ -240,8 +246,7 @@ class Tokenizer(object): sequences: A list of sequence. A "sequence" is a list of integer word indices. """ - self.document_count = len(sequences) - self.index_docs = {} + self.document_count += len(sequences) for seq in sequences: seq = set(seq) for i in seq: @@ -268,7 +273,11 @@ class Tokenizer(object): return res def texts_to_sequences_generator(self, texts): - """Transforms each text in texts in a sequence of integers. + """Transforms each text in `texts` in a sequence of integers. + + Each item in texts can also be a list, in which case we assume each item of + that list + to be a token. Only top "num_words" most frequent words will be taken into account. Only words known by the tokenizer will be taken into account. @@ -281,8 +290,10 @@ class Tokenizer(object): """ num_words = self.num_words for text in texts: - seq = text if self.char_level else text_to_word_sequence( - text, self.filters, self.lower, self.split) + if self.char_level or isinstance(text, list): + seq = text + else: + seq = text_to_word_sequence(text, self.filters, self.lower, self.split) vect = [] for w in seq: i = self.word_index.get(w) diff --git a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py index a934e33..c6a267e 100644 --- a/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py +++ b/tensorflow/python/keras/_impl/keras/preprocessing/text_test.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -80,17 +81,52 @@ class TestText(test.TestCase): x_train = ['This text has only known words'] x_test = ['This text has some unknown words'] # 2 OOVs: some, unknown - # Defalut, without OOV flag + # Default, without OOV flag tokenizer = keras.preprocessing.text.Tokenizer() tokenizer.fit_on_texts(x_train) x_test_seq = tokenizer.texts_to_sequences(x_test) - assert len(x_test_seq[0]) == 4 # discards 2 OOVs + self.assertEqual(len(x_test_seq[0]), 4) # discards 2 OOVs # With OOV feature tokenizer = keras.preprocessing.text.Tokenizer(oov_token='') tokenizer.fit_on_texts(x_train) x_test_seq = tokenizer.texts_to_sequences(x_test) - assert len(x_test_seq[0]) == 6 # OOVs marked in place + self.assertEqual(len(x_test_seq[0]), 6) # OOVs marked in place + + def test_sequential_fit(self): + texts = [ + 'The cat sat on the mat.', 'The dog sat on the log.', + 'Dogs and cats living together.' + ] + word_sequences = [['The', 'cat', 'is', 'sitting'], + ['The', 'dog', 'is', 'standing']] + tokenizer = keras.preprocessing.text.Tokenizer() + tokenizer.fit_on_texts(texts) + tokenizer.fit_on_texts(word_sequences) + + self.assertEqual(tokenizer.document_count, 5) + + tokenizer.texts_to_matrix(texts) + tokenizer.texts_to_matrix(word_sequences) + + def test_text_to_word_sequence(self): + text = 'hello! ? world!' + seq = keras.preprocessing.text.text_to_word_sequence(text) + self.assertEqual(seq, ['hello', 'world']) + + def test_text_to_word_sequence_unicode(self): + text = u'ali! veli? kırk dokuz elli' + seq = keras.preprocessing.text.text_to_word_sequence(text) + self.assertEqual(seq, [u'ali', u'veli', u'kırk', u'dokuz', u'elli']) + + def test_tokenizer_unicode(self): + texts = [ + u'ali veli kırk dokuz elli', u'ali veli kırk dokuz elli veli kırk dokuz' + ] + tokenizer = keras.preprocessing.text.Tokenizer(num_words=5) + tokenizer.fit_on_texts(texts) + + self.assertEqual(len(tokenizer.word_counts), 5) if __name__ == '__main__': diff --git a/tensorflow/python/keras/preprocessing/image/__init__.py b/tensorflow/python/keras/preprocessing/image/__init__.py index b96e767..6aba5fc 100644 --- a/tensorflow/python/keras/preprocessing/image/__init__.py +++ b/tensorflow/python/keras/preprocessing/image/__init__.py @@ -27,6 +27,7 @@ from tensorflow.python.keras._impl.keras.preprocessing.image import img_to_array from tensorflow.python.keras._impl.keras.preprocessing.image import Iterator from tensorflow.python.keras._impl.keras.preprocessing.image import load_img from tensorflow.python.keras._impl.keras.preprocessing.image import NumpyArrayIterator +from tensorflow.python.keras._impl.keras.preprocessing.image import random_brightness from tensorflow.python.keras._impl.keras.preprocessing.image import random_channel_shift from tensorflow.python.keras._impl.keras.preprocessing.image import random_rotation from tensorflow.python.keras._impl.keras.preprocessing.image import random_shear diff --git a/tensorflow/python/keras/preprocessing/sequence/__init__.py b/tensorflow/python/keras/preprocessing/sequence/__init__.py index 112f6af..b7a7149 100644 --- a/tensorflow/python/keras/preprocessing/sequence/__init__.py +++ b/tensorflow/python/keras/preprocessing/sequence/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function from tensorflow.python.keras._impl.keras.preprocessing.sequence import make_sampling_table from tensorflow.python.keras._impl.keras.preprocessing.sequence import pad_sequences from tensorflow.python.keras._impl.keras.preprocessing.sequence import skipgrams +from tensorflow.python.keras._impl.keras.preprocessing.sequence import TimeseriesGenerator del absolute_import del division diff --git a/tensorflow/python/keras/preprocessing/text/__init__.py b/tensorflow/python/keras/preprocessing/text/__init__.py index 5bf1a2f..000ad68 100644 --- a/tensorflow/python/keras/preprocessing/text/__init__.py +++ b/tensorflow/python/keras/preprocessing/text/__init__.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.keras._impl.keras.preprocessing.text import hashing_trick from tensorflow.python.keras._impl.keras.preprocessing.text import one_hot from tensorflow.python.keras._impl.keras.preprocessing.text import text_to_word_sequence from tensorflow.python.keras._impl.keras.preprocessing.text import Tokenizer diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt index 04174bf..ec0f3d8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-directory-iterator.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'directory\', \'image_data_generator\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'None\', \'\', \'png\', \'False\', \'nearest\'], " + argspec: "args=[\'self\', \'directory\', \'image_data_generator\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'subset\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'None\', \'\', \'png\', \'False\', \'None\', \'nearest\'], " } member_method { name: "next" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt index 41f27d1..f5bc04e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-image-data-generator.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'featurewise_center\', \'samplewise_center\', \'featurewise_std_normalization\', \'samplewise_std_normalization\', \'zca_whitening\', \'zca_epsilon\', \'rotation_range\', \'width_shift_range\', \'height_shift_range\', \'shear_range\', \'zoom_range\', \'channel_shift_range\', \'fill_mode\', \'cval\', \'horizontal_flip\', \'vertical_flip\', \'rescale\', \'preprocessing_function\', \'data_format\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'1e-06\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'nearest\', \'0.0\', \'False\', \'False\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'featurewise_center\', \'samplewise_center\', \'featurewise_std_normalization\', \'samplewise_std_normalization\', \'zca_whitening\', \'zca_epsilon\', \'rotation_range\', \'width_shift_range\', \'height_shift_range\', \'brightness_range\', \'shear_range\', \'zoom_range\', \'channel_shift_range\', \'fill_mode\', \'cval\', \'horizontal_flip\', \'vertical_flip\', \'rescale\', \'preprocessing_function\', \'data_format\', \'validation_split\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'1e-06\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'0.0\', \'0.0\', \'0.0\', \'nearest\', \'0.0\', \'False\', \'False\', \'None\', \'None\', \'None\', \'0.0\'], " } member_method { name: "fit" @@ -12,11 +12,11 @@ tf_class { } member_method { name: "flow" - argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\'], varargs=None, keywords=None, defaults=[\'None\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'subset\'], varargs=None, keywords=None, defaults=[\'None\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'None\'], " } member_method { name: "flow_from_directory" - argspec: "args=[\'self\', \'directory\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'False\', \'nearest\'], " + argspec: "args=[\'self\', \'directory\', \'target_size\', \'color_mode\', \'classes\', \'class_mode\', \'batch_size\', \'shuffle\', \'seed\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'follow_links\', \'subset\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'(256, 256)\', \'rgb\', \'None\', \'categorical\', \'32\', \'True\', \'None\', \'None\', \'\', \'png\', \'False\', \'None\', \'nearest\'], " } member_method { name: "random_transform" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt index 4ef6e6e..42196dd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.-numpy-array-iterator.pbtxt @@ -6,7 +6,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'x\', \'y\', \'image_data_generator\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'None\', \'None\', \'\', \'png\'], " + argspec: "args=[\'self\', \'x\', \'y\', \'image_data_generator\', \'batch_size\', \'shuffle\', \'seed\', \'data_format\', \'save_to_dir\', \'save_prefix\', \'save_format\', \'subset\'], varargs=None, keywords=None, defaults=[\'32\', \'False\', \'None\', \'None\', \'None\', \'\', \'png\', \'None\'], " } member_method { name: "next" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt index d28fef6..6b850dd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.image.pbtxt @@ -37,6 +37,10 @@ tf_module { argspec: "args=[\'path\', \'grayscale\', \'target_size\', \'interpolation\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'nearest\'], " } member_method { + name: "random_brightness" + argspec: "args=[\'x\', \'brightness_range\'], varargs=None, keywords=None, defaults=None" + } + member_method { name: "random_channel_shift" argspec: "args=[\'x\', \'intensity\', \'channel_axis\'], varargs=None, keywords=None, defaults=[\'0\'], " } diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt new file mode 100644 index 0000000..d9c3215 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.-timeseries-generator.pbtxt @@ -0,0 +1,14 @@ +path: "tensorflow.keras.preprocessing.sequence.TimeseriesGenerator" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'data\', \'targets\', \'length\', \'sampling_rate\', \'stride\', \'start_index\', \'end_index\', \'shuffle\', \'reverse\', \'batch_size\'], varargs=None, keywords=None, defaults=[\'1\', \'1\', \'0\', \'None\', \'False\', \'False\', \'128\'], " + } + member_method { + name: "on_epoch_end" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt index 1b01935..cf59f8a 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.sequence.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.keras.preprocessing.sequence" tf_module { + member { + name: "TimeseriesGenerator" + mtype: "" + } member_method { name: "make_sampling_table" argspec: "args=[\'size\', \'sampling_factor\'], varargs=None, keywords=None, defaults=[\'1e-05\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt index d106429..50b54fc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.preprocessing.text.pbtxt @@ -5,6 +5,10 @@ tf_module { mtype: "" } member_method { + name: "hashing_trick" + argspec: "args=[\'text\', \'n\', \'hash_function\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'None\', \'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], " + } + member_method { name: "one_hot" argspec: "args=[\'text\', \'n\', \'filters\', \'lower\', \'split\'], varargs=None, keywords=None, defaults=[\'!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n\', \'True\', \' \'], " } -- 2.7.4