Support Transpose in TFLite.
authorNupur Garg <nupurgarg@google.com>
Mon, 8 Jan 2018 22:45:31 +0000 (14:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 8 Jan 2018 22:50:46 +0000 (14:50 -0800)
The internal implementation supports 1D-4D tensors.

PiperOrigin-RevId: 181221674

20 files changed:
tensorflow/contrib/lite/builtin_op_data.h
tensorflow/contrib/lite/kernels/BUILD
tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
tensorflow/contrib/lite/kernels/register.cc
tensorflow/contrib/lite/kernels/transpose.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/transpose_test.cc
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/nnapi_delegate.cc
tensorflow/contrib/lite/schema/schema.fbs
tensorflow/contrib/lite/schema/schema_generated.h
tensorflow/contrib/lite/testing/BUILD
tensorflow/contrib/lite/testing/generate_examples.py
tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
tensorflow/contrib/lite/toco/BUILD
tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h
tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc [new file with mode: 0644]
tensorflow/contrib/lite/toco/model.h
tensorflow/contrib/lite/toco/tflite/operator.cc
tensorflow/contrib/lite/toco/tflite/operator_test.cc
tensorflow/contrib/lite/toco/toco_tooling.cc

index 0c6e7f93a63f2cb768aece993cbd4f1a66e6713e..347e46b83c4d528f2ca67c622770c77653c19071 100644 (file)
@@ -191,6 +191,13 @@ typedef struct {
   int axis;
 } TfLiteGatherParams;
 
+typedef struct {
+  // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+  // For now we will fix the maximum possible number of dimensions.
+  int perm[8];
+  int num_dimensions;
+} TfLiteTransposeParams;
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif  // __cplusplus
index 759de4490b267030e04f75b099ac45e905526a10..d7f4b36f948e778e379909d303f2713c477d209e 100644 (file)
@@ -101,6 +101,7 @@ cc_library(
         "space_to_batch_nd.cc",
         "space_to_depth.cc",
         "svdf.cc",
+        "transpose.cc",
         "unidirectional_sequence_rnn.cc",
     ],
     hdrs = [
index 93087cda57a341be5b73629863d0738495b86216..96158489961e2b279e0c63c8910e9b5bb603530f 100644 (file)
@@ -2485,8 +2485,8 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims,
 }
 
 template <typename T>
-void Transpose(const T* input, Dims<4>& input_dims, T* output,
-               Dims<4>& output_dims, int* permuted_axes) {
+void Transpose(const T* input, const Dims<4>& input_dims, T* output,
+               const Dims<4>& output_dims, int* permuted_axes) {
   int out_sizes[4];
   // Compute the inverse permutation array so we can do an output centered
   // transpose. Also, check to make sure output_dims is matching input_dims.
index 70b5d99eeabf99b333ba8f71f46e9d47fd4bc746..9e4bacbf78cdb160147996bf12a420cbc015ce61 100644 (file)
@@ -52,6 +52,7 @@ TfLiteRegistration* Register_RESIZE_BILINEAR();
 TfLiteRegistration* Register_SKIP_GRAM();
 TfLiteRegistration* Register_SPACE_TO_DEPTH();
 TfLiteRegistration* Register_GATHER();
+TfLiteRegistration* Register_TRANSPOSE();
 
 BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -90,6 +91,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
   AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
   AddBuiltin(BuiltinOperator_GATHER, Register_GATHER());
+  AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE());
 }
 
 TfLiteRegistration* BuiltinOpResolver::FindOp(
diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc
new file mode 100644 (file)
index 0000000..75d8136
--- /dev/null
@@ -0,0 +1,142 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace transpose {
+
+// This file has two implementations of Transpose.
+enum KernelType {
+  kReference,
+};
+
+// TODO(nupurgarg): Permutation arrays represented as a tensor are ignored. Only
+// use the `perm` specified in `params`.
+struct TransposeContext {
+  TransposeContext(TfLiteContext* context, TfLiteNode* node) {
+    params = reinterpret_cast<TfLiteTransposeParams*>(node->builtin_data);
+    input = GetInput(context, node, 0);
+    output = GetOutput(context, node, 0);
+  }
+  TfLiteTransposeParams* params;
+  TfLiteTensor* input;
+  TfLiteTensor* output;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  TransposeContext op_context(context, node);
+  int dims = NumDimensions(op_context.input);
+
+  // Ensure validity of input tensor and permutation array.
+  TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+  TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions);
+  TF_LITE_ENSURE_MSG(context, dims <= 4,
+                     "Transpose op only supports 1D-4D input arrays.");
+  for (int idx = 0; idx < dims; ++idx) {
+    TF_LITE_ENSURE_MSG(context,
+                       op_context.params->perm[idx] >= 0 &&
+                           op_context.params->perm[idx] < dims,
+                       "Transpose op permutations array is out of bounds.");
+  }
+
+  // Determine size of output tensor.
+  const TfLiteIntArray* input_size = op_context.input->dims;
+  TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims);
+  for (int idx = 0; idx < dims; ++idx) {
+    output_size->data[idx] = input_size->data[op_context.params->perm[idx]];
+  }
+
+  return context->ResizeTensor(context, op_context.output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  TransposeContext op_context(context, node);
+
+  // Reverse the permuted axes and convert to 4D due to the way Dims are
+  // constructed in GetTensorDims.
+  const int kOutputDimensionNum = 4;
+  int reversed_perm[kOutputDimensionNum];
+  int size = op_context.params->num_dimensions;
+  for (int output_k = 0, input_k = size - 1; output_k < size;
+       ++output_k, --input_k) {
+    reversed_perm[output_k] = size - op_context.params->perm[input_k] - 1;
+  }
+  for (int k = size; k < kOutputDimensionNum; ++k) {
+    reversed_perm[k] = k;
+  }
+
+#define TF_LITE_TRANSPOSE(type, scalar)                     \
+  type::Transpose(GetTensorData<scalar>(op_context.input),  \
+                  GetTensorDims(op_context.input),          \
+                  GetTensorData<scalar>(op_context.output), \
+                  GetTensorDims(op_context.output), reversed_perm)
+
+  switch (op_context.input->type) {
+    case kTfLiteFloat32:
+      if (kernel_type == kReference) {
+        TF_LITE_TRANSPOSE(reference_ops, float);
+      }
+      break;
+    case kTfLiteUInt8:
+      if (kernel_type == kReference) {
+        TF_LITE_TRANSPOSE(reference_ops, uint8_t);
+      }
+      break;
+    case kTfLiteInt32:
+      if (kernel_type == kReference) {
+        TF_LITE_TRANSPOSE(reference_ops, int32_t);
+      }
+      break;
+    case kTfLiteInt64:
+      if (kernel_type == kReference) {
+        TF_LITE_TRANSPOSE(reference_ops, int64_t);
+      }
+      break;
+    default:
+      context->ReportError(context,
+                           "Type is currently not supported by Transpose.");
+      return kTfLiteError;
+  }
+#undef TF_LITE_TRANSPOSE
+
+  return kTfLiteOk;
+}
+
+}  // namespace transpose
+
+TfLiteRegistration* Register_TRANSPOSE_REF() {
+  static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
+                                 transpose::Eval<transpose::kReference>};
+  return &r;
+}
+
+TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); }
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
index 0b5b60a8165abf0aa281365599b2e5d28616df9b..7f5832cd5fa3d502b52bf5554111b45136b588ae 100644 (file)
@@ -16,11 +16,15 @@ limitations under the License.
 #include "tensorflow/contrib/lite/interpreter.h"
 #include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
 #include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
 #include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
 
 namespace tflite {
 namespace {
 
+using ::testing::ElementsAreArray;
+
 void RunTestPermutation(const std::vector<int>& shape,
                         const std::vector<int>& perms,
                         std::vector<float>* input_transposed) {
@@ -64,14 +68,14 @@ void RunTestPermutation(const std::vector<int>& shape,
                                   reversed_perms);
 }
 
-TEST(TransposeTest, Test1D) {
+TEST(TransposeTest, TestRefOps1D) {
   // Basic 1D identity.
   std::vector<float> out;
   RunTestPermutation({3}, {0}, &out);
   ASSERT_EQ(out, std::vector<float>({0, 1, 2}));
 }
 
-TEST(TransposeTest, Test2D) {
+TEST(TransposeTest, TestRefOps2D) {
   std::vector<float> out;
   // Basic 2D.
   RunTestPermutation({3, 2}, {1, 0}, &out);
@@ -81,7 +85,7 @@ TEST(TransposeTest, Test2D) {
   ASSERT_EQ(out, std::vector<float>({0, 1, 2, 3, 4, 5}));
 }
 
-TEST(TransposeTest, Test3D) {
+TEST(TransposeTest, TestRefOps3D) {
   std::vector<float> out;
   // Test 3 dimensional
   {
@@ -99,7 +103,7 @@ TEST(TransposeTest, Test3D) {
   }
 }
 
-TEST(TransposeTest, Test4D) {
+TEST(TransposeTest, TestRefOps4D) {
   std::vector<float> out;
   // Basic 4d.
   RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out);
@@ -121,6 +125,118 @@ TEST(TransposeTest, Test4D) {
   ASSERT_EQ(out, ref);
 }
 
+class TransposeOpModel : public SingleOpModel {
+ public:
+  TransposeOpModel(std::initializer_list<int> input_shape,
+                   std::initializer_list<int> perm) {
+    input_ = AddInput(TensorType_FLOAT32);
+    output_ = AddOutput(TensorType_FLOAT32);
+    SetBuiltinOp(
+        BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
+        CreateTransposeOptions(builder_, builder_.CreateVector<int>(perm))
+            .Union());
+    BuildInterpreter({input_shape});
+  }
+
+  void SetInput(std::initializer_list<float> data) {
+    PopulateTensor<float>(input_, data);
+  }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+  int input_;
+  int output_;
+};
+
+TEST(TransposeTest, TestUnequalPermSize) {
+  EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {2, 2}),
+               "dims != op_context.params->num_dimensions");
+}
+
+TEST(TransposeTest, TestPermOutOfBounds) {
+  EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, -1, -2, -3}),
+               "Transpose op permutations array is out of bounds.");
+  EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, 1, 2, 4}),
+               "Transpose op permutations array is out of bounds.");
+}
+
+TEST(TransposeTest, Test1DInputTensor) {
+  TransposeOpModel m({3}, {0});
+  m.SetInput({1, 2, 3});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
+}
+
+TEST(TransposeTest, Test2DInputTensor) {
+  TransposeOpModel m({3, 2}, {1, 0});
+  m.SetInput({0, 1, 2, 3, 4, 5});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
+}
+
+TEST(TransposeTest, Test3DInputTensor) {
+  TransposeOpModel m({2, 3, 4}, {2, 0, 1});
+  m.SetInput({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
+              12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
+  EXPECT_THAT(m.GetOutput(),
+              ElementsAreArray({0, 4, 8,  12, 16, 20, 1, 5, 9,  13, 17, 21,
+                                2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
+}
+
+TEST(TransposeTest, Test5DInputTensor) {
+  EXPECT_DEATH(TransposeOpModel({1, 2, 3, 4, 5}, {0, 1, 2, 3, 4}),
+               "Transpose op only supports 1D-4D input arrays.");
+}
+
+TEST(TransposeTest, SimpleTestNoReorder) {
+  TransposeOpModel m({1, 2, 3, 1}, {0, 1, 2, 3});
+  m.SetInput({1, 2, 3, 4, 5, 6});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(TransposeTest, SimpleTestWithReorder) {
+  TransposeOpModel m({1, 2, 3, 1}, {2, 1, 3, 0});
+  m.SetInput({1, 2, 3, 4, 5, 6});
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2, 1, 1}));
+  EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(TransposeTest, ComplexTestWithReorder) {
+  TransposeOpModel m({2, 3, 4, 5}, {2, 0, 1, 3});
+  m.SetInput({0,   1,   2,   3,   4,   5,   6,   7,   8,   9,   10,  11,
+              12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,
+              24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35,
+              36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47,
+              48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59,
+              60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
+              72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
+              84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,
+              96,  97,  98,  99,  100, 101, 102, 103, 104, 105, 106, 107,
+              108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
+  m.Invoke();
+
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5}));
+  auto result = ElementsAreArray(
+      {0,  1,  2,  3,  4,  20, 21, 22, 23, 24, 40,  41,  42,  43,  44,
+       60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
+       5,  6,  7,  8,  9,  25, 26, 27, 28, 29, 45,  46,  47,  48,  49,
+       65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
+       10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50,  51,  52,  53,  54,
+       70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
+       15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55,  56,  57,  58,  59,
+       75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
+  EXPECT_THAT(m.GetOutput(), result);
+}
+
 }  // namespace
 }  // namespace tflite
 
index bc3a73ad29727a9842ef6b90375794dc201ad00e..8a7b6b5c72f01bc2ab2c66ffb503be23010195fd 100644 (file)
@@ -555,6 +555,17 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
       builtin_data = reinterpret_cast<void*>(params);
       break;
     }
+    case BuiltinOperator_TRANSPOSE: {
+      auto* params = MallocPOD<TfLiteTransposeParams>();
+      if (auto* schema_params = op->builtin_options_as_TransposeOptions()) {
+        const auto& perm = schema_params->perm();
+        FlatBufferIntVectorToArray(sizeof(params->perm), perm, params->perm,
+                                   error_reporter);
+        params->num_dimensions = perm->Length();
+      }
+      builtin_data = reinterpret_cast<void*>(params);
+      break;
+    }
   }
   return builtin_data;
 }
