From e4dcf28ad1c56c3a8e41ca52e7d87169eb7f93d5 Mon Sep 17 00:00:00 2001 From: Akshay Modi Date: Mon, 21 May 2018 18:53:54 -0700 Subject: [PATCH] Improvements to util/nest.py and data/util/nest.py Changes: - Add a cache for type -> is_sequence to speed up Flatten/IsSequence - Update data/util/nest.py flatten to use C Flatten Before: entry { name: "EagerLinearRegressionBenchmark.eager_train_cpu" iters: 2000 wall_time: 1.91852378845 extras { key: "examples_per_sec" value { double_value: 66717.9634521 } } } After: entry { name: "EagerLinearRegressionBenchmark.eager_train_cpu" iters: 2000 wall_time: 1.74479198456 extras { key: "examples_per_sec" value { double_value: 73361.1806638 } } } PiperOrigin-RevId: 197497854 --- tensorflow/python/BUILD | 2 +- tensorflow/python/data/util/nest.py | 32 ++--- tensorflow/python/framework/sparse_tensor.py | 2 + tensorflow/python/util/util.cc | 194 +++++++++++++++++++++------ tensorflow/python/util/util.h | 24 ++++ tensorflow/python/util/util.i | 9 ++ 6 files changed, 198 insertions(+), 65 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f714d1f..7201e12 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -313,8 +313,8 @@ cc_library( hdrs = ["util/util.h"], deps = [ ":safe_ptr", - "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//util/python:python_headers", ], ) diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py index eff6e02..9af2e9b 100644 --- a/tensorflow/python/data/util/nest.py +++ b/tensorflow/python/data/util/nest.py @@ -17,19 +17,16 @@ """## Functions for working with arbitrarily nested sequences of elements. NOTE(mrry): This fork of the `tensorflow.python.util.nest` module -makes three changes: +makes two changes: -1. It adds support for dictionaries as a level of nesting in nested structures. -2. It removes support for lists as a level of nesting in nested structures. -3. It adds support for `SparseTensorValue` as an atomic element. +1. It removes support for lists as a level of nesting in nested structures. +2. It adds support for `SparseTensorValue` as an atomic element. -The motivation for this change is threefold: +The motivation for this change is twofold: -1. Many input-processing functions (e.g. `tf.parse_example()`) return - dictionaries, and we would like to support them natively in datasets. -2. It seems more natural for lists to be treated (e.g. in Dataset constructors) +1. It seems more natural for lists to be treated (e.g. in Dataset constructors) as tensors, rather than lists of (lists of...) tensors. -3. This is needed because `SparseTensorValue` is implemented as a `namedtuple` +2. This is needed because `SparseTensorValue` is implemented as a `namedtuple` that would normally be flattened and we want to be able to create sparse tensor from `SparseTensorValue's similarly to creating tensors from numpy arrays. @@ -43,6 +40,7 @@ import collections as _collections import six as _six +from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow from tensorflow.python.framework import sparse_tensor as _sparse_tensor @@ -99,15 +97,6 @@ def _yield_value(iterable): yield value -def _yield_flat_nest(nest): - for n in _yield_value(nest): - if is_sequence(n): - for ni in _yield_flat_nest(n): - yield ni - else: - yield n - - def is_sequence(seq): """Returns a true if `seq` is a Sequence or dict (except strings/lists). @@ -123,9 +112,7 @@ def is_sequence(seq): True if the sequence is a not a string or list and is a collections.Sequence. """ - return (isinstance(seq, (_collections.Sequence, dict)) and - not isinstance(seq, _sparse_tensor.SparseTensorValue) and - not isinstance(seq, (list, _six.string_types))) + return _pywrap_tensorflow.IsSequenceForData(seq) def flatten(nest): @@ -140,7 +127,7 @@ def flatten(nest): Returns: A Python list, the flattened version of the input. """ - return list(_yield_flat_nest(nest)) if is_sequence(nest) else [nest] + return _pywrap_tensorflow.FlattenForData(nest) def _recursive_assert_same_structure(nest1, nest2, check_types): @@ -536,4 +523,3 @@ def map_structure_up_to(shallow_tree, func, *inputs): results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] return pack_sequence_as(structure=shallow_tree, flat_sequence=results) - diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py index 1fe81e5..6a5c646 100644 --- a/tensorflow/python/framework/sparse_tensor.py +++ b/tensorflow/python/framework/sparse_tensor.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections +from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util @@ -225,6 +226,7 @@ class SparseTensor(_TensorLike): SparseTensorValue = collections.namedtuple( "SparseTensorValue", ["indices", "values", "dense_shape"]) tf_export("SparseTensorValue")(SparseTensorValue) +pywrap_tensorflow.RegisterSparseTensorValueClass(SparseTensorValue) @tf_export("convert_to_tensor_or_sparse_tensor") diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc index 70aee4a..386a6fb 100644 --- a/tensorflow/python/util/util.cc +++ b/tensorflow/python/util/util.cc @@ -14,8 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/python/util/util.h" +#include +#include + +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/python/lib/core/safe_ptr.h" namespace tensorflow { @@ -25,6 +30,7 @@ namespace { // Type object for collections.Sequence. This is set by RegisterSequenceClass. PyObject* CollectionsSequenceType = nullptr; +PyTypeObject* SparseTensorValueType = nullptr; bool WarnedThatSetIsNotSequence = false; @@ -135,6 +141,12 @@ class ValIterator { Py_ssize_t index_; }; +mutex g_type_to_sequence_map(LINKER_INITIALIZED); +std::unordered_map* IsTypeSequenceMap() { + static auto* const m = new std::unordered_map; + return m; +} + // Returns 1 if `o` is considered a sequence for the purposes of Flatten(). // Returns 0 otherwise. // Returns -1 if an error occurred. @@ -155,64 +167,137 @@ int IsSequenceHelper(PyObject* o) { .c_str()); return -1; } + + // Try not to return to Python - see if the type has already been seen + // before. + + // NOTE: It's not clear whether the lock is required (we should be holding the + // python GIL in this code already). + mutex_lock l(g_type_to_sequence_map); + auto* type_to_sequence_map = IsTypeSequenceMap(); + auto* type = Py_TYPE(o); + + auto it = type_to_sequence_map->find(type); + if (it != type_to_sequence_map->end()) { + return it->second; + } + int is_instance = PyObject_IsInstance(o, CollectionsSequenceType); + + // Don't cache a failed is_instance check. if (is_instance == -1) return -1; - return static_cast(is_instance != 0 && !IsString(o)); + + bool is_sequence = static_cast(is_instance != 0 && !IsString(o)); + + // NOTE: This is never decref'd, but we don't want the type to get deleted + // as long as it is in the map. This should not be too much of a + // leak, as there should only be a relatively small number of types in the + // map, and an even smaller number that are eligible for decref. + Py_INCREF(type); + type_to_sequence_map->insert({type, is_sequence}); + + return is_sequence; } -bool FlattenHelper(PyObject* nested, PyObject* list) { - // if nested is not a sequence, append itself and exit - int is_seq = IsSequenceHelper(nested); - if (is_seq == -1) return false; - if (!is_seq) { - return PyList_Append(list, nested) != -1; +bool IsSparseTensorValueType(PyObject* o) { + if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) { + return false; } - // if nested if dictionary, sort it by key and recurse on each value - if (PyDict_Check(nested)) { - PyObject* keys = PyDict_Keys(nested); - if (PyList_Sort(keys) == -1) return false; - Py_ssize_t size = PyList_Size(keys); - for (Py_ssize_t i = 0; i < size; ++i) { - // We know that key and val will not be deleted because nested owns - // a reference to them and callers of flatten must not modify nested - // while the method is running. - PyObject* key = PyList_GET_ITEM(keys, i); - PyObject* val = PyDict_GetItem(nested, key); - if (Py_EnterRecursiveCall(" in flatten")) { - Py_DECREF(keys); - return false; - } - const bool success = FlattenHelper(val, list); - Py_LeaveRecursiveCall(); - if (!success) { - Py_DECREF(keys); - return false; - } - } - Py_DECREF(keys); - return true; + return PyObject_TypeCheck(o, SparseTensorValueType) == 1; +} + +int IsSequenceForDataHelper(PyObject* o) { + return IsSequenceHelper(o) == 1 && !PyList_Check(o) && + !IsSparseTensorValueType(o); +} + +bool GetNextValuesForDict(PyObject* nested, + std::vector* next_values) { + std::vector result; + + PyObject* keys = PyDict_Keys(nested); + if (PyList_Sort(keys) == -1) return false; + Py_ssize_t size = PyList_Size(keys); + for (Py_ssize_t i = 0; i < size; ++i) { + // We know that key and item will not be deleted because nested owns + // a reference to them and callers of flatten must not modify nested + // while the method is running. + PyObject* key = PyList_GET_ITEM(keys, i); + PyObject* item = PyDict_GetItem(nested, key); + Py_INCREF(item); + next_values->emplace_back(item); } + Py_DECREF(keys); + return true; +} - // iterate and recurse +bool GetNextValuesForIterable(PyObject* nested, + std::vector* next_values) { PyObject* item; PyObject* iterator = PyObject_GetIter(nested); while ((item = PyIter_Next(iterator)) != nullptr) { + next_values->emplace_back(item); + } + Py_DECREF(iterator); + return true; +} + +// GetNextValues returns the values that the FlattenHelper function will recurse +// over next. +bool GetNextValues(PyObject* nested, + std::vector* next_values) { + if (PyDict_Check(nested)) { + // if nested is dictionary, sort it by key and recurse on each value + return GetNextValuesForDict(nested, next_values); + } + // iterate and recurse + return GetNextValuesForIterable(nested, next_values); +} + +// Similar to above, just specialized for the functions in the data pacakage. +bool GetNextValuesForData(PyObject* nested, + std::vector* next_values) { + if (PyDict_Check(nested)) { + // if nested is dictionary, sort it by key and recurse on each value + return GetNextValuesForDict(nested, next_values); + } else if (IsSparseTensorValueType(nested)) { + // if nested is a SparseTensorValue, just return itself as a single item + Py_INCREF(nested); + next_values->emplace_back(nested); + return true; + } + // iterate and recurse + return GetNextValuesForIterable(nested, next_values); +} + +bool FlattenHelper( + PyObject* nested, PyObject* list, + const std::function& is_sequence_helper, + const std::function*)>& + next_values_getter) { + // if nested is not a sequence, append itself and exit + int is_seq = is_sequence_helper(nested); + if (is_seq == -1) return false; + if (!is_seq) { + return PyList_Append(list, nested) != -1; + } + + std::vector next_values; + // Get the next values to recurse over. + if (!next_values_getter(nested, &next_values)) return false; + + for (const auto& item : next_values) { if (Py_EnterRecursiveCall(" in flatten")) { - Py_DECREF(iterator); - Py_DECREF(item); return false; } - bool success = FlattenHelper(item, list); + const bool success = + FlattenHelper(item.get(), list, is_sequence_helper, next_values_getter); Py_LeaveRecursiveCall(); if (!success) { - Py_DECREF(iterator); - Py_DECREF(item); return false; } - Py_DECREF(item); } - Py_DECREF(iterator); return true; } @@ -351,7 +436,7 @@ bool AssertSameStructureHelper(PyObject* o1, PyObject* o2, bool check_types, } } -} // anonymous namespace +} // namespace void RegisterSequenceClass(PyObject* sequence_class) { if (!PyType_Check(sequence_class)) { @@ -366,11 +451,38 @@ void RegisterSequenceClass(PyObject* sequence_class) { CollectionsSequenceType = sequence_class; } +void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) { + if (!PyType_Check(sparse_tensor_value_class)) { + PyErr_SetString( + PyExc_TypeError, + tensorflow::strings::StrCat( + "Expecting a class definition for `SparseTensorValue`. Got ", + Py_TYPE(sparse_tensor_value_class)->tp_name) + .c_str()); + return; + } + SparseTensorValueType = + reinterpret_cast(sparse_tensor_value_class); +} + bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; } PyObject* Flatten(PyObject* nested) { PyObject* list = PyList_New(0); - if (FlattenHelper(nested, list)) { + if (FlattenHelper(nested, list, IsSequenceHelper, GetNextValues)) { + return list; + } else { + Py_DECREF(list); + return nullptr; + } +} + +bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; } + +PyObject* FlattenForData(PyObject* nested) { + PyObject* list = PyList_New(0); + if (FlattenHelper(nested, list, IsSequenceForDataHelper, + GetNextValuesForData)) { return list; } else { Py_DECREF(list); diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h index c325baa..9851c11 100644 --- a/tensorflow/python/util/util.h +++ b/tensorflow/python/util/util.h @@ -118,6 +118,30 @@ PyObject* Flatten(PyObject* nested); // the type from the module. This approach also requires some trigger from // Python so that we know that Python interpreter had been initialzied. void RegisterSequenceClass(PyObject* sequence_class); +// Similar to the above function, except for the +// sparse_tensor.SparseTensorValue class. +void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class); + +// The tensorflow.python.data package has its own nest utility that follows very +// slightly different semantics for its functions than the tensorflow.python +// nest utility. Returns a true if its input is a collections.Sequence (except +// strings). +// +// Main differences are (this is copied from nest.py in the +// tensorflow.data.util): +// +// 1. It removes support for lists as a level of nesting in nested structures. +// 2. It adds support for `SparseTensorValue` as an atomic element. + +// IsSequence specialized for the data package. Additional comments about +// difference in functionality can be found in nest.py in tensorflow.data.util +// and in the comments for Flatten above. +bool IsSequenceForData(PyObject* o); + +// IsSequence specialized for the data package. Additional comments about +// difference in functionality can be found in nest.py in tensorflow.data.util +// and in the comments for Flatten above. +PyObject* FlattenForData(PyObject* nested); } // namespace swig } // namespace tensorflow diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i index b7f201b..9f3b11b 100644 --- a/tensorflow/python/util/util.i +++ b/tensorflow/python/util/util.i @@ -31,6 +31,9 @@ limitations under the License. %unignore tensorflow::swig::RegisterSequenceClass; %noexception tensorflow::swig::RegisterSequenceClass; +%unignore tensorflow::swig::RegisterSparseTensorValueClass; +%noexception tensorflow::swig::RegisterSparseTensorValueClass; + %unignore tensorflow::swig::IsSequence; %noexception tensorflow::swig::IsSequence; @@ -46,6 +49,12 @@ limitations under the License. %unignore tensorflow::swig::Flatten; %noexception tensorflow::swig::Flatten; +%unignore tensorflow::swig::IsSequenceForData; +%noexception tensorflow::swig::IsSequenceForData; + +%unignore tensorflow::swig::FlattenForData; +%noexception tensorflow::swig::FlattenForData; + %include "tensorflow/python/util/util.h" %unignoreall -- 2.7.4