Improvements to util/nest.py and data/util/nest.py
authorAkshay Modi <nareshmodi@google.com>
Tue, 22 May 2018 01:53:54 +0000 (18:53 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 22 May 2018 01:56:14 +0000 (18:56 -0700)
Changes:
- Add a cache for type -> is_sequence to speed up Flatten/IsSequence
- Update data/util/nest.py flatten to use C Flatten

Before:
entry {
  name: "EagerLinearRegressionBenchmark.eager_train_cpu"
  iters: 2000
  wall_time: 1.91852378845
  extras {
    key: "examples_per_sec"
    value {
      double_value: 66717.9634521
    }
  }
}

After:
entry {
  name: "EagerLinearRegressionBenchmark.eager_train_cpu"
  iters: 2000
  wall_time: 1.74479198456
  extras {
    key: "examples_per_sec"
    value {
      double_value: 73361.1806638
    }
  }
}
PiperOrigin-RevId: 197497854

tensorflow/python/BUILD
tensorflow/python/data/util/nest.py
tensorflow/python/framework/sparse_tensor.py
tensorflow/python/util/util.cc
tensorflow/python/util/util.h
tensorflow/python/util/util.i

index f714d1f..7201e12 100644 (file)
@@ -313,8 +313,8 @@ cc_library(
     hdrs = ["util/util.h"],
     deps = [
         ":safe_ptr",
-        "//tensorflow/core:framework",
         "//tensorflow/core:lib",
+        "//tensorflow/core:lib_internal",
         "//util/python:python_headers",
     ],
 )
index eff6e02..9af2e9b 100644 (file)
 """## Functions for working with arbitrarily nested sequences of elements.
 
 NOTE(mrry): This fork of the `tensorflow.python.util.nest` module
-makes three changes:
+makes two changes:
 
-1. It adds support for dictionaries as a level of nesting in nested structures.
-2. It removes support for lists as a level of nesting in nested structures.
-3. It adds support for `SparseTensorValue` as an atomic element.
+1. It removes support for lists as a level of nesting in nested structures.
+2. It adds support for `SparseTensorValue` as an atomic element.
 
-The motivation for this change is threefold:
+The motivation for this change is twofold:
 
-1. Many input-processing functions (e.g. `tf.parse_example()`) return
-   dictionaries, and we would like to support them natively in datasets.
-2. It seems more natural for lists to be treated (e.g. in Dataset constructors)
+1. It seems more natural for lists to be treated (e.g. in Dataset constructors)
    as tensors, rather than lists of (lists of...) tensors.
-3. This is needed because `SparseTensorValue` is implemented as a `namedtuple`
+2. This is needed because `SparseTensorValue` is implemented as a `namedtuple`
    that would normally be flattened and we want to be able to create sparse
    tensor from `SparseTensorValue's similarly to creating tensors from numpy
    arrays.
@@ -43,6 +40,7 @@ import collections as _collections
 
 import six as _six
 
+from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
 from tensorflow.python.framework import sparse_tensor as _sparse_tensor
 
 
@@ -99,15 +97,6 @@ def _yield_value(iterable):
       yield value
 
 
-def _yield_flat_nest(nest):
-  for n in _yield_value(nest):
-    if is_sequence(n):
-      for ni in _yield_flat_nest(n):
-        yield ni
-    else:
-      yield n
-
-
 def is_sequence(seq):
   """Returns a true if `seq` is a Sequence or dict (except strings/lists).
 
@@ -123,9 +112,7 @@ def is_sequence(seq):
     True if the sequence is a not a string or list and is a
     collections.Sequence.
   """
-  return (isinstance(seq, (_collections.Sequence, dict)) and
-          not isinstance(seq, _sparse_tensor.SparseTensorValue) and
-          not isinstance(seq, (list, _six.string_types)))
+  return _pywrap_tensorflow.IsSequenceForData(seq)
 
 
 def flatten(nest):
@@ -140,7 +127,7 @@ def flatten(nest):
   Returns:
     A Python list, the flattened version of the input.
   """
-  return list(_yield_flat_nest(nest)) if is_sequence(nest) else [nest]
+  return _pywrap_tensorflow.FlattenForData(nest)
 
 
 def _recursive_assert_same_structure(nest1, nest2, check_types):
@@ -536,4 +523,3 @@ def map_structure_up_to(shallow_tree, func, *inputs):
 
   results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
   return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
-
index 1fe81e5..6a5c646 100644 (file)
@@ -20,6 +20,7 @@ from __future__ import print_function
 
 import collections
 
+from tensorflow.python import pywrap_tensorflow
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import tensor_util
@@ -225,6 +226,7 @@ class SparseTensor(_TensorLike):
 SparseTensorValue = collections.namedtuple(
     "SparseTensorValue", ["indices", "values", "dense_shape"])
 tf_export("SparseTensorValue")(SparseTensorValue)
+pywrap_tensorflow.RegisterSparseTensorValueClass(SparseTensorValue)
 
 
 @tf_export("convert_to_tensor_or_sparse_tensor")
index 70aee4a..386a6fb 100644 (file)
@@ -14,8 +14,13 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/python/util/util.h"
 
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/lib/gtl/map_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
 #include "tensorflow/python/lib/core/safe_ptr.h"
 
 namespace tensorflow {
@@ -25,6 +30,7 @@ namespace {
 
 // Type object for collections.Sequence. This is set by RegisterSequenceClass.
 PyObject* CollectionsSequenceType = nullptr;
+PyTypeObject* SparseTensorValueType = nullptr;
 
 bool WarnedThatSetIsNotSequence = false;
 
@@ -135,6 +141,12 @@ class ValIterator {
   Py_ssize_t index_;
 };
 
+mutex g_type_to_sequence_map(LINKER_INITIALIZED);
+std::unordered_map<PyTypeObject*, bool>* IsTypeSequenceMap() {
+  static auto* const m = new std::unordered_map<PyTypeObject*, bool>;
+  return m;
+}
+
 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
 // Returns 0 otherwise.
 // Returns -1 if an error occurred.
@@ -155,64 +167,137 @@ int IsSequenceHelper(PyObject* o) {
             .c_str());
     return -1;
   }
+
+  // Try not to return to Python - see if the type has already been seen
+  // before.
+
+  // NOTE: It's not clear whether the lock is required (we should be holding the
+  // python GIL in this code already).
+  mutex_lock l(g_type_to_sequence_map);
+  auto* type_to_sequence_map = IsTypeSequenceMap();
+  auto* type = Py_TYPE(o);
+
+  auto it = type_to_sequence_map->find(type);
+  if (it != type_to_sequence_map->end()) {
+    return it->second;
+  }
+
   int is_instance = PyObject_IsInstance(o, CollectionsSequenceType);
+
+  // Don't cache a failed is_instance check.
   if (is_instance == -1) return -1;
-  return static_cast<int>(is_instance != 0 && !IsString(o));
+
+  bool is_sequence = static_cast<int>(is_instance != 0 && !IsString(o));
+
+  // NOTE: This is never decref'd, but we don't want the type to get deleted
+  // as long as it is in the map. This should not be too much of a
+  // leak, as there should only be a relatively small number of types in the
+  // map, and an even smaller number that are eligible for decref.
+  Py_INCREF(type);
+  type_to_sequence_map->insert({type, is_sequence});
+
+  return is_sequence;
 }
 
-bool FlattenHelper(PyObject* nested, PyObject* list) {
-  // if nested is not a sequence, append itself and exit
-  int is_seq = IsSequenceHelper(nested);
-  if (is_seq == -1) return false;
-  if (!is_seq) {
-    return PyList_Append(list, nested) != -1;
+bool IsSparseTensorValueType(PyObject* o) {
+  if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
+    return false;
   }
 
-  // if nested if dictionary, sort it by key and recurse on each value
-  if (PyDict_Check(nested)) {
-    PyObject* keys = PyDict_Keys(nested);
-    if (PyList_Sort(keys) == -1) return false;
-    Py_ssize_t size = PyList_Size(keys);
-    for (Py_ssize_t i = 0; i < size; ++i) {
-      // We know that key and val will not be deleted because nested owns
-      // a reference to them and callers of flatten must not modify nested
-      // while the method is running.
-      PyObject* key = PyList_GET_ITEM(keys, i);
-      PyObject* val = PyDict_GetItem(nested, key);
-      if (Py_EnterRecursiveCall(" in flatten")) {
-        Py_DECREF(keys);
-        return false;
-      }
-      const bool success = FlattenHelper(val, list);
-      Py_LeaveRecursiveCall();
-      if (!success) {
-        Py_DECREF(keys);
-        return false;
-      }
-    }
-    Py_DECREF(keys);
-    return true;
+  return PyObject_TypeCheck(o, SparseTensorValueType) == 1;
+}
+
+int IsSequenceForDataHelper(PyObject* o) {
+  return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
+         !IsSparseTensorValueType(o);
+}
+
+bool GetNextValuesForDict(PyObject* nested,
+                          std::vector<Safe_PyObjectPtr>* next_values) {
+  std::vector<PyObject*> result;
+
+  PyObject* keys = PyDict_Keys(nested);
+  if (PyList_Sort(keys) == -1) return false;
+  Py_ssize_t size = PyList_Size(keys);
+  for (Py_ssize_t i = 0; i < size; ++i) {
+    // We know that key and item will not be deleted because nested owns
+    // a reference to them and callers of flatten must not modify nested
+    // while the method is running.
+    PyObject* key = PyList_GET_ITEM(keys, i);
+    PyObject* item = PyDict_GetItem(nested, key);
+    Py_INCREF(item);
+    next_values->emplace_back(item);
   }
+  Py_DECREF(keys);
+  return true;
+}
 
-  // iterate and recurse
+bool GetNextValuesForIterable(PyObject* nested,
+                              std::vector<Safe_PyObjectPtr>* next_values) {
   PyObject* item;
   PyObject* iterator = PyObject_GetIter(nested);
   while ((item = PyIter_Next(iterator)) != nullptr) {
+    next_values->emplace_back(item);
+  }
+  Py_DECREF(iterator);
+  return true;
+}
+
+// GetNextValues returns the values that the FlattenHelper function will recurse
+// over next.
+bool GetNextValues(PyObject* nested,
+                   std::vector<Safe_PyObjectPtr>* next_values) {
+  if (PyDict_Check(nested)) {
+    // if nested is dictionary, sort it by key and recurse on each value
+    return GetNextValuesForDict(nested, next_values);
+  }
+  // iterate and recurse
+  return GetNextValuesForIterable(nested, next_values);
+}
+
+// Similar to above, just specialized for the functions in the data pacakage.
+bool GetNextValuesForData(PyObject* nested,
+                          std::vector<Safe_PyObjectPtr>* next_values) {
+  if (PyDict_Check(nested)) {
+    // if nested is dictionary, sort it by key and recurse on each value
+    return GetNextValuesForDict(nested, next_values);
+  } else if (IsSparseTensorValueType(nested)) {
+    // if nested is a SparseTensorValue, just return itself as a single item
+    Py_INCREF(nested);
+    next_values->emplace_back(nested);
+    return true;
+  }
+  // iterate and recurse
+  return GetNextValuesForIterable(nested, next_values);
+}
+
+bool FlattenHelper(
+    PyObject* nested, PyObject* list,
+    const std::function<int(PyObject*)>& is_sequence_helper,
+    const std::function<bool(PyObject*, std::vector<Safe_PyObjectPtr>*)>&
+        next_values_getter) {
+  // if nested is not a sequence, append itself and exit
+  int is_seq = is_sequence_helper(nested);
+  if (is_seq == -1) return false;
+  if (!is_seq) {
+    return PyList_Append(list, nested) != -1;
+  }
+
+  std::vector<Safe_PyObjectPtr> next_values;
+  // Get the next values to recurse over.
+  if (!next_values_getter(nested, &next_values)) return false;
+
+  for (const auto& item : next_values) {
     if (Py_EnterRecursiveCall(" in flatten")) {
-      Py_DECREF(iterator);
-      Py_DECREF(item);
       return false;
     }
-    bool success = FlattenHelper(item, list);
+    const bool success =
+        FlattenHelper(item.get(), list, is_sequence_helper, next_values_getter);
     Py_LeaveRecursiveCall();
     if (!success) {
-      Py_DECREF(iterator);
-      Py_DECREF(item);
       return false;
     }
-    Py_DECREF(item);
   }
-  Py_DECREF(iterator);
   return true;
 }
 
@@ -351,7 +436,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types,
   }
 }
 
-}  // anonymous namespace
+}  // namespace
 
 void RegisterSequenceClass(PyObject* sequence_class) {
   if (!PyType_Check(sequence_class)) {
@@ -366,11 +451,38 @@ void RegisterSequenceClass(PyObject* sequence_class) {
   CollectionsSequenceType = sequence_class;
 }
 
+void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
+  if (!PyType_Check(sparse_tensor_value_class)) {
+    PyErr_SetString(
+        PyExc_TypeError,
+        tensorflow::strings::StrCat(
+            "Expecting a class definition for `SparseTensorValue`. Got ",
+            Py_TYPE(sparse_tensor_value_class)->tp_name)
+            .c_str());
+    return;
+  }
+  SparseTensorValueType =
+      reinterpret_cast<PyTypeObject*>(sparse_tensor_value_class);
+}
+
 bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
 
 PyObject* Flatten(PyObject* nested) {
   PyObject* list = PyList_New(0);
-  if (FlattenHelper(nested, list)) {
+  if (FlattenHelper(nested, list, IsSequenceHelper, GetNextValues)) {
+    return list;
+  } else {
+    Py_DECREF(list);
+    return nullptr;
+  }
+}
+
+bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
+
+PyObject* FlattenForData(PyObject* nested) {
+  PyObject* list = PyList_New(0);
+  if (FlattenHelper(nested, list, IsSequenceForDataHelper,
+                    GetNextValuesForData)) {
     return list;
   } else {
     Py_DECREF(list);
index c325baa..9851c11 100644 (file)
@@ -118,6 +118,30 @@ PyObject* Flatten(PyObject* nested);
 // the type from the module. This approach also requires some trigger from
 // Python so that we know that Python interpreter had been initialzied.
 void RegisterSequenceClass(PyObject* sequence_class);
+// Similar to the above function, except for the
+// sparse_tensor.SparseTensorValue class.
+void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
+
+// The tensorflow.python.data package has its own nest utility that follows very
+// slightly different semantics for its functions than the tensorflow.python
+// nest utility. Returns a true if its input is a collections.Sequence (except
+// strings).
+//
+// Main differences are (this is copied from nest.py in the
+// tensorflow.data.util):
+//
+// 1. It removes support for lists as a level of nesting in nested structures.
+// 2. It adds support for `SparseTensorValue` as an atomic element.
+
+// IsSequence specialized for the data package. Additional comments about
+// difference in functionality can be found in nest.py in tensorflow.data.util
+// and in the comments for Flatten above.
+bool IsSequenceForData(PyObject* o);
+
+// IsSequence specialized for the data package. Additional comments about
+// difference in functionality can be found in nest.py in tensorflow.data.util
+// and in the comments for Flatten above.
+PyObject* FlattenForData(PyObject* nested);
 
 }  // namespace swig
 }  // namespace tensorflow
index b7f201b..9f3b11b 100644 (file)
@@ -31,6 +31,9 @@ limitations under the License.
 %unignore tensorflow::swig::RegisterSequenceClass;
 %noexception tensorflow::swig::RegisterSequenceClass;
 
+%unignore tensorflow::swig::RegisterSparseTensorValueClass;
+%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+
 %unignore tensorflow::swig::IsSequence;
 %noexception tensorflow::swig::IsSequence;
 
@@ -46,6 +49,12 @@ limitations under the License.
 %unignore tensorflow::swig::Flatten;
 %noexception tensorflow::swig::Flatten;
 
+%unignore tensorflow::swig::IsSequenceForData;
+%noexception tensorflow::swig::IsSequenceForData;
+
+%unignore tensorflow::swig::FlattenForData;
+%noexception tensorflow::swig::FlattenForData;
+
 %include "tensorflow/python/util/util.h"
 
 %unignoreall