Resolve name collisions with assets in SavedModels by deduplicating names that
authorKarmel Allison <karmel@google.com>
Thu, 24 May 2018 03:53:15 +0000 (20:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 24 May 2018 03:56:01 +0000 (20:56 -0700)
point to distinct files.

PiperOrigin-RevId: 197835288

tensorflow/python/lib/io/file_io.py
tensorflow/python/lib/io/file_io_test.py
tensorflow/python/saved_model/builder_impl.py
tensorflow/python/saved_model/saved_model_test.py

index 59f5075..f22fb25 100644 (file)
@@ -21,6 +21,7 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
+import binascii
 import os
 import uuid
 
@@ -33,6 +34,10 @@ from tensorflow.python.util import compat
 from tensorflow.python.util import deprecation
 from tensorflow.python.util.tf_export import tf_export
 
+# A good default block size depends on the system in question.
+# A somewhat conservative default chosen here.
+_DEFAULT_BLOCK_SIZE = 16 * 1024 * 1024
+
 
 class FileIO(object):
   """FileIO class that exposes methods to read / write to / from files.
@@ -551,3 +556,56 @@ def stat(filename):
   with errors.raise_exception_on_not_ok_status() as status:
     pywrap_tensorflow.Stat(compat.as_bytes(filename), file_statistics, status)
     return file_statistics
+
+
+def filecmp(filename_a, filename_b):
+  """Compare two files, returning True if they are the same, False otherwise.
+
+  We check size first and return False quickly if the files are different sizes.
+  If they are the same size, we continue to generating a crc for the whole file.
+
+  You might wonder: why not use Python's filecmp.cmp() instead? The answer is
+  that the builtin library is not robust to the many different filesystems
+  TensorFlow runs on, and so we here perform a similar comparison with
+  the more robust FileIO.
+
+  Args:
+    filename_a: string path to the first file.
+    filename_b: string path to the second file.
+
+  Returns:
+    True if the files are the same, False otherwise.
+  """
+  size_a = FileIO(filename_a, "rb").size()
+  size_b = FileIO(filename_b, "rb").size()
+  if size_a != size_b:
+    return False
+
+  # Size is the same. Do a full check.
+  crc_a = file_crc32(filename_a)
+  crc_b = file_crc32(filename_b)
+  return crc_a == crc_b
+
+
+def file_crc32(filename, block_size=_DEFAULT_BLOCK_SIZE):
+  """Get the crc32 of the passed file.
+
+  The crc32 of a file can be used for error checking; two files with the same
+  crc32 are considered equivalent. Note that the entire file must be read
+  to produce the crc32.
+
+  Args:
+    filename: string, path to a file
+    block_size: Integer, process the files by reading blocks of `block_size`
+      bytes. Use -1 to read the file as once.
+
+  Returns:
+    hexadecimal as string, the crc32 of the passed file.
+  """
+  crc = 0
+  with FileIO(filename, mode="rb") as f:
+    chunk = f.read(n=block_size)
+    while chunk:
+      crc = binascii.crc32(chunk, crc)
+      chunk = f.read(n=block_size)
+  return hex(crc & 0xFFFFFFFF)
index 223858e..c21eb93 100644 (file)
@@ -491,5 +491,96 @@ class FileIoTest(test.TestCase):
     v = file_io.file_exists(file_path)
     self.assertEqual(v, True)
 
+  def testFilecmp(self):
+    file1 = os.path.join(self._base_dir, "file1")
+    file_io.write_string_to_file(file1, "This is a sentence\n" * 100)
+
+    file2 = os.path.join(self._base_dir, "file2")
+    file_io.write_string_to_file(file2, "This is another sentence\n" * 100)
+
+    file3 = os.path.join(self._base_dir, "file3")
+    file_io.write_string_to_file(file3, u"This is another sentence\n" * 100)
+
+    self.assertFalse(file_io.filecmp(file1, file2))
+    self.assertTrue(file_io.filecmp(file2, file3))
+
+  def testFilecmpSameSize(self):
+    file1 = os.path.join(self._base_dir, "file1")
+    file_io.write_string_to_file(file1, "This is a sentence\n" * 100)
+
+    file2 = os.path.join(self._base_dir, "file2")
+    file_io.write_string_to_file(file2, "This is b sentence\n" * 100)
+
+    file3 = os.path.join(self._base_dir, "file3")
+    file_io.write_string_to_file(file3, u"This is b sentence\n" * 100)
+
+    self.assertFalse(file_io.filecmp(file1, file2))
+    self.assertTrue(file_io.filecmp(file2, file3))
+
+  def testFilecmpBinary(self):
+    file1 = os.path.join(self._base_dir, "file1")
+    file_io.FileIO(file1, "wb").write("testing\n\na")
+
+    file2 = os.path.join(self._base_dir, "file2")
+    file_io.FileIO(file2, "wb").write("testing\n\nb")
+
+    file3 = os.path.join(self._base_dir, "file3")
+    file_io.FileIO(file3, "wb").write("testing\n\nb")
+
+    file4 = os.path.join(self._base_dir, "file4")
+    file_io.FileIO(file4, "wb").write("testing\n\ntesting")
+
+    self.assertFalse(file_io.filecmp(file1, file2))
+    self.assertFalse(file_io.filecmp(file1, file4))
+    self.assertTrue(file_io.filecmp(file2, file3))
+
+  def testFileCrc32(self):
+    file1 = os.path.join(self._base_dir, "file1")
+    file_io.write_string_to_file(file1, "This is a sentence\n" * 100)
+    crc1 = file_io.file_crc32(file1)
+
+    file2 = os.path.join(self._base_dir, "file2")
+    file_io.write_string_to_file(file2, "This is another sentence\n" * 100)
+    crc2 = file_io.file_crc32(file2)
+
+    file3 = os.path.join(self._base_dir, "file3")
+    file_io.write_string_to_file(file3, "This is another sentence\n" * 100)
+    crc3 = file_io.file_crc32(file3)
+
+    self.assertTrue(crc1 != crc2)
+    self.assertEqual(crc2, crc3)
+
+  def testFileCrc32WithBytes(self):
+    file1 = os.path.join(self._base_dir, "file1")
+    file_io.write_string_to_file(file1, "This is a sentence\n" * 100)
+    crc1 = file_io.file_crc32(file1, block_size=24)
+
+    file2 = os.path.join(self._base_dir, "file2")
+    file_io.write_string_to_file(file2, "This is another sentence\n" * 100)
+    crc2 = file_io.file_crc32(file2, block_size=24)
+
+    file3 = os.path.join(self._base_dir, "file3")
+    file_io.write_string_to_file(file3, "This is another sentence\n" * 100)
+    crc3 = file_io.file_crc32(file3, block_size=-1)
+
+    self.assertTrue(crc1 != crc2)
+    self.assertEqual(crc2, crc3)
+
+  def testFileCrc32Binary(self):
+    file1 = os.path.join(self._base_dir, "file1")
+    file_io.FileIO(file1, "wb").write("testing\n\n")
+    crc1 = file_io.file_crc32(file1)
+
+    file2 = os.path.join(self._base_dir, "file2")
+    file_io.FileIO(file2, "wb").write("testing\n\n\n")
+    crc2 = file_io.file_crc32(file2)
+
+    file3 = os.path.join(self._base_dir, "file3")
+    file_io.FileIO(file3, "wb").write("testing\n\n\n")
+    crc3 = file_io.file_crc32(file3)
+
+    self.assertTrue(crc1 != crc2)
+    self.assertEqual(crc2, crc3)
+
 if __name__ == "__main__":
   test.main()
index 071033b..4b39826 100644 (file)
@@ -104,10 +104,10 @@ class SavedModelBuilder(object):
     Args:
       assets_collection_to_add: The collection where the asset paths are setup.
     """
-    asset_source_filepath_list = _maybe_save_assets(assets_collection_to_add)
+    asset_filename_map = _maybe_save_assets(assets_collection_to_add)
 
     # Return if there are no assets to write.
-    if len(asset_source_filepath_list) is 0:
+    if not asset_filename_map:
       tf_logging.info("No assets to write.")
       return
 
@@ -119,12 +119,10 @@ class SavedModelBuilder(object):
       file_io.recursive_create_dir(assets_destination_dir)
 
     # Copy each asset from source path to destination path.
-    for asset_source_filepath in asset_source_filepath_list:
-      asset_source_filename = os.path.basename(asset_source_filepath)
-
+    for asset_basename, asset_source_filepath in asset_filename_map.items():
       asset_destination_filepath = os.path.join(
           compat.as_bytes(assets_destination_dir),
-          compat.as_bytes(asset_source_filename))
+          compat.as_bytes(asset_basename))
 
       # Only copy the asset file to the destination if it does not already
       # exist. This is to ensure that an asset with the same name defined as