index 8b6eec281020fa9469e5a366689a8a76946d7d10..faed5b193c1c7a522d67eb36bc3edf2b313d613b 100644 (file)
@@ -309,6 +309,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
       case tflite::BuiltinOperator_GATHER:
       case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
       case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
+      case tflite::BuiltinOperator_TRANSPOSE:
         FATAL("Op code %d is currently not delegated to NNAPI", builtin);
         nn_op_type = -1;  // set to invalid
         break;
index 5c50c4f1b6e405b20bdabf7f9a71eb2fd7905d6f..34dc16d66165e61778daad4e993bf56074331fe3 100644 (file)
@@ -109,6 +109,7 @@ enum BuiltinOperator : byte {
   GATHER = 36,
   BATCH_TO_SPACE_ND = 37,
   SPACE_TO_BATCH_ND = 38,
+  TRANSPOSE = 39,
 }
 
 // Options for the builtin operators.
@@ -138,6 +139,7 @@ union BuiltinOptions {
   GatherOptions,
   BatchToSpaceNDOptions,
   SpaceToBatchNDOptions,
+  TransposeOptions,
 }
 
 enum Padding : byte { SAME, VALID }
@@ -298,6 +300,10 @@ table GatherOptions {
   axis: int;
 }
 
+table TransposeOptions {
+  perm:[int];
+}
+
 // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
 // builtin, or a string if the operator is custom.
 table OperatorCode {
index 4ae82f2184179890a33fee6376c4e5b9c23a1d96..00cd2b9e1bea1c47177dbfc93bc6a02b4ef4ad45 100755 (executable)
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
 
 Licensed under the Apache License, Version 2.0 (the "License");
 you may not use this file except in compliance with the License.
@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
-
 // automatically generated by the FlatBuffers compiler, do not modify
 
 #ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_
@@ -103,6 +102,9 @@ struct EmbeddingLookupSparseOptionsT;
 struct GatherOptions;
 struct GatherOptionsT;
 
+struct TransposeOptions;
+struct TransposeOptionsT;
+
 struct OperatorCode;
 struct OperatorCodeT;
 
@@ -184,11 +186,12 @@ enum BuiltinOperator {
   BuiltinOperator_GATHER = 36,
   BuiltinOperator_BATCH_TO_SPACE_ND = 37,
   BuiltinOperator_SPACE_TO_BATCH_ND = 38,
+  BuiltinOperator_TRANSPOSE = 39,
   BuiltinOperator_MIN = BuiltinOperator_ADD,
-  BuiltinOperator_MAX = BuiltinOperator_SPACE_TO_BATCH_ND
+  BuiltinOperator_MAX = BuiltinOperator_TRANSPOSE
 };
 
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[36] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[37] {
   static BuiltinOperator values[] = {
       BuiltinOperator_ADD,
       BuiltinOperator_AVERAGE_POOL_2D,
@@ -225,7 +228,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[36] {
       BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
       BuiltinOperator_GATHER,
       BuiltinOperator_BATCH_TO_SPACE_ND,
-      BuiltinOperator_SPACE_TO_BATCH_ND};
+      BuiltinOperator_SPACE_TO_BATCH_ND,
+      BuiltinOperator_TRANSPOSE};
   return values;
 }
 
@@ -269,6 +273,7 @@ inline const char **EnumNamesBuiltinOperator() {
                                 "GATHER",
                                 "BATCH_TO_SPACE_ND",
                                 "SPACE_TO_BATCH_ND",
+                                "TRANSPOSE",
                                 nullptr};
   return names;
 }
