From: Андрей Шедько/AI Tools Lab /SRR/Assistant Engineer/삼성전자 Date: Thu, 15 Nov 2018 13:04:26 +0000 (+0300) Subject: [nnc] Tfl importer eltwise (#2208) X-Git-Tag: nncc_backup~1311 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=89d76c2b121e8a7fbea04cd378a20973222248ee;p=platform%2Fcore%2Fml%2Fnnfw.git [nnc] Tfl importer eltwise (#2208) Added elementwise operations, Tanh and Transposed Conv to tfLite importer Signed-off-by: Andrei Shedko --- diff --git a/contrib/nnc/passes/tflite_frontend/schema/schema.fbs b/contrib/nnc/passes/tflite_frontend/schema/schema.fbs index 7d2e00f..07b16f2 100644 --- a/contrib/nnc/passes/tflite_frontend/schema/schema.fbs +++ b/contrib/nnc/passes/tflite_frontend/schema/schema.fbs @@ -25,6 +25,9 @@ file_identifier "TFL3"; // File extension of any written files. file_extension "tflite"; +// IMPORTANT: All new members of tables, enums and unions must be added at the +// end to ensure backwards compatibility. + // The type of data stored in a tensor. enum TensorType : byte { FLOAT32 = 0, @@ -33,6 +36,9 @@ enum TensorType : byte { UINT8 = 3, INT64 = 4, STRING = 5, + BOOL = 6, + INT16 = 7, + COMPLEX64 = 8, } // Parameters for converting a quantized tensor back to float. Given a @@ -41,7 +47,7 @@ enum TensorType : byte { table QuantizationParameters { min:[float]; // For importing back into tensorflow. max:[float]; // For importing back into tensorflow. - scale:[float]; + scale:[float]; // For dequantizing the tensor's values. zero_point:[long]; } @@ -62,9 +68,11 @@ table Tensor { buffer:uint; name:string; // For debugging and importing back into tensorflow. quantization:QuantizationParameters; // Optional. + + is_variable:bool = false; } -// A list of builtin operators. Builtin operators a slighlty faster than custom +// A list of builtin operators. Builtin operators are slightly faster than custom // ones, but not by much. Moreover, while custom operators accept an opaque // object containing configuration parameters, builtins have a predetermined // set of acceptable options. @@ -77,7 +85,7 @@ enum BuiltinOperator : byte { // DEPTH_TO_SPACE = 5, DEQUANTIZE = 6, EMBEDDING_LOOKUP = 7, - // FLOOR = 8, + FLOOR = 8, FULLY_CONNECTED = 9, HASHTABLE_LOOKUP = 10, L2_NORMALIZATION = 11, @@ -132,6 +140,48 @@ enum BuiltinOperator : byte { CAST = 53, PRELU = 54, MAXIMUM = 55, + ARG_MAX = 56, + MINIMUM = 57, + LESS = 58, + NEG = 59, + PADV2 = 60, + GREATER = 61, + GREATER_EQUAL = 62, + LESS_EQUAL = 63, + SELECT = 64, + SLICE = 65, + SIN = 66, + TRANSPOSE_CONV = 67, + SPARSE_TO_DENSE = 68, + TILE = 69, + EXPAND_DIMS = 70, + EQUAL = 71, + NOT_EQUAL = 72, + LOG = 73, + SUM = 74, + SQRT = 75, + RSQRT = 76, + SHAPE = 77, + POW = 78, + ARG_MIN = 79, + FAKE_QUANT = 80, + REDUCE_PROD = 81, + REDUCE_MAX = 82, + PACK = 83, + LOGICAL_OR = 84, + ONE_HOT = 85, + LOGICAL_AND = 86, + LOGICAL_NOT = 87, + UNPACK = 88, + REDUCE_MIN = 89, + FLOOR_DIV = 90, + REDUCE_ANY = 91, + SQUARE = 92, + ZEROS_LIKE = 93, + FILL = 94, + FLOOR_MOD = 95, + RANGE = 96, + RESIZE_NEAREST_NEIGHBOR = 97, } // Options for the builtin operators. @@ -162,7 +212,7 @@ union BuiltinOptions { BatchToSpaceNDOptions, SpaceToBatchNDOptions, TransposeOptions, - MeanOptions, + ReducerOptions, SubOptions, DivOptions, SqueezeOptions, @@ -174,7 +224,42 @@ union BuiltinOptions { LogSoftmaxOptions, CastOptions, DequantizeOptions, - MaximumOptions, + MaximumMinimumOptions, + ArgMaxOptions, + LessOptions, + NegOptions, + PadV2Options, + GreaterOptions, + GreaterEqualOptions, + LessEqualOptions, + SelectOptions, + SliceOptions, + TransposeConvOptions, + SparseToDenseOptions, + TileOptions, + ExpandDimsOptions, + EqualOptions, + NotEqualOptions, + ShapeOptions, + PowOptions, + ArgMinOptions, + FakeQuantOptions, + PackOptions, + LogicalOrOptions, + OneHotOptions, + LogicalAndOptions, + LogicalNotOptions, + UnpackOptions, + FloorDivOptions, + SquareOptions, + ZerosLikeOptions, + FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, + UnidirectionalSequenceLSTMOptions, + FloorModOptions, + RangeOptions, + ResizeNearestNeighborOptions, } enum Padding : byte { SAME, VALID } @@ -193,6 +278,8 @@ table Conv2DOptions { stride_w:int; stride_h:int; fused_activation_function:ActivationFunctionType; + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; } table Pool2DOptions { @@ -205,11 +292,15 @@ table Pool2DOptions { } table DepthwiseConv2DOptions { + // Parameters for DepthwiseConv version 1 or above. padding:Padding; stride_w:int; stride_h:int; depth_multiplier:int; fused_activation_function:ActivationFunctionType; + // Parameters for DepthwiseConv version 2 or above. + dilation_w_factor:int = 1; + dilation_h_factor:int = 1; } table ConcatEmbeddingsOptions { @@ -248,11 +339,21 @@ table SequenceRNNOptions { table BidirectionalSequenceRNNOptions { time_major:bool; fused_activation_function:ActivationFunctionType; + merge_outputs: bool; +} + +enum FullyConnectedOptionsWeightsFormat: byte { + DEFAULT = 0, + SHUFFLED4x16INT8 = 1, } // An implementation of TensorFlow fully_connected (a.k.a Dense) layer. table FullyConnectedOptions { + // Parameters for FullyConnected version 1 or above. fused_activation_function:ActivationFunctionType; + + // Parameters for FullyConnected version 2 or above. + weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT; } table SoftmaxOptions { @@ -284,11 +385,42 @@ table LocalResponseNormalizationOptions { beta:float; } +enum LSTMKernelType : byte { + // Full LSTM kernel which supports peephole and projection. + FULL = 0, + // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell. + BASIC = 1, +} + // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell table LSTMOptions { + // Parameters for LSTM version 1 or above. + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // Parameters for LSTM version 2 or above. + // Basic kernel is only supported in version 2 or above. + kernel_type: LSTMKernelType = FULL; +} + +// An implementation of TensorFlow dynamic_rnn with LSTMCell. +table UnidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true then first dimension is sequence, otherwise batch. + time_major:bool; +} + +table BidirectionalSequenceLSTMOptions { fused_activation_function:ActivationFunctionType; cell_clip: float; // Optional, 0.0 means no clipping proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; } table ResizeBilinearOptions { @@ -297,6 +429,10 @@ table ResizeBilinearOptions { align_corners: bool; } +table ResizeNearestNeighborOptions { + align_corners: bool; +} + // A call operation options table CallOptions { // The subgraph index that needs to be called. @@ -306,6 +442,9 @@ table CallOptions { table PadOptions { } +table PadV2Options { +} + table ReshapeOptions { new_shape:[int]; } @@ -357,7 +496,7 @@ table TransposeOptions { table ExpOptions { } -table MeanOptions { +table ReducerOptions { keep_dims: bool; } @@ -381,12 +520,124 @@ table LogSoftmaxOptions { } table CastOptions { + in_data_type: TensorType; + out_data_type: TensorType; } table DequantizeOptions { } -table MaximumOptions { +table MaximumMinimumOptions { +} + +table TileOptions { +} + +table ArgMaxOptions { + output_type : TensorType; +} + +table ArgMinOptions { + output_type : TensorType; +} + +table GreaterOptions { +} + +table GreaterEqualOptions { +} + +table LessOptions { +} + +table LessEqualOptions { +} + +table NegOptions { +} + +table SelectOptions { +} + +table SliceOptions { +} + +table TransposeConvOptions { + padding:Padding; + stride_w:int; + stride_h:int; +} + +table ExpandDimsOptions { +} + +table SparseToDenseOptions { + validate_indices:bool; +} + +table EqualOptions { +} + +table NotEqualOptions { +} + +table ShapeOptions { + // Optional output type of the operation (int32 or int64). Defaults to int32. + out_type : TensorType; +} + +table PowOptions { +} + +table FakeQuantOptions { + // Parameters supported by version 1: + min:float; + max:float; + num_bits:int; + + // Parameters supported by version 2: + narrow_range:bool; +} + +table PackOptions { + values_count:int; + axis:int; +} + +table LogicalOrOptions { +} + +table OneHotOptions { + axis:int; +} + +table LogicalAndOptions { +} + +table LogicalNotOptions { +} + +table UnpackOptions { + num:int; + axis:int; +} + +table FloorDivOptions { +} + +table SquareOptions { +} + +table ZerosLikeOptions { +} + +table FillOptions { +} + +table FloorModOptions { +} + +table RangeOptions { } // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a @@ -394,6 +645,10 @@ table MaximumOptions { table OperatorCode { builtin_code:BuiltinOperator; custom_code:string; + + // The version of the operator. The version need to be bumped whenever new + // parameters are introduced into an op. + version:int = 1; } enum CustomOptionsFormat : byte { @@ -416,30 +671,44 @@ table Operator { builtin_options:BuiltinOptions; custom_options:[ubyte]; custom_options_format:CustomOptionsFormat; + + // A list of booleans indicating the input tensors which are being mutated by + // this operator.(e.g. used by RNN and LSTM). + // For example, if the "inputs" array refers to 5 tensors and the second and + // fifth are mutable variables, then this list will contain + // [false, true, false, false, true]. + // + // If the list is empty, no variable is mutated in this operator. + // The list either has the same length as `inputs`, or is empty. + mutating_variable_inputs:[bool]; } -// The root type, defining a model. +// The root type, defining a subgraph, which typically represents an entire +// model. table SubGraph { - // A list of all tensors used in this model. + // A list of all tensors used in this subgraph. tensors:[Tensor]; - // Indices of the input tensors. + // Indices of the tensors that are inputs into this subgraph. Note this is + // the list of non-static tensors that feed into the subgraph for inference. inputs:[int]; - // Indices of the output tensors. + // Indices of the tensors that are outputs out of this subgraph. Note this is + // the list of output tensors that are considered the product of the + // subgraph's inference. outputs:[int]; // All operators, in execution order. operators:[Operator]; - // Name of subgraph (used for debugging). + // Name of this subgraph (used for debugging). name:string; } // Table of raw data buffers (used for constant tensors). Referenced by tensors -// by index. +// by index. The generous alignment accommodates mmap-friendly data structures. table Buffer { - data:[ubyte]; + data:[ubyte] (force_align: 16); } table Model { @@ -458,9 +727,14 @@ table Model { // A description of the model. description:string; - // Buffers of the model + // Buffers of the model. + // Note the 0th entry of this array must be an empty buffer (sentinel). + // This is a convention so that tensors without a buffer can provide 0 as + // their buffer. buffers:[Buffer]; + // Metadata about the model. Indirects into the existings buffers list. + metadata_buffer:[int]; } root_type Model; diff --git a/contrib/nnc/passes/tflite_frontend/schema/schema.meta b/contrib/nnc/passes/tflite_frontend/schema/schema.meta index 74668ab..fd3eec0 100644 --- a/contrib/nnc/passes/tflite_frontend/schema/schema.meta +++ b/contrib/nnc/passes/tflite_frontend/schema/schema.meta @@ -1,2 +1,2 @@ REPO=https://github.com/tensorflow/tensorflow.git -COMMIT=c7a04561fb8 +COMMIT=61c6c84 diff --git a/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp b/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp index 698237c..318658a 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_importer.cpp @@ -93,6 +93,10 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) { case BuiltinOperator_SOFTMAX: case BuiltinOperator_RESHAPE: case BuiltinOperator_SQUEEZE: + case BuiltinOperator_ADD: + case BuiltinOperator_MUL: + case BuiltinOperator_MAXIMUM: + case BuiltinOperator_TRANSPOSE_CONV: // No checks break; default: @@ -184,6 +188,18 @@ void TfliteImporter::walkOperator(const Operator* op) { break; case BuiltinOperator_SQUEEZE: outputs = _opCreator->createSqueeze(inputs, params, op->builtin_options_as()); + case BuiltinOperator_ADD: + outputs = _opCreator->createAdd(inputs, params, op->builtin_options_as()); + break; + case BuiltinOperator_MUL: + outputs = _opCreator->createMul(inputs, params, op->builtin_options_as()); + break; + case BuiltinOperator_MAXIMUM: + outputs = _opCreator->createMax(inputs, params, op->builtin_options_as()); + break; + case BuiltinOperator_TRANSPOSE_CONV: + outputs = _opCreator->createTransposeConv( + inputs, params,op->builtin_options_as()); break; default: assert(false && "All unsupported types should have been found before this pass."); diff --git a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp index 0845ab2..901ce09 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp +++ b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp @@ -23,6 +23,9 @@ #include "core/modelIR/operations/FullyConnectedOp.h" #include "core/modelIR/operations/ReluOp.h" #include "core/modelIR/operations/CappedReluOp.h" +#include "core/modelIR/operations/TanhOp.h" +#include "core/modelIR/operations/ElementwiseOp.h" +#include "core/modelIR/operations/Deconv2DOp.h" #include "core/modelIR/operations/SoftmaxOp.h" #include "core/modelIR/operations/PoolOp.h" #include "core/modelIR/operations/BiasAddOp.h" @@ -126,6 +129,38 @@ std::vector TFLiteOpCreator::convertReshape(InputOps inputs, In return outputs; } +std::vector TFLiteOpCreator::createTransposeConv( + InputOps& inputs, + InputParams& params, + const ::tflite::TransposeConvOptions* opts) {// first param is output shape + return createOp(inputs, ActivationFunctionType_NONE, std::move(*params[1]), + Shape{static_cast(opts->stride_h()), + static_cast(opts->stride_w()), 1}, + paddingMap[opts->padding()]); +} + +std::vector +TFLiteOpCreator::createAdd(InputOps& inputs, + InputParams&, const ::tflite::AddOptions* opts) { + return createOp( + inputs, opts->fused_activation_function(), ops::ElementwiseOp::OpType::sum, inputs.size()); +} + +std::vector +TFLiteOpCreator::createMul(InputOps& inputs, + InputParams&, const ::tflite::MulOptions* opts) { + return createOp( + inputs, opts->fused_activation_function(), ops::ElementwiseOp::OpType::prod, inputs.size()); +} + + +std::vector +TFLiteOpCreator::createMax(InputOps& inputs, + InputParams&, const ::tflite::MaximumMinimumOptions* opts) { + return createOp( + inputs, ActivationFunctionType_NONE, ops::ElementwiseOp::OpType::max, inputs.size()); +} + void TFLiteOpCreator::checkFullyConnected(const FullyConnectedOptions* opts, std::set& problems_op_set) { checkActivationType(opts->fused_activation_function(), problems_op_set); @@ -150,7 +185,8 @@ void TFLiteOpCreator::checkActivationType(ActivationFunctionType activation_type std::set& problems_op_set) { if (activation_type != ActivationFunctionType_NONE && activation_type != ActivationFunctionType_RELU - && activation_type != ActivationFunctionType_RELU6) + && activation_type != ActivationFunctionType_RELU6 + && activation_type != ActivationFunctionType_TANH) problems_op_set.insert(std::string("Unsupported activation type: ") + EnumNamesActivationFunctionType()[activation_type]); } @@ -168,6 +204,9 @@ mir::Operation* TFLiteOpCreator::addFusedActivation(mir::Operation* input, case ActivationFunctionType_RELU6: activation = graph->create("", 6); break; + case ActivationFunctionType_TANH: + activation = graph->create(""); + break; default: assert(false && "Unsupported activation types must be detected before this pass"); } diff --git a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h index fa9e0ad..bceed15 100644 --- a/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h +++ b/contrib/nnc/passes/tflite_frontend/tflite_op_creator.h @@ -71,6 +71,21 @@ public: std::vector createSqueeze(InputOps& inputs, InputParams& params, const ::tflite::SqueezeOptions* opts); + /** @brief Elementwise Add */ + std::vector createAdd(InputOps&, InputParams&, const ::tflite::AddOptions*); + /** @brief Elementwise product */ + std::vector createMul(InputOps&, InputParams&, const ::tflite::MulOptions*); + /** @brief Elementwise maximum */ + std::vector createMax(InputOps&, InputParams&, const ::tflite::MaximumMinimumOptions*); + + /** + * @brief Creates a Transposed convolution + * @param params 0 - output shape (unused), 1 - kernel, 2- input + */ + std::vector createTransposeConv( + InputOps&, InputParams&, + const ::tflite::TransposeConvOptions*); + void checkPool2D(const ::tflite::Pool2DOptions*, std::set&); void checkConcatenation(const ::tflite::ConcatenationOptions*, std::set&); diff --git a/contrib/nnc/unittests/tflite_frontend/CMakeLists.txt b/contrib/nnc/unittests/tflite_frontend/CMakeLists.txt index bb4cb76..8e47a7b 100644 --- a/contrib/nnc/unittests/tflite_frontend/CMakeLists.txt +++ b/contrib/nnc/unittests/tflite_frontend/CMakeLists.txt @@ -1,6 +1,13 @@ file(GLOB_RECURSE TESTS "*.cpp") -if (NNC_FRONTEND_TFLITE_ENABLED) +#Feature detect: +execute_process( + COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test_data/gen_test.py + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/test_data/ + RESULT_VARIABLE test_create_failed +) + +if (NNC_FRONTEND_TFLITE_ENABLED AND NOT ${test_create_failed}) add_definitions(-DTFLITE_TEST_DIR="${CMAKE_CURRENT_SOURCE_DIR}/test_data/") add_nnc_unit_test(nnc_tflite_frontend_test ${TESTS} ${OPTIONS_SRC}) if (TARGET nnc_tflite_frontend_test) diff --git a/contrib/nnc/unittests/tflite_frontend/test_data/gen_test.py b/contrib/nnc/unittests/tflite_frontend/test_data/gen_test.py new file mode 100755 index 0000000..30032a2 --- /dev/null +++ b/contrib/nnc/unittests/tflite_frontend/test_data/gen_test.py @@ -0,0 +1,25 @@ +#!/usr/bin/python3 +import tensorflow as tf +import numpy as np + +output_shape = [1, 28, 28, 1] +strides = [1,1,1,1] +W = tf.constant(np.ones([7, 7, 1, 1]).astype(np.float32), name = "ker_d") + +# Create the graph. +X = tf.placeholder(shape=[1, 28, 28, 1], name='input', dtype=tf.float32) +Y = tf.sin(X) + +out0 = tf.identity(Y, name="out") +# Filter the input image. +with tf.Session() as sess: + print('Evaluating...') + out = sess.run(out0, feed_dict = {"input:0": np.ones((1, 28, 28, 1)).astype(np.float32)}) + # print(sess.graph_def) + + frozen_graphdef = tf.graph_util.convert_variables_to_constants( + sess, sess.graph_def, ["out"]) + tflite_model = tf.contrib.lite.TocoConverter( + frozen_graphdef, [X], [out0]).convert() + + open("unsupported.tflite", "wb").write(tflite_model) diff --git a/contrib/nnc/unittests/tflite_frontend/test_data/unsupported.tflite b/contrib/nnc/unittests/tflite_frontend/test_data/unsupported.tflite deleted file mode 100644 index 496c0e2..0000000 Binary files a/contrib/nnc/unittests/tflite_frontend/test_data/unsupported.tflite and /dev/null differ diff --git a/contrib/nnc/unittests/tflite_frontend/unsupported_tflite_model.cpp b/contrib/nnc/unittests/tflite_frontend/unsupported_tflite_model.cpp index d1f0b85..d6df6f2 100644 --- a/contrib/nnc/unittests/tflite_frontend/unsupported_tflite_model.cpp +++ b/contrib/nnc/unittests/tflite_frontend/unsupported_tflite_model.cpp @@ -5,8 +5,7 @@ #include const char *ErrorMsg = "Detected problems:\n" - "ADD: unsupported operator\n" - "TANH: unsupported operator\n"; + "SIN: unsupported operator\n"; // When adding support for new layers, change the model, not the test TEST(TFLITE_IMPORT_UNSUPPORTED, ImportModelWithUnsupportedLayers) {