int axis;
} TfLiteGatherParams;
+typedef struct {
+ // TODO(ahentz): We can't have dynamic data in this struct, at least not yet.
+ // For now we will fix the maximum possible number of dimensions.
+ int perm[8];
+ int num_dimensions;
+} TfLiteTransposeParams;
+
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
"space_to_batch_nd.cc",
"space_to_depth.cc",
"svdf.cc",
+ "transpose.cc",
"unidirectional_sequence_rnn.cc",
],
hdrs = [
}
template <typename T>
-void Transpose(const T* input, Dims<4>& input_dims, T* output,
- Dims<4>& output_dims, int* permuted_axes) {
+void Transpose(const T* input, const Dims<4>& input_dims, T* output,
+ const Dims<4>& output_dims, int* permuted_axes) {
int out_sizes[4];
// Compute the inverse permutation array so we can do an output centered
// transpose. Also, check to make sure output_dims is matching input_dims.
TfLiteRegistration* Register_SKIP_GRAM();
TfLiteRegistration* Register_SPACE_TO_DEPTH();
TfLiteRegistration* Register_GATHER();
+TfLiteRegistration* Register_TRANSPOSE();
BuiltinOpResolver::BuiltinOpResolver() {
AddBuiltin(BuiltinOperator_RELU, Register_RELU());
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
AddBuiltin(BuiltinOperator_GATHER, Register_GATHER());
+ AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE());
}
TfLiteRegistration* BuiltinOpResolver::FindOp(
--- /dev/null
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace transpose {
+
+// This file has two implementations of Transpose.
+enum KernelType {
+ kReference,
+};
+
+// TODO(nupurgarg): Permutation arrays represented as a tensor are ignored. Only
+// use the `perm` specified in `params`.
+struct TransposeContext {
+ TransposeContext(TfLiteContext* context, TfLiteNode* node) {
+ params = reinterpret_cast<TfLiteTransposeParams*>(node->builtin_data);
+ input = GetInput(context, node, 0);
+ output = GetOutput(context, node, 0);
+ }
+ TfLiteTransposeParams* params;
+ TfLiteTensor* input;
+ TfLiteTensor* output;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ TransposeContext op_context(context, node);
+ int dims = NumDimensions(op_context.input);
+
+ // Ensure validity of input tensor and permutation array.
+ TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type);
+ TF_LITE_ENSURE_EQ(context, dims, op_context.params->num_dimensions);
+ TF_LITE_ENSURE_MSG(context, dims <= 4,
+ "Transpose op only supports 1D-4D input arrays.");
+ for (int idx = 0; idx < dims; ++idx) {
+ TF_LITE_ENSURE_MSG(context,
+ op_context.params->perm[idx] >= 0 &&
+ op_context.params->perm[idx] < dims,
+ "Transpose op permutations array is out of bounds.");
+ }
+
+ // Determine size of output tensor.
+ const TfLiteIntArray* input_size = op_context.input->dims;
+ TfLiteIntArray* output_size = TfLiteIntArrayCreate(dims);
+ for (int idx = 0; idx < dims; ++idx) {
+ output_size->data[idx] = input_size->data[op_context.params->perm[idx]];
+ }
+
+ return context->ResizeTensor(context, op_context.output, output_size);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ TransposeContext op_context(context, node);
+
+ // Reverse the permuted axes and convert to 4D due to the way Dims are
+ // constructed in GetTensorDims.
+ const int kOutputDimensionNum = 4;
+ int reversed_perm[kOutputDimensionNum];
+ int size = op_context.params->num_dimensions;
+ for (int output_k = 0, input_k = size - 1; output_k < size;
+ ++output_k, --input_k) {
+ reversed_perm[output_k] = size - op_context.params->perm[input_k] - 1;
+ }
+ for (int k = size; k < kOutputDimensionNum; ++k) {
+ reversed_perm[k] = k;
+ }
+
+#define TF_LITE_TRANSPOSE(type, scalar) \
+ type::Transpose(GetTensorData<scalar>(op_context.input), \
+ GetTensorDims(op_context.input), \
+ GetTensorData<scalar>(op_context.output), \
+ GetTensorDims(op_context.output), reversed_perm)
+
+ switch (op_context.input->type) {
+ case kTfLiteFloat32:
+ if (kernel_type == kReference) {
+ TF_LITE_TRANSPOSE(reference_ops, float);
+ }
+ break;
+ case kTfLiteUInt8:
+ if (kernel_type == kReference) {
+ TF_LITE_TRANSPOSE(reference_ops, uint8_t);
+ }
+ break;
+ case kTfLiteInt32:
+ if (kernel_type == kReference) {
+ TF_LITE_TRANSPOSE(reference_ops, int32_t);
+ }
+ break;
+ case kTfLiteInt64:
+ if (kernel_type == kReference) {
+ TF_LITE_TRANSPOSE(reference_ops, int64_t);
+ }
+ break;
+ default:
+ context->ReportError(context,
+ "Type is currently not supported by Transpose.");
+ return kTfLiteError;
+ }
+#undef TF_LITE_TRANSPOSE
+
+ return kTfLiteOk;
+}
+
+} // namespace transpose
+
+TfLiteRegistration* Register_TRANSPOSE_REF() {
+ static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare,
+ transpose::Eval<transpose::kReference>};
+ return &r;
+}
+
+TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); }
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
#include "tensorflow/contrib/lite/interpreter.h"
#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
namespace tflite {
namespace {
+using ::testing::ElementsAreArray;
+
void RunTestPermutation(const std::vector<int>& shape,
const std::vector<int>& perms,
std::vector<float>* input_transposed) {
reversed_perms);
}
-TEST(TransposeTest, Test1D) {
+TEST(TransposeTest, TestRefOps1D) {
// Basic 1D identity.
std::vector<float> out;
RunTestPermutation({3}, {0}, &out);
ASSERT_EQ(out, std::vector<float>({0, 1, 2}));
}
-TEST(TransposeTest, Test2D) {
+TEST(TransposeTest, TestRefOps2D) {
std::vector<float> out;
// Basic 2D.
RunTestPermutation({3, 2}, {1, 0}, &out);
ASSERT_EQ(out, std::vector<float>({0, 1, 2, 3, 4, 5}));
}
-TEST(TransposeTest, Test3D) {
+TEST(TransposeTest, TestRefOps3D) {
std::vector<float> out;
// Test 3 dimensional
{
}
}
-TEST(TransposeTest, Test4D) {
+TEST(TransposeTest, TestRefOps4D) {
std::vector<float> out;
// Basic 4d.
RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out);
ASSERT_EQ(out, ref);
}
+class TransposeOpModel : public SingleOpModel {
+ public:
+ TransposeOpModel(std::initializer_list<int> input_shape,
+ std::initializer_list<int> perm) {
+ input_ = AddInput(TensorType_FLOAT32);
+ output_ = AddOutput(TensorType_FLOAT32);
+ SetBuiltinOp(
+ BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions,
+ CreateTransposeOptions(builder_, builder_.CreateVector<int>(perm))
+ .Union());
+ BuildInterpreter({input_shape});
+ }
+
+ void SetInput(std::initializer_list<float> data) {
+ PopulateTensor<float>(input_, data);
+ }
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+ int input_;
+ int output_;
+};
+
+TEST(TransposeTest, TestUnequalPermSize) {
+ EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {2, 2}),
+ "dims != op_context.params->num_dimensions");
+}
+
+TEST(TransposeTest, TestPermOutOfBounds) {
+ EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, -1, -2, -3}),
+ "Transpose op permutations array is out of bounds.");
+ EXPECT_DEATH(TransposeOpModel({1, 3, 3, 1}, {0, 1, 2, 4}),
+ "Transpose op permutations array is out of bounds.");
+}
+
+TEST(TransposeTest, Test1DInputTensor) {
+ TransposeOpModel m({3}, {0});
+ m.SetInput({1, 2, 3});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3}));
+}
+
+TEST(TransposeTest, Test2DInputTensor) {
+ TransposeOpModel m({3, 2}, {1, 0});
+ m.SetInput({0, 1, 2, 3, 4, 5});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5}));
+}
+
+TEST(TransposeTest, Test3DInputTensor) {
+ TransposeOpModel m({2, 3, 4}, {2, 0, 1});
+ m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3}));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21,
+ 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}));
+}
+
+TEST(TransposeTest, Test5DInputTensor) {
+ EXPECT_DEATH(TransposeOpModel({1, 2, 3, 4, 5}, {0, 1, 2, 3, 4}),
+ "Transpose op only supports 1D-4D input arrays.");
+}
+
+TEST(TransposeTest, SimpleTestNoReorder) {
+ TransposeOpModel m({1, 2, 3, 1}, {0, 1, 2, 3});
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6}));
+}
+
+TEST(TransposeTest, SimpleTestWithReorder) {
+ TransposeOpModel m({1, 2, 3, 1}, {2, 1, 3, 0});
+ m.SetInput({1, 2, 3, 4, 5, 6});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2, 1, 1}));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6}));
+}
+
+TEST(TransposeTest, ComplexTestWithReorder) {
+ TransposeOpModel m({2, 3, 4, 5}, {2, 0, 1, 3});
+ m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+ 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
+ 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47,
+ 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59,
+ 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
+ 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
+ 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95,
+ 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107,
+ 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5}));
+ auto result = ElementsAreArray(
+ {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44,
+ 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104,
+ 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49,
+ 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109,
+ 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54,
+ 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114,
+ 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59,
+ 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119});
+ EXPECT_THAT(m.GetOutput(), result);
+}
+
} // namespace
} // namespace tflite
builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_TRANSPOSE: {
+ auto* params = MallocPOD<TfLiteTransposeParams>();
+ if (auto* schema_params = op->builtin_options_as_TransposeOptions()) {
+ const auto& perm = schema_params->perm();
+ FlatBufferIntVectorToArray(sizeof(params->perm), perm, params->perm,
+ error_reporter);
+ params->num_dimensions = perm->Length();
+ }
+ builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
}
return builtin_data;
}
case tflite::BuiltinOperator_GATHER:
case tflite::BuiltinOperator_SPACE_TO_BATCH_ND:
case tflite::BuiltinOperator_BATCH_TO_SPACE_ND:
+ case tflite::BuiltinOperator_TRANSPOSE:
FATAL("Op code %d is currently not delegated to NNAPI", builtin);
nn_op_type = -1; // set to invalid
break;
GATHER = 36,
BATCH_TO_SPACE_ND = 37,
SPACE_TO_BATCH_ND = 38,
+ TRANSPOSE = 39,
}
// Options for the builtin operators.
GatherOptions,
BatchToSpaceNDOptions,
SpaceToBatchNDOptions,
+ TransposeOptions,
}
enum Padding : byte { SAME, VALID }
axis: int;
}
+table TransposeOptions {
+ perm:[int];
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-
// automatically generated by the FlatBuffers compiler, do not modify
#ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_
struct GatherOptions;
struct GatherOptionsT;
+struct TransposeOptions;
+struct TransposeOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
BuiltinOperator_GATHER = 36,
BuiltinOperator_BATCH_TO_SPACE_ND = 37,
BuiltinOperator_SPACE_TO_BATCH_ND = 38,
+ BuiltinOperator_TRANSPOSE = 39,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_SPACE_TO_BATCH_ND
+ BuiltinOperator_MAX = BuiltinOperator_TRANSPOSE
};
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[36] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[37] {
static BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN,
BuiltinOperator_GATHER,
BuiltinOperator_BATCH_TO_SPACE_ND,
- BuiltinOperator_SPACE_TO_BATCH_ND};
+ BuiltinOperator_SPACE_TO_BATCH_ND,
+ BuiltinOperator_TRANSPOSE};
return values;
}
"GATHER",
"BATCH_TO_SPACE_ND",
"SPACE_TO_BATCH_ND",
+ "TRANSPOSE",
nullptr};
return names;
}
BuiltinOptions_GatherOptions = 23,
BuiltinOptions_BatchToSpaceNDOptions = 24,
BuiltinOptions_SpaceToBatchNDOptions = 25,
+ BuiltinOptions_TransposeOptions = 26,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_SpaceToBatchNDOptions
+ BuiltinOptions_MAX = BuiltinOptions_TransposeOptions
};
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[26] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[27] {
static BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
BuiltinOptions_PadOptions,
BuiltinOptions_GatherOptions,
BuiltinOptions_BatchToSpaceNDOptions,
- BuiltinOptions_SpaceToBatchNDOptions};
+ BuiltinOptions_SpaceToBatchNDOptions,
+ BuiltinOptions_TransposeOptions};
return values;
}
"GatherOptions",
"BatchToSpaceNDOptions",
"SpaceToBatchNDOptions",
+ "TransposeOptions",
nullptr};
return names;
}
static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions;
};
+template <>
+struct BuiltinOptionsTraits<TransposeOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
? reinterpret_cast<const SpaceToBatchNDOptionsT *>(value)
: nullptr;
}
+ TransposeOptionsT *AsTransposeOptions() {
+ return type == BuiltinOptions_TransposeOptions
+ ? reinterpret_cast<TransposeOptionsT *>(value)
+ : nullptr;
+ }
+ const TransposeOptionsT *AsTransposeOptions() const {
+ return type == BuiltinOptions_TransposeOptions
+ ? reinterpret_cast<const TransposeOptionsT *>(value)
+ : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o,
const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct TransposeOptionsT : public flatbuffers::NativeTable {
+ typedef TransposeOptions TableType;
+ std::vector<int32_t> perm;
+ TransposeOptionsT() {}
+};
+
+struct TransposeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef TransposeOptionsT NativeTableType;
+ enum { VT_PERM = 4 };
+ const flatbuffers::Vector<int32_t> *perm() const {
+ return GetPointer<const flatbuffers::Vector<int32_t> *>(VT_PERM);
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_PERM) &&
+ verifier.Verify(perm()) && verifier.EndTable();
+ }
+ TransposeOptionsT *UnPack(
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(
+ TransposeOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<TransposeOptions> Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct TransposeOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_perm(flatbuffers::Offset<flatbuffers::Vector<int32_t>> perm) {
+ fbb_.AddOffset(TransposeOptions::VT_PERM, perm);
+ }
+ explicit TransposeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ TransposeOptionsBuilder &operator=(const TransposeOptionsBuilder &);
+ flatbuffers::Offset<TransposeOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<TransposeOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ flatbuffers::Offset<flatbuffers::Vector<int32_t>> perm = 0) {
+ TransposeOptionsBuilder builder_(_fbb);
+ builder_.add_perm(perm);
+ return builder_.Finish();
+}
+
+inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptionsDirect(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ const std::vector<int32_t> *perm = nullptr) {
+ return tflite::CreateTransposeOptions(
+ _fbb, perm ? _fbb.CreateVector<int32_t>(*perm) : 0);
+}
+
+flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
? static_cast<const SpaceToBatchNDOptions *>(builtin_options())
: nullptr;
}
+ const TransposeOptions *builtin_options_as_TransposeOptions() const {
+ return builtin_options_type() == BuiltinOptions_TransposeOptions
+ ? static_cast<const TransposeOptions *>(builtin_options())
+ : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
return builtin_options_as_SpaceToBatchNDOptions();
}
+template <>
+inline const TransposeOptions *Operator::builtin_options_as<TransposeOptions>()
+ const {
+ return builtin_options_as_TransposeOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
return tflite::CreateGatherOptions(_fbb, _axis);
}
+inline TransposeOptionsT *TransposeOptions::UnPack(
+ const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new TransposeOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void TransposeOptions::UnPackTo(
+ TransposeOptionsT *_o,
+ const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ {
+ auto _e = perm();
+ if (_e) {
+ _o->perm.resize(_e->size());
+ for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) {
+ _o->perm[_i] = _e->Get(_i);
+ }
+ }
+ };
+}
+
+inline flatbuffers::Offset<TransposeOptions> TransposeOptions::Pack(
+ flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateTransposeOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
+ flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
+ const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs {
+ flatbuffers::FlatBufferBuilder *__fbb;
+ const TransposeOptionsT *__o;
+ const flatbuffers::rehasher_function_t *__rehasher;
+ } _va = {&_fbb, _o, _rehasher};
+ (void)_va;
+ auto _perm = _o->perm.size() ? _fbb.CreateVector(_o->perm) : 0;
+ return tflite::CreateTransposeOptions(_fbb, _perm);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(
const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
auto ptr = reinterpret_cast<const SpaceToBatchNDOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_TransposeOptions: {
+ auto ptr = reinterpret_cast<const TransposeOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default:
return false;
}
auto ptr = reinterpret_cast<const SpaceToBatchNDOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_TransposeOptions: {
+ auto ptr = reinterpret_cast<const TransposeOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default:
return nullptr;
}
auto ptr = reinterpret_cast<const SpaceToBatchNDOptionsT *>(value);
return CreateSpaceToBatchNDOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_TransposeOptions: {
+ auto ptr = reinterpret_cast<const TransposeOptionsT *>(value);
+ return CreateTransposeOptions(_fbb, ptr, _rehasher).Union();
+ }
default:
return 0;
}
*reinterpret_cast<SpaceToBatchNDOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_TransposeOptions: {
+ value = new TransposeOptionsT(
+ *reinterpret_cast<TransposeOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
delete ptr;
break;
}
+ case BuiltinOptions_TransposeOptions: {
+ auto ptr = reinterpret_cast<TransposeOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default:
break;
}
"softmax.zip",
"space_to_batch_nd.zip",
"space_to_depth.zip",
+ "transpose.zip",
],
)
make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+def make_transpose_tests(zip_path):
+ """Make a set of tests to do transpose."""
+
+ # TODO(nupurgarg): Add test for uint8.
+ test_parameters = [{
+ "dtype": [tf.int32, tf.int64, tf.float32],
+ "input_shape": [[2, 2, 3]],
+ "perm": [[0, 1, 2], [0, 2, 1]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape": [[1, 2, 3, 4]],
+ "perm": [[0, 1, 2, 3], [3, 0, 1, 2]],
+ }, {
+ "dtype": [tf.float32],
+ "input_shape": [[1, 2, 3, 4, 5]],
+ "perm": [[0, 1, 2, 3, 4]],
+ }]
+
+ def build_graph(parameters):
+ input_tensor = tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input",
+ shape=parameters["input_shape"])
+ out = tf.transpose(input_tensor, perm=parameters["perm"])
+ return [input_tensor], [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ input_values = create_tensor_data(parameters["dtype"],
+ parameters["input_shape"])
+ return [input_values], sess.run(
+ outputs, feed_dict=dict(zip(inputs, [input_values])))
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_l2_pool(input_tensor, ksize, strides, padding, data_format):
"""Given an input perform a sequence of TensorFlow ops to produce l2pool."""
return tf.sqrt(tf.nn.avg_pool(
"sigmoid.zip": make_sigmoid_tests,
"softmax.zip": make_softmax_tests,
"space_to_depth.zip": make_space_to_depth_tests,
+ "transpose.zip": make_transpose_tests,
}
out = FLAGS.zip_to_output
bin_path = FLAGS.toco
// ResizeBilinear looks completely incompatible with Tensorflow
{R"(resize_bilinear)", "67964336"},
+
+ // Transpose only supports 1D-4D input tensors.
+ {R"(transposedtype=.*,input_shape=\[.,.,.,.,.\],perm=.*)", "71545879"},
};
// Allows test data to be unzipped into a temporary directory and makes
INSTANTIATE_TESTS(sigmoid)
INSTANTIATE_TESTS(softmax)
INSTANTIATE_TESTS(space_to_depth)
+INSTANTIATE_TESTS(transpose)
} // namespace testing
} // namespace tflite
"graph_transformations/resolve_tensorflow_squeeze.cc",
"graph_transformations/resolve_tensorflow_switch.cc",
"graph_transformations/resolve_tensorflow_tile.cc",
+ "graph_transformations/resolve_transpose_attributes.cc",
"graph_transformations/unfuse_activation_functions.cc",
],
hdrs = [
DECLARE_GRAPH_TRANSFORMATION(ResolveStridedSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveSliceAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveMeanAttributes)
+DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantTensorFlowShape)
DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill)
DECLARE_GRAPH_TRANSFORMATION(Dequantize)
--- /dev/null
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
+#include "tensorflow/contrib/lite/toco/model.h"
+#include "tensorflow/contrib/lite/toco/tooling_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace toco {
+
+bool ResolveTransposeAttributes::Run(Model* model, std::size_t op_index) {
+ const auto op_it = model->operators.begin() + op_index;
+ if (op_it->get()->type != OperatorType::kTranspose) return false;
+
+ auto* op = static_cast<TransposeOperator*>(op_it->get());
+ if (!op->perm.empty()) return false;
+
+ CHECK_EQ(op->inputs.size(), 2);
+ if (!IsConstantParameterArray(*model, op->inputs[1])) return false;
+
+ // Handling perm.
+ const auto& perm_array = *model->arrays[op->inputs[1]];
+ if (!perm_array.has_shape()) return false;
+
+ const std::vector<int>& perm_dims = perm_array.shape().dims();
+ CHECK_EQ(perm_dims.size(), 1);
+
+ std::vector<int> perm_buffer =
+ perm_array.GetBuffer<ArrayDataType::kInt32>().data;
+ for (int i = 0; i < perm_dims[0]; ++i) {
+ op->perm.push_back(perm_buffer[i]);
+ }
+
+ return true;
+}
+
+} // namespace toco
// TensorFlow equivalent: Transpose
struct TransposeOperator : Operator {
TransposeOperator() : Operator(OperatorType::kTranspose) {}
+ std::vector<int> perm;
};
// Element-wise subtraction operator.
}
};
+class Transpose
+ : public BuiltinOperator<TransposeOperator, ::tflite::TransposeOptions,
+ ::tflite::BuiltinOptions_TransposeOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateTransposeOptions(*builder,
+ builder->CreateVector(op.perm));
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->perm.insert(op->perm.end(), options.perm()->begin(),
+ options.perm()->end());
+ }
+};
+
class Split : public CustomOperator<TensorFlowSplitOperator> {
public:
using CustomOperator::CustomOperator;
OperatorType::kSpaceToDepth));
ops.emplace_back(
new Svdf(::tflite::BuiltinOperator_SVDF, OperatorType::kSvdf));
+ ops.emplace_back(new Transpose(::tflite::BuiltinOperator_TRANSPOSE,
+ OperatorType::kTranspose));
// Custom Operators.
ops.emplace_back(new Cast("CAST", OperatorType::kCast));
EXPECT_EQ(op.rank, output_toco_op->rank);
}
+TEST_F(OperatorTest, Transpose) {
+ TransposeOperator op;
+ op.perm = {0, 1, 2, 3};
+
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("TRANSPOSE", OperatorType::kTranspose), op);
+ EXPECT_EQ(op.perm, output_toco_op->perm);
+}
+
TEST_F(OperatorTest, TensorFlowUnsupported) {
TensorFlowUnsupportedOperator op;
op.tensorflow_op = "MyCustomUnsupportedOp";
transformations->Add(new ResolveStridedSliceAttributes);
transformations->Add(new ResolveSliceAttributes);
transformations->Add(new ResolveMeanAttributes);
+ transformations->Add(new ResolveTransposeAttributes);
transformations->Add(new ResolveConstantTensorFlowShape);
transformations->Add(new MakeInitialDequantizeOperator);
}