@@ -305,11 +310,12 @@ enum BuiltinOptions {
   BuiltinOptions_GatherOptions = 23,
   BuiltinOptions_BatchToSpaceNDOptions = 24,
   BuiltinOptions_SpaceToBatchNDOptions = 25,
+  BuiltinOptions_TransposeOptions = 26,
   BuiltinOptions_MIN = BuiltinOptions_NONE,
-  BuiltinOptions_MAX = BuiltinOptions_SpaceToBatchNDOptions
+  BuiltinOptions_MAX = BuiltinOptions_TransposeOptions
 };
 
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[26] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[27] {
   static BuiltinOptions values[] = {
       BuiltinOptions_NONE,
       BuiltinOptions_Conv2DOptions,
@@ -336,7 +342,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[26] {
       BuiltinOptions_PadOptions,
       BuiltinOptions_GatherOptions,
       BuiltinOptions_BatchToSpaceNDOptions,
-      BuiltinOptions_SpaceToBatchNDOptions};
+      BuiltinOptions_SpaceToBatchNDOptions,
+      BuiltinOptions_TransposeOptions};
   return values;
 }
 
@@ -367,6 +374,7 @@ inline const char **EnumNamesBuiltinOptions() {
                                 "GatherOptions",
                                 "BatchToSpaceNDOptions",
                                 "SpaceToBatchNDOptions",
+                                "TransposeOptions",
                                 nullptr};
   return names;
 }