@@ -475,16 +473,17 @@ def _maybe_save_assets(assets_collection_to_add=None):
     assets_collection_to_add: The collection where the asset paths are setup.
 
   Returns:
-    The list of filepaths to the assets in the assets collection.
+    A dict of asset basenames for saving to the original full path to the asset.
 
   Raises:
     ValueError: Indicating an invalid filepath tensor.
   """
-  asset_source_filepath_list = []
+  # Map of target file names to original filenames
+  asset_filename_map = {}
 
   if assets_collection_to_add is None:
     tf_logging.info("No assets to save.")
-    return asset_source_filepath_list
+    return asset_filename_map
 
   # Iterate over the supplied asset collection, build the `AssetFile` proto
   # and add them to the collection with key `constants.ASSETS_KEY`, in the
@@ -494,15 +493,71 @@ def _maybe_save_assets(assets_collection_to_add=None):
     if not asset_source_filepath:
       raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
 
-    asset_source_filename = os.path.basename(asset_source_filepath)
+    asset_filename = _get_asset_filename_to_add(
+        asset_source_filepath, asset_filename_map)
 
     # Build `AssetFile` proto and add it to the asset collection in the graph.
-    _add_asset_to_collection(asset_source_filename, asset_tensor)
+    # Note that this should be done even when the file is a duplicate of an
+    # already-added file, as the tensor reference should still exist.
+    _add_asset_to_collection(asset_filename, asset_tensor)
 
-    asset_source_filepath_list.append(asset_source_filepath)
+    # In the cases where we are adding a duplicate, this will result in the
+    # last of the filepaths being the one used for copying the file to the
+    # SavedModel. Since the files in question are the same, it doesn't matter
+    # either way.
+    asset_filename_map[asset_filename] = asset_source_filepath
 
   tf_logging.info("Assets added to graph.")
-  return asset_source_filepath_list
+  return asset_filename_map
+
+
+def _get_asset_filename_to_add(asset_filepath, asset_filename_map):
+  """Get a unique basename to add to the SavedModel if this file is unseen.
+
+  Assets come from users as full paths, and we save them out to the
+  SavedModel as basenames. In some cases, the basenames collide. Here,
+  we dedupe asset basenames by first checking if the file is the same,
+  and, if different, generate and return an index-suffixed basename
+  that can be used to add the asset to the SavedModel.
+
+  Args:
+    asset_filepath: the full path to the asset that is being saved
+    asset_filename_map: a dict of filenames used for saving the asset in
+      the SavedModel to full paths from which the filenames were derived.
+
+  Returns:
+    Uniquified filename string if the file is not a duplicate, or the original
+    filename if the file has already been seen and saved.
+  """
+  asset_filename = os.path.basename(asset_filepath)
+
+  if asset_filename not in asset_filename_map:
+    # This is an unseen asset. Safe to add.
+    return asset_filename
+
+  other_asset_filepath = asset_filename_map[asset_filename]
+  if other_asset_filepath == asset_filepath:
+    # This is the same file, stored twice in the collection list. No need
+    # to make unique.
+    return asset_filename
+
+  # Else, asset_filename is in the map, and the filepath is different. Dedupe.
+  if not file_io.filecmp(asset_filepath, other_asset_filepath):
+    # Files are different; dedupe filenames.
+    return _get_unique_asset_filename(asset_filename, asset_filename_map)
+
+  # Files are the same; don't make unique.
+  return asset_filename
+
+
+def _get_unique_asset_filename(asset_filename, asset_filename_map):
+  i = 1
+  unique_filename = asset_filename
+  while unique_filename in asset_filename_map:
+    unique_filename = compat.as_bytes("_").join(
+        [compat.as_bytes(asset_filename), compat.as_bytes(str(i))])
+    i += 1
+  return unique_filename
 
 
 def _asset_path_from_tensor(path_tensor):
index 1b83d60..7302c77 100644 (file)
@@ -64,9 +64,12 @@ class SavedModelTest(test.TestCase):
     self.assertEqual(variable_value, v.eval())
 
   def _build_asset_collection(self, asset_file_name, asset_file_contents,
-                              asset_file_tensor_name):
+                              asset_file_tensor_name, asset_subdir=""):
+    parent_dir = os.path.join(
+        compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_subdir))
+    file_io.recursive_create_dir(parent_dir)
     asset_filepath = os.path.join(
-        compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name))
+        compat.as_bytes(parent_dir), compat.as_bytes(asset_file_name))
     file_io.write_string_to_file(asset_filepath, asset_file_contents)
     asset_file_tensor = constant_op.constant(
         asset_filepath, name=asset_file_tensor_name)
@@ -77,10 +80,11 @@ class SavedModelTest(test.TestCase):
   def _validate_asset_collection(self, export_dir, graph_collection_def,
                                  expected_asset_file_name,
                                  expected_asset_file_contents,
-                                 expected_asset_tensor_name):
+                                 expected_asset_tensor_name,
+                                 asset_id=0):
     assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
     asset = meta_graph_pb2.AssetFileDef()
-    assets_any[0].Unpack(asset)
+    assets_any[asset_id].Unpack(asset)
     assets_path = os.path.join(
         compat.as_bytes(export_dir),
         compat.as_bytes(constants.ASSETS_DIRECTORY),
@@ -634,6 +638,141 @@ class SavedModelTest(test.TestCase):
           compat.as_bytes("ignored.txt"))
       self.assertFalse(file_io.file_exists(ignored_asset_path))
 
+  def testAssetsNameCollisionDiffFile(self):
+    export_dir = self._get_export_dir("test_assets_name_collision_diff_file")
+    builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+    with self.test_session(graph=ops.Graph()) as sess:
+      self._init_and_validate_variable(sess, "v", 42)
+
+      asset_collection = self._build_asset_collection(
+          "hello42.txt", "foo bar bak", "asset_file_tensor",
+          asset_subdir="1")
+
+      asset_collection = self._build_asset_collection(
+          "hello42.txt", "foo bar baz", "asset_file_tensor_1",
+          asset_subdir="2")
+
+      builder.add_meta_graph_and_variables(
+          sess, ["foo"], assets_collection=asset_collection)
+
+    # Save the SavedModel to disk.
+    builder.save()
+
+    with self.test_session(graph=ops.Graph()) as sess:
+      foo_graph = loader.load(sess, ["foo"], export_dir)
+      self._validate_asset_collection(export_dir, foo_graph.collection_def,
+                                      "hello42.txt", "foo bar bak",
+                                      "asset_file_tensor:0")
+      self._validate_asset_collection(export_dir, foo_graph.collection_def,
+                                      "hello42.txt_1", "foo bar baz",
+                                      "asset_file_tensor_1:0",
+                                      asset_id=1)
+
+  def testAssetsNameCollisionSameFilepath(self):
+    export_dir = self._get_export_dir("test_assets_name_collision_same_path")
+    builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+    with self.test_session(graph=ops.Graph()) as sess:
+      self._init_and_validate_variable(sess, "v", 42)
+
+      asset_collection = self._build_asset_collection(
+          "hello42.txt", "foo bar baz", "asset_file_tensor")
+
+      asset_collection = self._build_asset_collection(
+          "hello42.txt", "foo bar baz", "asset_file_tensor_1")
+
+      builder.add_meta_graph_and_variables(
+          sess, ["foo"], assets_collection=asset_collection)
+
+    # Save the SavedModel to disk.
+    builder.save()
+
+    with self.test_session(graph=ops.Graph()) as sess:
+      foo_graph = loader.load(sess, ["foo"], export_dir)
+      self._validate_asset_collection(export_dir, foo_graph.collection_def,
+                                      "hello42.txt", "foo bar baz",
+                                      "asset_file_tensor:0")
+      # The second tensor should be recorded, but the same.
+      self._validate_asset_collection(export_dir, foo_graph.collection_def,
+                                      "hello42.txt", "foo bar baz",
+                                      "asset_file_tensor_1:0",
+                                      asset_id=1)
+      ignored_asset_path = os.path.join(
+          compat.as_bytes(export_dir),
+          compat.as_bytes(constants.ASSETS_DIRECTORY),
+          compat.as_bytes("hello42.txt_1"))
+      self.assertFalse(file_io.file_exists(ignored_asset_path))
+
+  def testAssetsNameCollisionSameFile(self):
+    export_dir = self._get_export_dir("test_assets_name_collision_same_file")
+    builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+    with self.test_session(graph=ops.Graph()) as sess:
+      self._init_and_validate_variable(sess, "v", 42)
+
+      asset_collection = self._build_asset_collection(
+          "hello42.txt", "foo bar baz", "asset_file_tensor",
+          asset_subdir="1")
+
+      asset_collection = self._build_asset_collection(
+          "hello42.txt", "foo bar baz", "asset_file_tensor_1",
+          asset_subdir="2")
+
+      builder.add_meta_graph_and_variables(
+          sess, ["foo"], assets_collection=asset_collection)
+
+    # Save the SavedModel to disk.
+    builder.save()
+
+    with self.test_session(graph=ops.Graph()) as sess:
+      foo_graph = loader.load(sess, ["foo"], export_dir)
+      self._validate_asset_collection(export_dir, foo_graph.collection_def,
+                                      "hello42.txt", "foo bar baz",
+                                      "asset_file_tensor:0")
+      # The second tensor should be recorded, but the same.
+      self._validate_asset_collection(export_dir, foo_graph.collection_def,
+                                      "hello42.txt", "foo bar baz",
+                                      "asset_file_tensor_1:0",
+                                      asset_id=1)
+      ignored_asset_path = os.path.join(
+          compat.as_bytes(export_dir),
+          compat.as_bytes(constants.ASSETS_DIRECTORY),
+          compat.as_bytes("hello42.txt_1"))
+      self.assertFalse(file_io.file_exists(ignored_asset_path))
+
+  def testAssetsNameCollisionManyFiles(self):
+    export_dir = self._get_export_dir("test_assets_name_collision_many_files")
+    builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+    with self.test_session(graph=ops.Graph()) as sess:
+      self._init_and_validate_variable(sess, "v", 42)
+
+      for i in range(5):
+        idx = str(i)
+        asset_collection = self._build_asset_collection(
+            "hello42.txt", "foo bar baz " + idx, "asset_file_tensor_" + idx,
+            asset_subdir=idx)
+
+      builder.add_meta_graph_and_variables(
+          sess, ["foo"], assets_collection=asset_collection)
+
+    # Save the SavedModel to disk.
+    builder.save()
+
+    with self.test_session(graph=ops.Graph()) as sess:
+      foo_graph = loader.load(sess, ["foo"], export_dir)
+      for i in range(1, 5):
+        idx = str(i)
+        self._validate_asset_collection(
+            export_dir, foo_graph.collection_def, "hello42.txt_" + idx,
+            "foo bar baz " + idx, "asset_file_tensor_{}:0".format(idx),
+            asset_id=i)
+
+      self._validate_asset_collection(export_dir, foo_graph.collection_def,
+                                      "hello42.txt", "foo bar baz 0",
+                                      "asset_file_tensor_0:0")
+
   def testCustomMainOp(self):
     export_dir = self._get_export_dir("test_main_op")
     builder = saved_model_builder.SavedModelBuilder(export_dir)