From 7e2e57410eb40c0512dc573955fd256a6c787741 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 31 May 2018 06:05:04 -0700 Subject: [PATCH] implementation of sparse_to_dense PiperOrigin-RevId: 198710452 --- tensorflow/contrib/lite/build_def.bzl | 1 + tensorflow/contrib/lite/builtin_op_data.h | 4 + tensorflow/contrib/lite/builtin_ops.h | 1 + .../contrib/lite/g3doc/tf_ops_compatibility.md | 15 ++ tensorflow/contrib/lite/kernels/BUILD | 14 ++ .../kernels/internal/reference/reference_ops.h | 36 +++ tensorflow/contrib/lite/kernels/register.cc | 2 + tensorflow/contrib/lite/kernels/sparse_to_dense.cc | 275 +++++++++++++++++++++ .../contrib/lite/kernels/sparse_to_dense_test.cc | 155 ++++++++++++ tensorflow/contrib/lite/model.cc | 10 + tensorflow/contrib/lite/nnapi_delegate.cc | 1 + tensorflow/contrib/lite/schema/schema.fbs | 6 + tensorflow/contrib/lite/schema/schema_generated.h | 141 ++++++++++- .../contrib/lite/testing/generate_examples.py | 77 +++++- tensorflow/contrib/lite/toco/export_tensorflow.cc | 19 ++ .../propagate_array_data_types.cc | 10 + .../graph_transformations/propagate_fixed_sizes.cc | 32 +++ tensorflow/contrib/lite/toco/import_tensorflow.cc | 20 ++ tensorflow/contrib/lite/toco/model.h | 14 ++ tensorflow/contrib/lite/toco/tflite/operator.cc | 23 ++ .../contrib/lite/toco/tflite/operator_test.cc | 9 + tensorflow/contrib/lite/toco/tooling_util.cc | 1 + 22 files changed, 859 insertions(+), 7 deletions(-) create mode 100644 tensorflow/contrib/lite/kernels/sparse_to_dense.cc create mode 100644 tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index c8820ab..b9e40cc 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -239,6 +239,7 @@ def generated_test_models(): "softmax", "space_to_batch_nd", "space_to_depth", + "sparse_to_dense", "split", "squeeze", "strided_slice", diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 8660c65..52ab9ee 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -236,6 +236,10 @@ typedef struct { int stride_height; } TfLiteTransposeConvParams; +typedef struct { + bool validate_indices; +} TfLiteSparseToDenseParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 24a9b0f..c797e35 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -93,6 +93,7 @@ typedef enum { kTfLiteBuiltinSlice = 65, kTfLiteBuiltinSin = 66, kTfLiteBuiltinTransposeConv = 67, + kTfLiteBuiltinSparseToDense = 68, } TfLiteBuiltinOperator; #ifdef __cplusplus diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 244919b..27e7d25 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -595,6 +595,21 @@ Outputs { } ``` +**SPARSE_TO_DENSE** + +``` +Inputs { + 0: 0D or 1D or 2D tensor + 1: 1D tensor + 2: 0D or 1D tensor + 3: 0D tensor + 4: a boolean value +} +Outputs { + 0: Dense Tensor of shape output_shape. Has the same type as sparse_values. +} +``` + **SPLIT** ``` diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index b7291dd..0af659b 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -170,6 +170,7 @@ cc_library( "slice.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "sparse_to_dense.cc", "split.cc", "squeeze.cc", "strided_slice.cc", @@ -934,6 +935,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "sparse_to_dense_test", + size = "small", + srcs = ["sparse_to_dense_test.cc"], + tags = ["tflite_not_portable_ios"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index 62d6fe0..c43c5f9 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -4000,6 +4000,42 @@ inline void RankOneSelect(const D* input_condition_data, } } +// For easy implementation, the indices is always a vector of size-4 vectors. +template +inline void SparseToDense(const std::vector>& indices, + const T* values, T default_value, T* output_data, + const Dims<4>& output_dims, bool value_is_scalar) { + const int value_count = indices.size(); + + // First fill the output_data with default value. + const int num_elements = FlatSize(output_dims); + for (int i = 0; i < num_elements; ++i) { + output_data[i] = default_value; + } + + // Special handle for value is scalar case to avoid checking the boolean + // condition within the loop every time. + if (value_is_scalar) { + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = *values; // just use the first value. + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } + return; + } + + // Go through the values and indices to fill the sparse values. + for (int i = 0; i < value_count; ++i) { + const std::vector& index = indices[i]; + TFLITE_DCHECK_EQ(index.size(), 4); + const T value = values[i]; + output_data[Offset(output_dims, index[3], index[2], index[1], index[0])] = + value; + } +} + } // namespace reference_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 21cc185..4eea992 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -90,6 +90,7 @@ TfLiteRegistration* Register_SELECT(); TfLiteRegistration* Register_SLICE(); TfLiteRegistration* Register_SIN(); TfLiteRegistration* Register_TRANSPOSE_CONV(); +TfLiteRegistration* Register_SPARSE_TO_DENSE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -161,6 +162,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SLICE, Register_SLICE()); AddBuiltin(BuiltinOperator_SIN, Register_SIN()); AddBuiltin(BuiltinOperator_TRANSPOSE_CONV, Register_TRANSPOSE_CONV()); + AddBuiltin(BuiltinOperator_SPARSE_TO_DENSE, Register_SPARSE_TO_DENSE()); // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that // custom ops aren't always included by default. diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc new file mode 100644 index 0000000..404c32a --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense.cc @@ -0,0 +1,275 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/kernels/padding.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace sparse_to_dense { + +constexpr int kIndicesTensor = 0; +constexpr int kOutputShapeTensor = 1; +constexpr int kValueInputTensor = 2; +constexpr int kDefaultValueTensor = 3; +constexpr int kOutputTensor = 0; + +constexpr int kMaxDimensions = 4; + +template +TfLiteStatus Resize(TfLiteContext* context, const TfLiteTensor* output_shape, + TfLiteTensor* output) { + const int output_dimensions = NumElements(output_shape); + TfLiteIntArray* output_shape_array = TfLiteIntArrayCreate(output_dimensions); + for (int i = 0; i < output_dimensions; ++i) { + output_shape_array->data[i] = GetTensorData(output_shape)[i]; + } + + return context->ResizeTensor(context, output, output_shape_array); +} + +TfLiteStatus CheckDimensionsMatch(TfLiteContext* context, + const TfLiteTensor* indices, + const TfLiteTensor* output_shape, + const TfLiteTensor* values) { + switch (NumDimensions(indices)) { + case 0: + case 1: { + if (NumDimensions(values) == 0) { + TF_LITE_ENSURE_EQ(context, NumElements(indices), NumElements(values)); + } + TF_LITE_ENSURE_EQ(context, NumElements(output_shape), 1); + break; + } + case 2: { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 1), + NumElements(output_shape)); + if (NumDimensions(values) == 0) + TF_LITE_ENSURE_EQ(context, SizeOfDimension(indices, 0), + NumElements(values)); + break; + } + default: + context->ReportError( + context, "Wrong indices dimensions %d, should be less than 3.", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +// Convert indices into a vector of 4-d vectors. +// TODO(renjieliu): Revisit here to improve the performance, since multiple +// allocations of std::vectors will be quite slow on phones. +template +TfLiteStatus GetIndicesVector(TfLiteContext* context, + const TfLiteTensor* indices, + const int num_indices, + std::vector>* indices_vector) { + // Note because TfLite will reverse the dimensions, so pad zeros upfront. + switch (NumDimensions(indices)) { + case 0: + case 1: { + const auto indices_data = GetTensorData(indices); + for (int i = 0; i < num_indices; ++i) { + std::vector index({0, 0, 0, indices_data[i]}); + indices_vector->push_back(index); + } + break; + } + case 2: { + const int true_dimensions = SizeOfDimension(indices, 1); + TF_LITE_ENSURE(context, true_dimensions <= kMaxDimensions); + for (int i = 0; i < num_indices; ++i) { + std::vector index; + index.reserve(kMaxDimensions); + // Fill the index with 1 up to kMaxDimensions - true_dimensions to + // satisfy the needs for 4-dimension index. + for (int j = 0; j < kMaxDimensions - true_dimensions; ++j) { + index.push_back(0); + } + for (int j = 0; j < true_dimensions; ++j) { + index.push_back(GetTensorData(indices)[i * true_dimensions + j]); + } + + indices_vector->push_back(index); + } + break; + } + default: + context->ReportError(context, + "Indices dimensions problem, got %d dimensions", + NumDimensions(indices)); + return kTfLiteError; + } + return kTfLiteOk; +} + +TfLiteStatus ResizeOutputShape(TfLiteContext* context, + const TfLiteTensor* output_shape, + TfLiteTensor* output) { + if (output_shape->type == kTfLiteInt32) { + return Resize(context, output_shape, output); + } else if (output_shape->type == kTfLiteInt64) { + return Resize(context, output_shape, output); + } else { + context->ReportError(context, "Dense shape type %d not supported.", + output_shape->type); + return kTfLiteError; + } +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + + // TODO(renjieliu): Handle validate_indices. + + // Indices can be 0-D, 1-D or 2-D. + TF_LITE_ASSERT(NumDimensions(indices) >= 0); + TF_LITE_ENSURE(context, NumDimensions(indices) < 3); + TF_LITE_ASSERT(NumDimensions(output_shape) >= 0); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + // Values can be 0-D or 1-D. + TF_LITE_ASSERT(NumDimensions(values) >= 0); + TF_LITE_ENSURE(context, NumDimensions(values) < 2); + + TF_LITE_ENSURE_EQ(context, NumElements(default_value), 1); + + TF_LITE_ENSURE( + context, indices->type == kTfLiteInt32 || indices->type == kTfLiteInt64); + TF_LITE_ENSURE(context, output_shape->type == kTfLiteInt32 || + output_shape->type == kTfLiteInt64); + TF_LITE_ENSURE_EQ(context, values->type, default_value->type); + + // Ensure dimensions match. + TF_LITE_ENSURE_OK( + context, CheckDimensionsMatch(context, indices, output_shape, values)); + + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TF_LITE_ENSURE_EQ(context, NumDimensions(output_shape), 1); + + if (!IsConstantTensor(output_shape)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputShape(context, output_shape, output); +} + +template +TfLiteStatus SparseToDenseImpl(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* output_shape = + GetInput(context, node, kOutputShapeTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + const TfLiteTensor* default_value = + GetInput(context, node, kDefaultValueTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputShape(context, output_shape, output)); + } + + const int num_indices = SizeOfDimension(indices, 0); + const bool value_is_scalar = NumDimensions(values) == 0; + std::vector> indices_vector; + indices_vector.reserve(num_indices); + TF_LITE_ENSURE_OK(context, GetIndicesVector(context, indices, num_indices, + &indices_vector)); + reference_ops::SparseToDense(indices_vector, GetTensorData(values), + *GetTensorData(default_value), + GetTensorData(output), GetTensorDims(output), + value_is_scalar); + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + const TfLiteTensor* indices = GetInput(context, node, kIndicesTensor); + const TfLiteTensor* values = GetInput(context, node, kValueInputTensor); + + // Currently only supports float32 and int32. + switch (values->type) { + case kTfLiteFloat32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + case kTfLiteInt32: { + switch (indices->type) { + case kTfLiteInt32: { + return SparseToDenseImpl(context, node); + } + case kTfLiteInt64: { + return SparseToDenseImpl(context, node); + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + indices->type); + return kTfLiteError; + } + break; + } + default: + context->ReportError( + context, "Type %d is currently not supported by sparse to dense.", + values->type); + return kTfLiteError; + } +} + +} // namespace sparse_to_dense + +TfLiteRegistration* Register_SPARSE_TO_DENSE() { + static TfLiteRegistration r = {nullptr, nullptr, sparse_to_dense::Prepare, + sparse_to_dense::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc new file mode 100644 index 0000000..a51ec17 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sparse_to_dense_test.cc @@ -0,0 +1,155 @@ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +template +class SparseToDenseOpModel : public SingleOpModel { + public: + SparseToDenseOpModel(std::initializer_list indices_shape, + std::initializer_list output_shape_shape, + std::initializer_list values_shape, T default_value, + TensorType tensor_index_type, + TensorType tensor_input_type) { + indices_ = AddInput(tensor_index_type); + output_shape_ = AddInput(TensorType_INT32); + values_ = AddInput(tensor_input_type); + default_value_ = AddInput(tensor_input_type); + output_ = AddOutput(tensor_input_type); + + SetBuiltinOp(BuiltinOperator_SPARSE_TO_DENSE, + BuiltinOptions_SparseToDenseOptions, + CreateSparseToDenseOptions(builder_, false).Union()); + BuildInterpreter({indices_shape, output_shape_shape, values_shape, {1}}); + + PopulateTensor(default_value_, {default_value}); + } + + int indices() { return indices_; } + int output_shape() { return output_shape_; } + int values() { return values_; } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int indices_; + int output_shape_; + int values_; + int default_value_; + int output_; +}; + +TEST(SparseToDenseOpModelTest, ZeroDimensionTest) { + SparseToDenseOpModel m({1}, {1}, {1}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {3}); + m.PopulateTensor(m.output_shape(), {5}); + m.PopulateTensor(m.values(), {7}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 7, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({5})); +} + +TEST(SparseToDenseOpModelTest, OneDimensionTest) { + SparseToDenseOpModel m({3}, {1}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {1, 3, 5}); + m.PopulateTensor(m.output_shape(), {7}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 0, 4, 0, 6, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({7})); +} + +TEST(SparseToDenseOpModelTest, TwoDimensionsTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, 0, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 4, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, DefaultValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, IntegerValueTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT32, + TensorType_INT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +TEST(SparseToDenseOpModelTest, Int64IndexTest) { + SparseToDenseOpModel m({3, 3}, {3}, {3}, -1, TensorType_INT64, + TensorType_FLOAT32); + m.PopulateTensor(m.indices(), {0, 0, 0, 1, 2, 1, 2, 0, 1}); + m.PopulateTensor(m.output_shape(), {3, 3, 3}); + m.PopulateTensor(m.values(), {2, 4, 6}); + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({2, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, 4, -1, -1, 6, -1, -1, -1, -1, -1, -1, -1})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 3, 3})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 80fcb28..6ac41a9 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -699,6 +699,16 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_SPARSE_TO_DENSE: { + TfLiteSparseToDenseParams* params = + MallocPOD(); + if (auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_DELEGATE: { // TODO(ycling): Revisit when supporting saving delegated models. error_reporter->Report("DELEGATE op shouldn't exist in model."); diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index eed57d4..fad08bb 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -491,6 +491,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_SLICE: case tflite::BuiltinOperator_SIN: case tflite::BuiltinOperator_TRANSPOSE_CONV: + case tflite::BuiltinOperator_SPARSE_TO_DENSE: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 8bdeb03..522eac2 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -145,6 +145,7 @@ enum BuiltinOperator : byte { SLICE = 65, SIN = 66, TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, } // Options for the builtin operators. @@ -198,6 +199,7 @@ union BuiltinOptions { SelectOptions, SliceOptions, TransposeConvOptions, + SparseToDenseOptions, } enum Padding : byte { SAME, VALID } @@ -450,6 +452,10 @@ table TransposeConvOptions { stride_h:int; } +table SparseToDenseOptions { + validate_indices:bool; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 35c34f5..746dd26 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -178,6 +178,9 @@ struct SliceOptionsT; struct TransposeConvOptions; struct TransposeConvOptionsT; +struct SparseToDenseOptions; +struct SparseToDenseOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -305,11 +308,12 @@ enum BuiltinOperator { BuiltinOperator_SLICE = 65, BuiltinOperator_SIN = 66, BuiltinOperator_TRANSPOSE_CONV = 67, + BuiltinOperator_SPARSE_TO_DENSE = 68, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_TRANSPOSE_CONV + BuiltinOperator_MAX = BuiltinOperator_SPARSE_TO_DENSE }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[68] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -377,7 +381,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[67] { BuiltinOperator_SELECT, BuiltinOperator_SLICE, BuiltinOperator_SIN, - BuiltinOperator_TRANSPOSE_CONV + BuiltinOperator_TRANSPOSE_CONV, + BuiltinOperator_SPARSE_TO_DENSE }; return values; } @@ -452,6 +457,7 @@ inline const char **EnumNamesBuiltinOperator() { "SLICE", "SIN", "TRANSPOSE_CONV", + "SPARSE_TO_DENSE", nullptr }; return names; @@ -513,11 +519,12 @@ enum BuiltinOptions { BuiltinOptions_SelectOptions = 47, BuiltinOptions_SliceOptions = 48, BuiltinOptions_TransposeConvOptions = 49, + BuiltinOptions_SparseToDenseOptions = 50, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_TransposeConvOptions + BuiltinOptions_MAX = BuiltinOptions_SparseToDenseOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[51] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -568,7 +575,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[50] { BuiltinOptions_LessEqualOptions, BuiltinOptions_SelectOptions, BuiltinOptions_SliceOptions, - BuiltinOptions_TransposeConvOptions + BuiltinOptions_TransposeConvOptions, + BuiltinOptions_SparseToDenseOptions }; return values; } @@ -625,6 +633,7 @@ inline const char **EnumNamesBuiltinOptions() { "SelectOptions", "SliceOptions", "TransposeConvOptions", + "SparseToDenseOptions", nullptr }; return names; @@ -835,6 +844,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeConvOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SparseToDenseOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1258,6 +1271,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_TransposeConvOptions ? reinterpret_cast(value) : nullptr; } + SparseToDenseOptionsT *AsSparseToDenseOptions() { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } + const SparseToDenseOptionsT *AsSparseToDenseOptions() const { + return type == BuiltinOptions_SparseToDenseOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -4543,6 +4564,60 @@ inline flatbuffers::Offset CreateTransposeConvOptions( flatbuffers::Offset CreateTransposeConvOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeConvOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct SparseToDenseOptionsT : public flatbuffers::NativeTable { + typedef SparseToDenseOptions TableType; + bool validate_indices; + SparseToDenseOptionsT() + : validate_indices(false) { + } +}; + +struct SparseToDenseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SparseToDenseOptionsT NativeTableType; + enum { + VT_VALIDATE_INDICES = 4 + }; + bool validate_indices() const { + return GetField(VT_VALIDATE_INDICES, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_VALIDATE_INDICES) && + verifier.EndTable(); + } + SparseToDenseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SparseToDenseOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_validate_indices(bool validate_indices) { + fbb_.AddElement(SparseToDenseOptions::VT_VALIDATE_INDICES, static_cast(validate_indices), 0); + } + explicit SparseToDenseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SparseToDenseOptionsBuilder &operator=(const SparseToDenseOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSparseToDenseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool validate_indices = false) { + SparseToDenseOptionsBuilder builder_(_fbb); + builder_.add_validate_indices(validate_indices); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -4821,6 +4896,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const TransposeConvOptions *builtin_options_as_TransposeConvOptions() const { return builtin_options_type() == BuiltinOptions_TransposeConvOptions ? static_cast(builtin_options()) : nullptr; } + const SparseToDenseOptions *builtin_options_as_SparseToDenseOptions() const { + return builtin_options_type() == BuiltinOptions_SparseToDenseOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -5043,6 +5121,10 @@ template<> inline const TransposeConvOptions *Operator::builtin_options_as inline const SparseToDenseOptions *Operator::builtin_options_as() const { + return builtin_options_as_SparseToDenseOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -6862,6 +6944,32 @@ inline flatbuffers::Offset CreateTransposeConvOptions(flat _stride_h); } +inline SparseToDenseOptionsT *SparseToDenseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SparseToDenseOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SparseToDenseOptions::UnPackTo(SparseToDenseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = validate_indices(); _o->validate_indices = _e; }; +} + +inline flatbuffers::Offset SparseToDenseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSparseToDenseOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSparseToDenseOptions(flatbuffers::FlatBufferBuilder &_fbb, const SparseToDenseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SparseToDenseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _validate_indices = _o->validate_indices; + return tflite::CreateSparseToDenseOptions( + _fbb, + _validate_indices); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -7244,6 +7352,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -7458,6 +7570,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -7660,6 +7776,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateTransposeConvOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + return CreateSparseToDenseOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -7862,6 +7982,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new TransposeConvOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_SparseToDenseOptions: { + value = new SparseToDenseOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -8114,6 +8238,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_SparseToDenseOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 13fafeb..ae66bd8 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -146,8 +146,9 @@ def toco_options(data_types, " --inference_type=%s" % inference_type + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + " --input_arrays=%s" % ",".join(input_arrays) + - " --input_shapes=%s" % shape_str + " --output_arrays=%s" % ",".join(output_arrays)) + if shape_str: + s += (" --input_shapes=%s" % shape_str) if extra_toco_options.drop_control_dependency: s += " --drop_control_dependency" if extra_toco_options.allow_custom_ops: @@ -238,6 +239,19 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): return value.astype(dtype) +def create_scalar_data(dtype, min_value=-100, max_value=100): + """Build scalar tensor data range from min_value to max_value exclusively.""" + + if dtype in _TF_TYPE_INFO: + dtype = _TF_TYPE_INFO[dtype][0] + + if dtype in (tf.float32, tf.float16): + value = (max_value - min_value) * np.random.random() + min_value + elif dtype in (tf.int32, tf.uint8, tf.int64): + value = np.random.randint(min_value, max_value + 1) + return np.array(value, dtype=dtype) + + def freeze_graph(session, outputs): """Freeze the current graph. @@ -2485,6 +2499,67 @@ def make_transpose_conv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_sparse_to_dense_tests(zip_path): + """Make a set of tests to do sparse to dense.""" + + test_parameters = [{ + "value_dtype": [tf.float32, tf.int32], + "index_dtype": [tf.int32, tf.int64], + "value_count": [1, 3, 6, 8], + "dense_shape": [[15], [3, 10], [4, 4, 4, 4], [7, 10, 9]], + "default_value": [0, -1], + "value_is_scalar": [True, False], + }] + + # Return a single value for 1-D dense shape, but a tuple for other shapes. + def generate_index(dense_shape): + if len(dense_shape) == 1: + return np.random.randint(dense_shape[0]) + else: + index = [] + for shape in dense_shape: + index.append(np.random.randint(shape)) + return tuple(index) + + def build_graph(parameters): + """Build the sparse_to_dense op testing graph.""" + dense_shape = parameters["dense_shape"] + + # Special handle for value_is_scalar case. + # value_count must be 1. + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + value = tf.placeholder( + name="value", dtype=parameters["value_dtype"], shape=()) + else: + value = tf.placeholder( + name="value", + dtype=parameters["value_dtype"], + shape=[parameters["value_count"]]) + indices = set() + while len(indices) < parameters["value_count"]: + indices.add(generate_index(dense_shape)) + indices = tf.constant(tuple(indices), dtype=parameters["index_dtype"]) + # TODO(renjieliu): Add test for validate_indices case. + out = tf.sparse_to_dense( + indices, + dense_shape, + value, + parameters["default_value"], + validate_indices=False) + + return [value], [out] + + def build_inputs(parameters, sess, inputs, outputs): + if parameters["value_is_scalar"] and parameters["value_count"] == 1: + input_value = create_scalar_data(parameters["value_dtype"]) + else: + input_value = create_tensor_data(parameters["value_dtype"], + [parameters["value_count"]]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + # Toco binary path provided by the generate rule. bin_path = None diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index f515714..99f0c81 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -1728,6 +1728,25 @@ void ConvertComparisonOperator(const Model& model, const Operator& src_op, (*comparison_op->mutable_attr())["T"].set_type(data_type); } +void ConvertSparseToDenseOperator(const Model& model, + const SparseToDenseOperator& src_op, + const char* op_name, + GraphDef* tensorflow_graph) { + auto* sparse_to_dense_op = tensorflow_graph->add_node(); + sparse_to_dense_op->set_op(op_name); + sparse_to_dense_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 4); + for (int i = 0; i < 4; ++i) { + *sparse_to_dense_op->add_input() = src_op.inputs[i]; + } + const auto data_type = GetTensorFlowDataType(model, src_op.inputs[3]); + (*sparse_to_dense_op->mutable_attr())["T"].set_type(data_type); + const auto index_type = GetTensorFlowDataType(model, src_op.inputs[0]); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_type(index_type); + (*sparse_to_dense_op->mutable_attr())["Tindices"].set_b( + src_op.validate_indices); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc index 6342cf3..64096fb 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_array_data_types.cc @@ -163,6 +163,16 @@ bool PropagateArrayDataTypes::Run(Model* model, std::size_t op_index) { SetDataTypeForAllOutputs(model, op, data_type_x); break; } + case OperatorType::kSparseToDense: { + // Select produces outputs with the same type as their 3rd input + CHECK_EQ(op->inputs.size(), 4); + const ArrayDataType data_type = model->GetArray(op->inputs[2]).data_type; + const ArrayDataType data_type_default = + model->GetArray(op->inputs[3]).data_type; + CHECK(data_type == data_type_default); + SetDataTypeForAllOutputs(model, op, data_type); + break; + } default: { // These operators produce outputs with the same type as their 1st input CHECK_GT(op->inputs.size(), 0); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 9d1d27f..adb241d 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -1477,6 +1477,34 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) { *output_array.mutable_shape()->mutable_dims() = output_dims; } +void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) { + CHECK_EQ(op->inputs.size(), 4); + + const Array& output_shape_array = model->GetArray(op->inputs[1]); + if (!output_shape_array.has_shape()) return; + CHECK_EQ(output_shape_array.shape().dimensions_count(), 1); + + // Output should not go over four dimensions. + CHECK_LE(output_shape_array.shape().dims(0), 4); + + const string& output_name = op->outputs[0]; + Array& output_array = model->GetArray(output_name); + if (output_array.has_shape()) return; + + CHECK(output_shape_array.data_type == ArrayDataType::kInt32 || + output_shape_array.data_type == ArrayDataType::kInt64); + if (output_shape_array.data_type == ArrayDataType::kInt32) { + *output_array.mutable_shape()->mutable_dims() = + output_shape_array.GetBuffer().data; + } else { + const std::vector& output_shape_data = + output_shape_array.GetBuffer().data; + std::copy( + output_shape_data.begin(), output_shape_data.end(), + std::back_inserter(*output_array.mutable_shape()->mutable_dims())); + } +} + } // namespace bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { @@ -1700,6 +1728,10 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { CHECK_EQ(op->inputs.size(), 1); ProcessOpWithShapeInput(model, op); break; + case OperatorType::kSparseToDense: + ProcessSparseToDenseOperator(model, + static_cast(op)); + break; default: // Unimplemented, another graph transformation should drop it. LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type); diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 27e9d1a..94ec7c2 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -2133,6 +2133,24 @@ void ConvertDynamicStitchOperator(const NodeDef& node, model->operators.emplace_back(op.release()); } +void ConvertSparseToDenseOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "SparseToDense"); + CheckInputsCount(node, tf_import_flags, 4); + + auto* op = new SparseToDenseOperator; + for (const string& input : node.input()) { + op->inputs.push_back(input); + } + op->outputs.push_back(node.name()); + + op->validate_indices = HasAttr(node, "validate_indices") + ? GetBoolAttr(node, "validate_indices") + : true; + model->operators.emplace_back(op); +} + } // namespace namespace internal { @@ -2314,6 +2332,8 @@ Status ImportTensorFlowNode(const tensorflow::NodeDef& node, ConvertSinOperator(node, tf_import_flags, model); } else if (node.op() == "Select") { ConvertSelectOperator(node, tf_import_flags, model); + } else if (node.op() == "SparseToDense") { + ConvertSparseToDenseOperator(node, tf_import_flags, model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index d878ac5..9062c03 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -135,6 +135,7 @@ enum class OperatorType { // special nodes in the graph to shuffle axes. kReorderAxes, kSelect, + kSparseToDense, }; // Helper to deal with TensorFlow arrays using a different ordering of @@ -1598,6 +1599,19 @@ struct DynamicStitchOperator : Operator { int num_partitions; }; +// SparseToDense operator: +// +// Inputs: +// Inputs[0]: required: sparse_indices. +// Inputs[1]: required: output_shape. +// Inputs[2]: required: sparse_values. +// +// TensorFlow equivalent: SparseToDense. +struct SparseToDenseOperator : Operator { + SparseToDenseOperator() : Operator(OperatorType::kSparseToDense) {} + bool validate_indices; +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 6922e50..8f0f2e2 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -794,6 +794,27 @@ class TransposeConv int GetVersion(const Operator& op) const override { return 1; } }; +class SparseToDense + : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSparseToDenseOptions(*builder, op.validate_indices); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->validate_indices = options.validate_indices(); + } + + int GetVersion(const Operator& op) const override { return 1; } +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -978,6 +999,8 @@ std::vector> BuildOperatorList() { new ArgMax(::tflite::BuiltinOperator_ARG_MAX, OperatorType::kArgMax)); ops.emplace_back(new TransposeConv(::tflite::BuiltinOperator_TRANSPOSE_CONV, OperatorType::kTransposeConv)); + ops.emplace_back(new SparseToDense(::tflite::BuiltinOperator_SPARSE_TO_DENSE, + OperatorType::kSparseToDense)); // Custom Operators. ops.emplace_back( diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index fe594c6..d63c99a 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -420,6 +420,15 @@ TEST_F(OperatorTest, BuiltinTransposeConv) { EXPECT_EQ(op.padding.type, output_toco_op->padding.type); } +TEST_F(OperatorTest, BuiltinSparseToDense) { + SparseToDenseOperator op; + op.validate_indices = false; + std::unique_ptr output_toco_op = + SerializeAndDeserialize( + GetOperator("SPARSE_TO_DENSE", OperatorType::kSparseToDense), op); + EXPECT_EQ(op.validate_indices, output_toco_op->validate_indices); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index 1e6314f..fe7bed8 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -393,6 +393,7 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(DynamicPartition) HANDLE_OPERATORTYPENAME_CASE(DynamicStitch) HANDLE_OPERATORTYPENAME_CASE(Select) + HANDLE_OPERATORTYPENAME_CASE(SparseToDense) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE -- 2.7.4