@@ -510,6 +518,11 @@ struct BuiltinOptionsTraits<SpaceToBatchNDOptions> {
   static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions;
 };
 
+template <>
+struct BuiltinOptionsTraits<TransposeOptions> {
+  static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions;
+};
+
 struct BuiltinOptionsUnion {
   BuiltinOptions type;
   void *value;
@@ -807,6 +820,16 @@ struct BuiltinOptionsUnion {
                ? reinterpret_cast<const SpaceToBatchNDOptionsT *>(value)
                : nullptr;
   }
+  TransposeOptionsT *AsTransposeOptions() {
+    return type == BuiltinOptions_TransposeOptions
+               ? reinterpret_cast<TransposeOptionsT *>(value)
+               : nullptr;
+  }
+  const TransposeOptionsT *AsTransposeOptions() const {
+    return type == BuiltinOptions_TransposeOptions
+               ? reinterpret_cast<const TransposeOptionsT *>(value)
+               : nullptr;
+  }
 };
 
 bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
@@ -2996,6 +3019,69 @@ flatbuffers::Offset<GatherOptions> CreateGatherOptions(
     flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o,
     const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
+struct TransposeOptionsT : public flatbuffers::NativeTable {
+  typedef TransposeOptions TableType;
+  std::vector<int32_t> perm;
+  TransposeOptionsT() {}
+};
+
+struct TransposeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef TransposeOptionsT NativeTableType;
+  enum { VT_PERM = 4 };
+  const flatbuffers::Vector<int32_t> *perm() const {
+    return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PERM);
+  }
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PERM) &&
+           verifier.Verify(perm()) && verifier.EndTable();
+  }
+  TransposeOptionsT *UnPack(
+      const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(
+      TransposeOptionsT *_o,
+      const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<TransposeOptions> Pack(
+      flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
+      const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct TransposeOptionsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  void add_perm(flatbuffers::Offset<flatbuffers::Vector<int32_t>> perm) {
+    fbb_.AddOffset(TransposeOptions::VT_PERM, perm);
+  }
+  explicit TransposeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+      : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  TransposeOptionsBuilder &operator=(const TransposeOptionsBuilder &);
+  flatbuffers::Offset<TransposeOptions> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<TransposeOptions>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    flatbuffers::Offset<flatbuffers::Vector<int32_t>> perm = 0) {
+  TransposeOptionsBuilder builder_(_fbb);
+  builder_.add_perm(perm);
+  return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptionsDirect(
+    flatbuffers::FlatBufferBuilder &_fbb,
+    const std::vector<int32_t> *perm = nullptr) {
+  return tflite::CreateTransposeOptions(
+      _fbb, perm ? _fbb.CreateVector<int32_t>(*perm) : 0);
+}
+
+flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
+    flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
+    const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
 struct OperatorCodeT : public flatbuffers::NativeTable {
   typedef OperatorCode TableType;
   BuiltinOperator builtin_code;
@@ -3250,6 +3336,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
                ? static_cast<const SpaceToBatchNDOptions *>(builtin_options())
                : nullptr;
   }
+  const TransposeOptions *builtin_options_as_TransposeOptions() const {
+    return builtin_options_type() == BuiltinOptions_TransposeOptions
+               ? static_cast<const TransposeOptions *>(builtin_options())
+               : nullptr;
+  }
   const flatbuffers::Vector<uint8_t> *custom_options() const {
     return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
   }
@@ -3424,6 +3515,12 @@ Operator::builtin_options_as<SpaceToBatchNDOptions>() const {
   return builtin_options_as_SpaceToBatchNDOptions();
 }
 
+template <>
+inline const TransposeOptions *Operator::builtin_options_as<TransposeOptions>()
+    const {
+  return builtin_options_as_TransposeOptions();
+}
+
 struct OperatorBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
@@ -5183,6 +5280,50 @@ inline flatbuffers::Offset<GatherOptions> CreateGatherOptions(
   return tflite::CreateGatherOptions(_fbb, _axis);
 }
 
+inline TransposeOptionsT *TransposeOptions::UnPack(
+    const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new TransposeOptionsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void TransposeOptions::UnPackTo(
+    TransposeOptionsT *_o,
+    const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+  {
+    auto _e = perm();
+    if (_e) {
+      _o->perm.resize(_e->size());
+      for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
+        _o->perm[_i] = _e->Get(_i);
+      }
+    }
+  };
+}
+
+inline flatbuffers::Offset<TransposeOptions> TransposeOptions::Pack(
+    flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
+    const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateTransposeOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
+    flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
+    const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs {
+    flatbuffers::FlatBufferBuilder *__fbb;
+    const TransposeOptionsT *__o;
+    const flatbuffers::rehasher_function_t *__rehasher;
+  } _va = {&_fbb, _o, _rehasher};
+  (void)_va;
+  auto _perm = _o->perm.size() ? _fbb.CreateVector(_o->perm) : 0;
+  return tflite::CreateTransposeOptions(_fbb, _perm);
+}
+
 inline OperatorCodeT *OperatorCode::UnPack(
     const flatbuffers::resolver_function_t *_resolver) const {
   auto _o = new OperatorCodeT();
@@ -5671,6 +5812,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier,
       auto ptr = reinterpret_cast<const SpaceToBatchNDOptions *>(obj);
       return verifier.VerifyTable(ptr);
     }
+    case BuiltinOptions_TransposeOptions: {
+      auto ptr = reinterpret_cast<const TransposeOptions *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
     default:
       return false;
   }
@@ -5795,6 +5940,10 @@ inline void *BuiltinOptionsUnion::UnPack(
       auto ptr = reinterpret_cast<const SpaceToBatchNDOptions *>(obj);
       return ptr->UnPack(resolver);
     }
+    case BuiltinOptions_TransposeOptions: {
+      auto ptr = reinterpret_cast<const TransposeOptions *>(obj);
+      return ptr->UnPack(resolver);
+    }
     default:
       return nullptr;
   }
@@ -5906,6 +6055,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(
       auto ptr = reinterpret_cast<const SpaceToBatchNDOptionsT *>(value);
       return CreateSpaceToBatchNDOptions(_fbb, ptr, _rehasher).Union();
     }
+    case BuiltinOptions_TransposeOptions: {
+      auto ptr = reinterpret_cast<const TransposeOptionsT *>(value);
+      return CreateTransposeOptions(_fbb, ptr, _rehasher).Union();
+    }
     default:
       return 0;
   }
@@ -6029,6 +6182,11 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u)
           *reinterpret_cast<SpaceToBatchNDOptionsT *>(u.value));
       break;
     }
