--- /dev/null
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Smoke test for reading records from GCS to TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+import time
+
+import numpy as np
+import tensorflow as tf
+from tensorflow.core.example import example_pb2
+from tensorflow.python.lib.io import file_io
+
+flags = tf.app.flags
+flags.DEFINE_string("gcs_bucket_url", "",
+ "The URL to the GCS bucket in which the temporary "
+ "tfrecord file is to be written and read, e.g., "
+ "gs://my-gcs-bucket/test-directory")
+flags.DEFINE_integer("num_examples", 10, "Number of examples to generate")
+
+FLAGS = flags.FLAGS
+
+
+def create_examples(num_examples, input_mean):
+ """Create ExampleProto's containing data."""
+ ids = np.arange(num_examples).reshape([num_examples, 1])
+ inputs = np.random.randn(num_examples, 1) + input_mean
+ target = inputs - input_mean
+ examples = []
+ for row in range(num_examples):
+ ex = example_pb2.Example()
+ ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0]))
+ ex.features.feature["target"].float_list.value.append(target[row, 0])
+ ex.features.feature["inputs"].float_list.value.append(inputs[row, 0])
+ examples.append(ex)
+ return examples
+
+
+def create_dir_test():
+ """Verifies file_io directory handling methods."""
+
+ # Test directory creation.
+ starttime_ms = int(round(time.time() * 1000))
+ dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
+ print("Creating dir %s" % dir_name)
+ file_io.create_dir(dir_name)
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ print("Created directory in: %d milliseconds" % elapsed_ms)
+
+ # Check that the directory exists.
+ dir_exists = file_io.is_directory(dir_name)
+ assert dir_exists
+ print("%s directory exists: %s" % (dir_name, dir_exists))
+
+ # Test recursive directory creation.
+ starttime_ms = int(round(time.time() * 1000))
+ recursive_dir_name = "%s/%s/%s" % (dir_name,
+ "nested_dir1",
+ "nested_dir2")
+ print("Creating recursive dir %s" % recursive_dir_name)
+ file_io.recursive_create_dir(recursive_dir_name)
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ print("Created directory recursively in: %d milliseconds" % elapsed_ms)
+
+ # Check that the directory exists.
+ recursive_dir_exists = file_io.is_directory(recursive_dir_name)
+ assert recursive_dir_exists
+ print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists))
+
+ # Create some contents in the just created directory and list the contents.
+ num_files = 10
+ files_to_create = ["file_%d.txt" % n for n in range(num_files)]
+ for file_num in files_to_create:
+ file_name = "%s/%s" % (dir_name, file_num)
+ print("Creating file %s." % file_name)
+ file_io.write_string_to_file(file_name, "test file.")
+
+ print("Listing directory %s." % dir_name)
+ starttime_ms = int(round(time.time() * 1000))
+ directory_contents = file_io.list_directory(dir_name)
+ print(directory_contents)
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms))
+ assert set(directory_contents) == set(files_to_create + ["nested_dir1/"])
+
+ # Test directory renaming.
+ dir_to_rename = "%s/old_dir" % dir_name
+ new_dir_name = "%s/new_dir" % dir_name
+ file_io.create_dir(dir_to_rename)
+ assert file_io.is_directory(dir_to_rename)
+ assert not file_io.is_directory(new_dir_name)
+
+ starttime_ms = int(round(time.time() * 1000))
+ print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name))
+ file_io.rename(dir_to_rename, new_dir_name)
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ print("Renamed directory %s to %s in %s milliseconds" % (
+ dir_to_rename, new_dir_name, elapsed_ms))
+ assert not file_io.is_directory(dir_to_rename)
+ assert file_io.is_directory(new_dir_name)
+
+ # Test Delete directory recursively.
+ print("Deleting directory recursively %s." % dir_name)
+ starttime_ms = int(round(time.time() * 1000))
+ file_io.delete_recursively(dir_name)
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ dir_exists = file_io.is_directory(dir_name)
+ assert not dir_exists
+ print("Deleted directory recursively %s in %s milliseconds" % (
+ dir_name, elapsed_ms))
+
+
+def create_object_test():
+ """Verifies file_io's object manipulation methods ."""
+ starttime_ms = int(round(time.time() * 1000))
+ dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms)
+ print("Creating dir %s." % dir_name)
+ file_io.create_dir(dir_name)
+
+ num_files = 5
+ # Create files of 2 different patterns in this directory.
+ files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n)
+ for n in range(num_files)]
+ files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n)
+ for n in range(num_files)]
+
+ starttime_ms = int(round(time.time() * 1000))
+ files_to_create = files_pattern_1 + files_pattern_2
+ for file_name in files_to_create:
+ print("Creating file %s." % file_name)
+ file_io.write_string_to_file(file_name, "test file creation.")
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ print("Created %d files in %s milliseconds" %
+ (len(files_to_create), elapsed_ms))
+
+ # Listing files of pattern1.
+ list_files_pattern = "%s/test_file*.txt" % dir_name
+ print("Getting files matching pattern %s." % list_files_pattern)
+ starttime_ms = int(round(time.time() * 1000))
+ files_list = file_io.get_matching_files(list_files_pattern)
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ print("Listed files in %s milliseconds" % elapsed_ms)
+ print(files_list)
+ assert set(files_list) == set(files_pattern_1)
+
+ # Listing files of pattern2.
+ list_files_pattern = "%s/testfile*.txt" % dir_name
+ print("Getting files matching pattern %s." % list_files_pattern)
+ starttime_ms = int(round(time.time() * 1000))
+ files_list = file_io.get_matching_files(list_files_pattern)
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ print("Listed files in %s milliseconds" % elapsed_ms)
+ print(files_list)
+ assert set(files_list) == set(files_pattern_2)
+
+ # Test renaming file.
+ file_to_rename = "%s/oldname.txt" % dir_name
+ file_new_name = "%s/newname.txt" % dir_name
+ file_io.write_string_to_file(file_to_rename, "test file.")
+ assert file_io.file_exists(file_to_rename)
+ assert not file_io.file_exists(file_new_name)
+
+ print("Will try renaming file %s to %s" % (file_to_rename, file_new_name))
+ starttime_ms = int(round(time.time() * 1000))
+ file_io.rename(file_to_rename, file_new_name)
+ elapsed_ms = int(round(time.time() * 1000)) - starttime_ms
+ print("File %s renamed to %s in %s milliseconds" % (
+ file_to_rename, file_new_name, elapsed_ms))
+ assert not file_io.file_exists(file_to_rename)
+ assert file_io.file_exists(file_new_name)
+
+ # Delete directory.
+ print("Deleting directory %s." % dir_name)
+ file_io.delete_recursively(dir_name)
+
+
+def main(argv):
+ del argv # Unused.
+ # Sanity check on the GCS bucket URL.
+ if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"):
+ print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url)
+ sys.exit(1)
+
+ # Verify that writing to the records file in GCS works.
+ print("\n=== Testing writing and reading of GCS record file... ===")
+ example_data = create_examples(FLAGS.num_examples, 5)
+ with tf.python_io.TFRecordWriter(FLAGS.gcs_bucket_url) as hf:
+ for e in example_data:
+ hf.write(e.SerializeToString())
+
+ print("Data written to: %s" % FLAGS.gcs_bucket_url)
+
+ # Verify that reading from the tfrecord file works and that
+ # tf_record_iterator works.
+ record_iter = tf.python_io.tf_record_iterator(FLAGS.gcs_bucket_url)
+ read_count = 0
+ for _ in record_iter:
+ read_count += 1
+ print("Read %d records using tf_record_iterator" % read_count)
+
+ if read_count != FLAGS.num_examples:
+ print("FAIL: The number of records read from tf_record_iterator (%d) "
+ "differs from the expected number (%d)" % (read_count,
+ FLAGS.num_examples))
+ sys.exit(1)
+
+ # Verify that running the read op in a session works.
+ print("\n=== Testing TFRecordReader.read op in a session... ===")
+ with tf.Graph().as_default() as _:
+ filename_queue = tf.train.string_input_producer([FLAGS.gcs_bucket_url],
+ num_epochs=1)
+ reader = tf.TFRecordReader()
+ _, serialized_example = reader.read(filename_queue)
+
+ with tf.Session() as sess:
+ sess.run(tf.global_variables_initializer())
+ sess.run(tf.local_variables_initializer())
+ tf.train.start_queue_runners()
+ index = 0
+ for _ in range(FLAGS.num_examples):
+ print("Read record: %d" % index)
+ sess.run(serialized_example)
+ index += 1
+
+ # Reading one more record should trigger an exception.
+ try:
+ sess.run(serialized_example)
+ print("FAIL: Failed to catch the expected OutOfRangeError while "
+ "reading one more record than is available")
+ sys.exit(1)
+ except tf.errors.OutOfRangeError:
+ print("Successfully caught the expected OutOfRangeError while "
+ "reading one more record than is available")
+
+ create_dir_test()
+ create_object_test()
+
+if __name__ == "__main__":
+ tf.app.run(main)