+    case BuiltinOptions_TransposeOptions: {
+      value = new TransposeOptionsT(
+          *reinterpret_cast<TransposeOptionsT *>(u.value));
+      break;
+    }
     default:
       break;
   }
@@ -6161,6 +6319,11 @@ inline void BuiltinOptionsUnion::Reset() {
       delete ptr;
       break;
     }
+    case BuiltinOptions_TransposeOptions: {
+      auto ptr = reinterpret_cast<TransposeOptionsT *>(value);
+      delete ptr;
+      break;
+    }
     default:
       break;
   }
index 0f856a57a0f00a1c3c5d1577b162c8bbad9bdac4..f43b09ef48e838482f42fdd2df860173a66b5092 100644 (file)
@@ -43,6 +43,7 @@ gen_zipped_test_files(
         "softmax.zip",
         "space_to_batch_nd.zip",
         "space_to_depth.zip",
+        "transpose.zip",
     ],
 )
 
index 9b3dc7d29a0a34ca843d4124ee0f268fce4c8f75..6f38229e8dec299f04696e5ded7755741cd6440a 100644 (file)
@@ -1283,6 +1283,41 @@ def make_batch_to_space_nd_tests(zip_path):
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
 
+def make_transpose_tests(zip_path):
+  """Make a set of tests to do transpose."""
+
+  # TODO(nupurgarg): Add test for uint8.
+  test_parameters = [{
+      "dtype": [tf.int32, tf.int64, tf.float32],
+      "input_shape": [[2, 2, 3]],
+      "perm": [[0, 1, 2], [0, 2, 1]],
+  }, {
+      "dtype": [tf.float32],
+      "input_shape": [[1, 2, 3, 4]],
+      "perm": [[0, 1, 2, 3], [3, 0, 1, 2]],
+  }, {
+      "dtype": [tf.float32],
+      "input_shape": [[1, 2, 3, 4, 5]],
+      "perm": [[0, 1, 2, 3, 4]],
+  }]
+
+  def build_graph(parameters):
+    input_tensor = tf.placeholder(
+        dtype=parameters["dtype"],
+        name="input",
+        shape=parameters["input_shape"])
+    out = tf.transpose(input_tensor, perm=parameters["perm"])
+    return [input_tensor], [out]
+
+  def build_inputs(parameters, sess, inputs, outputs):
+    input_values = create_tensor_data(parameters["dtype"],
+                                      parameters["input_shape"])
+    return [input_values], sess.run(
+        outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
 def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
   """Given an input perform a sequence of TensorFlow ops to produce l2pool."""
   return tf.sqrt(tf.nn.avg_pool(
@@ -1336,6 +1371,7 @@ def main(unused_args):
         "sigmoid.zip": make_sigmoid_tests,
         "softmax.zip": make_softmax_tests,
         "space_to_depth.zip": make_space_to_depth_tests,
+        "transpose.zip": make_transpose_tests,
     }
     out = FLAGS.zip_to_output
     bin_path = FLAGS.toco
index 81df5177d9a55e4c8e9acf94c2964a6a9aea6d93..0885c6eddbea41aa5d97f6b34f03b2ffe501e2de 100644 (file)
@@ -85,6 +85,9 @@ std::map<string, string> kBrokenTests = {
 
     // ResizeBilinear looks completely incompatible with Tensorflow
     {R"(resize_bilinear)", "67964336"},
+
+    // Transpose only supports 1D-4D input tensors.
+    {R"(transposedtype=.*,input_shape=\[.,.,.,.,.\],perm=.*)", "71545879"},
 };
 
 // Allows test data to be unzipped into a temporary directory and makes
@@ -270,6 +273,7 @@ INSTANTIATE_TESTS(resize_bilinear)
 INSTANTIATE_TESTS(sigmoid)
 INSTANTIATE_TESTS(softmax)
 INSTANTIATE_TESTS(space_to_depth)
+INSTANTIATE_TESTS(transpose)
 
 }  // namespace testing
 }  // namespace tflite
index 9d86920e5fc6232d11e1c48295b6da95a9b3d5a5..741f9d4bfb5e1951b2e4b383d7abbc83bd456695 100644 (file)
@@ -222,6 +222,7 @@ cc_library(
         "graph_transformations/resolve_tensorflow_squeeze.cc",
         "graph_transformations/resolve_tensorflow_switch.cc",
         "graph_transformations/resolve_tensorflow_tile.cc",
+        "graph_transformations/resolve_transpose_attributes.cc",
         "graph_transformations/unfuse_activation_functions.cc",
     ],
     hdrs = [
index 2f583a4e16488f23ad6543be40b9e87de5547cd3..785dad85965bf9d0c1360d71da73eb8a4eefd28a 100644 (file)
@@ -158,6 +158,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolvePadAttributes)
 DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
 DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
 DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape)
 DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
 DECLARE_GRAPH_TRANSFORMATION(Dequantize)
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_transpose_attributes.cc
new file mode 100644 (file)
index 0000000..12d966b
--- /dev/null
@@ -0,0 +1,53 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
+  const auto op_it = model->operators.begin() + op_index;
+  if (op_it->get()->type != OperatorType::kTranspose) return false;
+
+  auto* op = static_cast<TransposeOperator*>(op_it->get());
+  if (!op->perm.empty()) return false;
+
+  CHECK_EQ(op->inputs.size(), 2);
+  if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+  // Handling perm.
+  const auto& perm_array = *model->arrays[op->inputs[1]];
+  if (!perm_array.has_shape()) return false;
+
+  const std::vector<int>& perm_dims = perm_array.shape().dims();
+  CHECK_EQ(perm_dims.size(), 1);
+
+  std::vector<int> perm_buffer =
+      perm_array.GetBuffer<ArrayDataType::kInt32>().data;
+  for (int i = 0; i < perm_dims[0]; ++i) {
+    op->perm.push_back(perm_buffer[i]);
+  }
+
+  return true;
+}
+
+}  // namespace toco
index eb66213f38287d8a5552a3f9dec5c844913a95a6..7b2235e2751e1bb359195a3d69f91725a5463434 100644 (file)
@@ -957,6 +957,7 @@ struct TensorFlowSquareOperator : Operator {
 // TensorFlow equivalent: Transpose
 struct TransposeOperator : Operator {
   TransposeOperator() : Operator(OperatorType::kTranspose) {}
+  std::vector<int> perm;
 };
 
 // Element-wise subtraction operator.
index 202e3e949d9f45e238763c007762f26c4b230d74..d15346144409f9da8d8fa5f61c8b2aac2bb03ca3 100644 (file)
@@ -506,6 +506,25 @@ class SpaceToDepth
   }
 };
 
+class Transpose
+    : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
+                             ::tflite::BuiltinOptions_TransposeOptions> {
+ public:
+  using BuiltinOperator::BuiltinOperator;
+  flatbuffers::Offset<TfLiteOptions> WriteOptions(
+      const TocoOperator& op,
+      flatbuffers::FlatBufferBuilder* builder) const override {
+    return ::tflite::CreateTransposeOptions(*builder,
+                                            builder->CreateVector(op.perm));
+  }
+
+  void ReadOptions(const TfLiteOptions& options,
+                   TocoOperator* op) const override {
+    op->perm.insert(op->perm.end(), options.perm()->begin(),
+                    options.perm()->end());
+  }
+};
+
 class Split : public CustomOperator<TensorFlowSplitOperator> {
  public:
   using CustomOperator::CustomOperator;
@@ -670,6 +689,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
                                     OperatorType::kSpaceToDepth));
   ops.emplace_back(
       new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
+  ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE,
+                                 OperatorType::kTranspose));
 
   // Custom Operators.
   ops.emplace_back(new Cast("CAST", OperatorType::kCast));
index d126bb496c35864488077ad7dd032b69c3ac7dca..bcf8ac04efccaa61063edd3789e1a54d78b031e5 100644 (file)
@@ -369,6 +369,15 @@ TEST_F(OperatorTest, Svdf) {
   EXPECT_EQ(op.rank, output_toco_op->rank);
 }
 
+TEST_F(OperatorTest, Transpose) {
+  TransposeOperator op;
+  op.perm = {0, 1, 2, 3};
+
+  auto output_toco_op = SerializeAndDeserialize(
+      GetOperator("TRANSPOSE", OperatorType::kTranspose), op);
+  EXPECT_EQ(op.perm, output_toco_op->perm);
+}
+
 TEST_F(OperatorTest, TensorFlowUnsupported) {
   TensorFlowUnsupportedOperator op;
   op.tensorflow_op = "MyCustomUnsupportedOp";
index 0eae148cc2df778fbcb9e159e6e6b3bd485b2dd9..01806040c80610b339a1a955b665d12bb5109a4a 100644 (file)
@@ -85,6 +85,7 @@ void MakeGeneralGraphTransformationsSet(
   transformations->Add(new ResolveStridedSliceAttributes);
   transformations->Add(new ResolveSliceAttributes);
   transformations->Add(new ResolveMeanAttributes);
+  transformations->Add(new ResolveTransposeAttributes);
   transformations->Add(new ResolveConstantTensorFlowShape);
   transformations->Add(new MakeInitialDequantizeOperator);
 }