[nnc] Tfl importer eltwise (#2208)
authorАндрей Шедько/AI Tools Lab /SRR/Assistant Engineer/삼성전자 <a.shedko@partner.samsung.com>
Thu, 15 Nov 2018 13:04:26 +0000 (16:04 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Thu, 15 Nov 2018 13:04:26 +0000 (16:04 +0300)
Added elementwise operations, Tanh and Transposed Conv to tfLite importer

Signed-off-by: Andrei Shedko <a.shedko@partner.samsung.com>
contrib/nnc/passes/tflite_frontend/schema/schema.fbs
contrib/nnc/passes/tflite_frontend/schema/schema.meta
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.h
contrib/nnc/unittests/tflite_frontend/CMakeLists.txt
contrib/nnc/unittests/tflite_frontend/test_data/gen_test.py [new file with mode: 0755]
contrib/nnc/unittests/tflite_frontend/test_data/unsupported.tflite [deleted file]
contrib/nnc/unittests/tflite_frontend/unsupported_tflite_model.cpp

index 7d2e00f..07b16f2 100644 (file)
@@ -25,6 +25,9 @@ file_identifier "TFL3";
 // File extension of any written files.
 file_extension "tflite";
 
+// IMPORTANT: All new members of tables, enums and unions must be added at the
+// end to ensure backwards compatibility.
+
 // The type of data stored in a tensor.
 enum TensorType : byte {
   FLOAT32 = 0,
@@ -33,6 +36,9 @@ enum TensorType : byte {
   UINT8 = 3,
   INT64 = 4,
   STRING = 5,
+  BOOL = 6,
+  INT16 = 7,
+  COMPLEX64 = 8,
 }
 
 // Parameters for converting a quantized tensor back to float. Given a
@@ -41,7 +47,7 @@ enum TensorType : byte {
 table QuantizationParameters {
   min:[float];  // For importing back into tensorflow.
   max:[float];  // For importing back into tensorflow.
-  scale:[float];
+  scale:[float];  // For dequantizing the tensor's values.
   zero_point:[long];
 }
 
@@ -62,9 +68,11 @@ table Tensor {
   buffer:uint;
   name:string;  // For debugging and importing back into tensorflow.
   quantization:QuantizationParameters;  // Optional.
+
+  is_variable:bool = false;
 }
 
-// A list of builtin operators. Builtin operators a slighlty faster than custom
+// A list of builtin operators. Builtin operators are slightly faster than custom
 // ones, but not by much. Moreover, while custom operators accept an opaque
 // object containing configuration parameters, builtins have a predetermined
 // set of acceptable options.
@@ -77,7 +85,7 @@ enum BuiltinOperator : byte {
   // DEPTH_TO_SPACE = 5,
   DEQUANTIZE = 6,
   EMBEDDING_LOOKUP = 7,
-  // FLOOR = 8,
+  FLOOR = 8,
   FULLY_CONNECTED = 9,
   HASHTABLE_LOOKUP = 10,
   L2_NORMALIZATION = 11,
@@ -132,6 +140,48 @@ enum BuiltinOperator : byte {
   CAST = 53,
   PRELU = 54,
   MAXIMUM = 55,
+  ARG_MAX = 56,
+  MINIMUM = 57,
+  LESS = 58,
+  NEG = 59,
+  PADV2 = 60,
+  GREATER = 61,
+  GREATER_EQUAL = 62,
+  LESS_EQUAL = 63,
+  SELECT = 64,
+  SLICE = 65,
+  SIN = 66,
+  TRANSPOSE_CONV = 67,
+  SPARSE_TO_DENSE = 68,
+  TILE = 69,
+  EXPAND_DIMS = 70,
+  EQUAL = 71,
+  NOT_EQUAL = 72,
+  LOG = 73,
+  SUM = 74,
+  SQRT = 75,
+  RSQRT = 76,
+  SHAPE = 77,
+  POW = 78,
+  ARG_MIN = 79,
+  FAKE_QUANT = 80,
+  REDUCE_PROD = 81,
+  REDUCE_MAX = 82,
+  PACK = 83,
+  LOGICAL_OR = 84,
+  ONE_HOT = 85,
+  LOGICAL_AND = 86,
+  LOGICAL_NOT = 87,
+  UNPACK = 88,
+  REDUCE_MIN = 89,
+  FLOOR_DIV = 90,
+  REDUCE_ANY = 91,
+  SQUARE = 92,
+  ZEROS_LIKE = 93,
+  FILL = 94,
+  FLOOR_MOD = 95,
+  RANGE = 96,
+  RESIZE_NEAREST_NEIGHBOR = 97,
 }
 
 // Options for the builtin operators.
@@ -162,7 +212,7 @@ union BuiltinOptions {
   BatchToSpaceNDOptions,
   SpaceToBatchNDOptions,
   TransposeOptions,
-  MeanOptions,
+  ReducerOptions,
   SubOptions,
   DivOptions,
   SqueezeOptions,
@@ -174,7 +224,42 @@ union BuiltinOptions {
   LogSoftmaxOptions,
   CastOptions,
   DequantizeOptions,
-  MaximumOptions,
+  MaximumMinimumOptions,
+  ArgMaxOptions,
+  LessOptions,
+  NegOptions,
+  PadV2Options,
+  GreaterOptions,
+  GreaterEqualOptions,
+  LessEqualOptions,
+  SelectOptions,
+  SliceOptions,
+  TransposeConvOptions,
+  SparseToDenseOptions,
+  TileOptions,
+  ExpandDimsOptions,
+  EqualOptions,
+  NotEqualOptions,
+  ShapeOptions,
+  PowOptions,
+  ArgMinOptions,
+  FakeQuantOptions,
+  PackOptions,
+  LogicalOrOptions,
+  OneHotOptions,
+  LogicalAndOptions,
+  LogicalNotOptions,
+  UnpackOptions,
+  FloorDivOptions,
+  SquareOptions,
+  ZerosLikeOptions,
+  FillOptions,
+  BidirectionalSequenceLSTMOptions,
+  BidirectionalSequenceRNNOptions,
+  UnidirectionalSequenceLSTMOptions,
+  FloorModOptions,
+  RangeOptions,
+  ResizeNearestNeighborOptions,
 }
 
 enum Padding : byte { SAME, VALID }
@@ -193,6 +278,8 @@ table Conv2DOptions {
   stride_w:int;
   stride_h:int;
   fused_activation_function:ActivationFunctionType;
+  dilation_w_factor:int = 1;
+  dilation_h_factor:int = 1;
 }
 
 table Pool2DOptions {
@@ -205,11 +292,15 @@ table Pool2DOptions {
 }
 
 table DepthwiseConv2DOptions {
+  // Parameters for DepthwiseConv version 1 or above.
   padding:Padding;
   stride_w:int;
   stride_h:int;
   depth_multiplier:int;
   fused_activation_function:ActivationFunctionType;
+  // Parameters for DepthwiseConv version 2 or above.
+  dilation_w_factor:int = 1;
+  dilation_h_factor:int = 1;
 }
 
 table ConcatEmbeddingsOptions {
@@ -248,11 +339,21 @@ table SequenceRNNOptions {
 table BidirectionalSequenceRNNOptions {
   time_major:bool;
   fused_activation_function:ActivationFunctionType;
+  merge_outputs: bool;
+}
+
+enum FullyConnectedOptionsWeightsFormat: byte {
+  DEFAULT = 0,
+  SHUFFLED4x16INT8 = 1,
 }
 
 // An implementation of TensorFlow fully_connected (a.k.a Dense) layer.
 table FullyConnectedOptions {
+  // Parameters for FullyConnected version 1 or above.
   fused_activation_function:ActivationFunctionType;
+
+  // Parameters for FullyConnected version 2 or above.
+  weights_format:FullyConnectedOptionsWeightsFormat = DEFAULT;
 }
 
 table SoftmaxOptions {
@@ -284,11 +385,42 @@ table LocalResponseNormalizationOptions {
   beta:float;
 }
 
+enum LSTMKernelType : byte {
+  // Full LSTM kernel which supports peephole and projection.
+  FULL = 0,
+  // Basic LSTM kernels. Equivalent to TensorFlow BasicLSTMCell.
+  BASIC = 1,
+}
+
 // An implementation of TensorFlow LSTMCell and CoupledInputForgetGateLSTMCell
 table LSTMOptions {
+  // Parameters for LSTM version 1 or above.
+  fused_activation_function:ActivationFunctionType;
+  cell_clip: float; // Optional, 0.0 means no clipping
+  proj_clip: float; // Optional, 0.0 means no clipping
+
+  // Parameters for LSTM version 2 or above.
+  // Basic kernel is only supported in version 2 or above.
+  kernel_type: LSTMKernelType = FULL;
+}
+
+// An implementation of TensorFlow dynamic_rnn with LSTMCell.
+table UnidirectionalSequenceLSTMOptions {
+  fused_activation_function:ActivationFunctionType;
+  cell_clip: float; // Optional, 0.0 means no clipping
+  proj_clip: float; // Optional, 0.0 means no clipping
+
+  // If true then first dimension is sequence, otherwise batch.
+  time_major:bool;
+}
+
+table BidirectionalSequenceLSTMOptions {
   fused_activation_function:ActivationFunctionType;
   cell_clip: float; // Optional, 0.0 means no clipping
   proj_clip: float; // Optional, 0.0 means no clipping
+
+  // If true, store the outputs of both directions into the first output.
+  merge_outputs: bool;
 }
 
 table ResizeBilinearOptions {
@@ -297,6 +429,10 @@ table ResizeBilinearOptions {
   align_corners: bool;
 }
 
+table ResizeNearestNeighborOptions {
+  align_corners: bool;
+}
+
 // A call operation options
 table CallOptions {
   // The subgraph index that needs to be called.
@@ -306,6 +442,9 @@ table CallOptions {
 table PadOptions {
 }
 
+table PadV2Options {
+}
+
 table ReshapeOptions {
   new_shape:[int];
 }
@@ -357,7 +496,7 @@ table TransposeOptions {
 table ExpOptions {
 }
 
-table MeanOptions {
+table ReducerOptions {
   keep_dims: bool;
 }
 
@@ -381,12 +520,124 @@ table LogSoftmaxOptions {
 }
 
 table CastOptions {
+  in_data_type: TensorType;
+  out_data_type: TensorType;
 }
 
 table DequantizeOptions {
 }
 
-table MaximumOptions {
+table MaximumMinimumOptions {
+}
+
+table TileOptions {
+}
+
+table ArgMaxOptions {
+  output_type : TensorType;
+}
+
+table ArgMinOptions {
+  output_type : TensorType;
+}
+
+table GreaterOptions {
+}
+
+table GreaterEqualOptions {
+}
+
+table LessOptions {
+}
+
+table LessEqualOptions {
+}
+
+table NegOptions {
+}
+
+table SelectOptions {
+}
+
+table SliceOptions {
+}
+
+table TransposeConvOptions {
+  padding:Padding;
+  stride_w:int;
+  stride_h:int;
+}
+
+table ExpandDimsOptions {
+}
+
+table SparseToDenseOptions {
+  validate_indices:bool;
+}
+
+table EqualOptions {
+}
+
+table NotEqualOptions {
+}
+
+table ShapeOptions {
+  // Optional output type of the operation (int32 or int64). Defaults to int32.
+  out_type : TensorType;
+}
+
+table PowOptions {
+}
+
+table FakeQuantOptions {
+  // Parameters supported by version 1:
+  min:float;
+  max:float;
+  num_bits:int;
+
+  // Parameters supported by version 2:
+  narrow_range:bool;
+}
+
+table PackOptions {
+  values_count:int;
+  axis:int;
+}
+
+table LogicalOrOptions {
+}
+
+table OneHotOptions {
+  axis:int;
+}
+
+table LogicalAndOptions {
+}
+
+table LogicalNotOptions {
+}
+
+table UnpackOptions {
+  num:int;
+  axis:int;
+}
+
+table FloorDivOptions {
+}
+
+table SquareOptions {
+}
+
+table ZerosLikeOptions {
+}
+
+table FillOptions {
+}
+
+table FloorModOptions {
+}
+
+table RangeOptions {
 }
 
 // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
@@ -394,6 +645,10 @@ table MaximumOptions {
 table OperatorCode {
   builtin_code:BuiltinOperator;
   custom_code:string;
+
+  // The version of the operator. The version need to be bumped whenever new
+  // parameters are introduced into an op.
+  version:int = 1;
 }
 
 enum CustomOptionsFormat : byte {
@@ -416,30 +671,44 @@ table Operator {
   builtin_options:BuiltinOptions;
   custom_options:[ubyte];
   custom_options_format:CustomOptionsFormat;
+
+  // A list of booleans indicating the input tensors which are being mutated by
+  // this operator.(e.g. used by RNN and LSTM).
+  // For example, if the "inputs" array refers to 5 tensors and the second and
+  // fifth are mutable variables, then this list will contain
+  // [false, true, false, false, true].
+  //
+  // If the list is empty, no variable is mutated in this operator.
+  // The list either has the same length as `inputs`, or is empty.
+  mutating_variable_inputs:[bool];
 }
 
-// The root type, defining a model.
+// The root type, defining a subgraph, which typically represents an entire
+// model.
 table SubGraph {
-  // A list of all tensors used in this model.
+  // A list of all tensors used in this subgraph.
   tensors:[Tensor];
 
-  // Indices of the input tensors.
+  // Indices of the tensors that are inputs into this subgraph. Note this is
+  // the list of non-static tensors that feed into the subgraph for inference.
   inputs:[int];
 
-  // Indices of the output tensors.
+  // Indices of the tensors that are outputs out of this subgraph. Note this is
+  // the list of output tensors that are considered the product of the
+  // subgraph's inference.
   outputs:[int];
 
   // All operators, in execution order.
   operators:[Operator];
 
-  // Name of subgraph (used for debugging).
+  // Name of this subgraph (used for debugging).
   name:string;
 }
 
 // Table of raw data buffers (used for constant tensors). Referenced by tensors
-// by index.
+// by index. The generous alignment accommodates mmap-friendly data structures.
 table Buffer {
-  data:[ubyte];
+  data:[ubyte] (force_align: 16);
 }
 
 table Model {
@@ -458,9 +727,14 @@ table Model {
   // A description of the model.
   description:string;
 
-  // Buffers of the model
+  // Buffers of the model.
+  // Note the 0th entry of this array must be an empty buffer (sentinel).
+  // This is a convention so that tensors without a buffer can provide 0 as
+  // their buffer.
   buffers:[Buffer];
 
+  // Metadata about the model.  Indirects into the existings buffers list.
+  metadata_buffer:[int];
 }
 
 root_type Model;
index 698237c..318658a 100644 (file)
@@ -93,6 +93,10 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) {
     case BuiltinOperator_SOFTMAX:
     case BuiltinOperator_RESHAPE:
     case BuiltinOperator_SQUEEZE:
+    case BuiltinOperator_ADD:
+    case BuiltinOperator_MUL:
+    case BuiltinOperator_MAXIMUM:
+    case BuiltinOperator_TRANSPOSE_CONV:
       // No checks
       break;
     default:
@@ -184,6 +188,18 @@ void TfliteImporter::walkOperator(const Operator* op) {
       break;
     case BuiltinOperator_SQUEEZE:
       outputs = _opCreator->createSqueeze(inputs, params, op->builtin_options_as<SqueezeOptions>());
+    case BuiltinOperator_ADD:
+      outputs = _opCreator->createAdd(inputs, params, op->builtin_options_as<AddOptions>());
+      break;
+    case BuiltinOperator_MUL:
+      outputs = _opCreator->createMul(inputs, params, op->builtin_options_as<MulOptions>());
+      break;
+    case BuiltinOperator_MAXIMUM:
+      outputs = _opCreator->createMax(inputs, params, op->builtin_options_as<MaximumMinimumOptions>());
+      break;
+    case BuiltinOperator_TRANSPOSE_CONV:
+      outputs = _opCreator->createTransposeConv(
+        inputs, params,op->builtin_options_as<TransposeConvOptions>());
       break;
     default:
       assert(false && "All unsupported types should have been found before this pass.");
index 0845ab2..901ce09 100644 (file)
@@ -23,6 +23,9 @@
 #include "core/modelIR/operations/FullyConnectedOp.h"
 #include "core/modelIR/operations/ReluOp.h"
 #include "core/modelIR/operations/CappedReluOp.h"
+#include "core/modelIR/operations/TanhOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
+#include "core/modelIR/operations/Deconv2DOp.h"
 #include "core/modelIR/operations/SoftmaxOp.h"
 #include "core/modelIR/operations/PoolOp.h"
 #include "core/modelIR/operations/BiasAddOp.h"
@@ -126,6 +129,38 @@ std::vector<mir::Operation*> TFLiteOpCreator::convertReshape(InputOps inputs, In
   return outputs;
 }
 
+std::vector<mir::Operation*> TFLiteOpCreator::createTransposeConv(
+  InputOps& inputs,
+  InputParams& params,
+  const ::tflite::TransposeConvOptions* opts) {// first param is output shape
+  return createOp<ops::DeConv2DOp>(inputs, ActivationFunctionType_NONE, std::move(*params[1]),
+                                         Shape{static_cast<int32_t>(opts->stride_h()),
+                                               static_cast<int32_t>(opts->stride_w()), 1},
+                                         paddingMap[opts->padding()]);
+}
+
+std::vector<mir::Operation*>
+TFLiteOpCreator::createAdd(InputOps& inputs,
+                           InputParams&, const ::tflite::AddOptions* opts) {
+  return createOp<ops::ElementwiseOp>(
+    inputs, opts->fused_activation_function(), ops::ElementwiseOp::OpType::sum, inputs.size());
+}
+
+std::vector<mir::Operation*>
+TFLiteOpCreator::createMul(InputOps& inputs,
+                           InputParams&, const ::tflite::MulOptions* opts) {
+  return createOp<ops::ElementwiseOp>(
+    inputs, opts->fused_activation_function(), ops::ElementwiseOp::OpType::prod, inputs.size());
+}
+
+
+std::vector<mir::Operation*>
+TFLiteOpCreator::createMax(InputOps& inputs,
+                           InputParams&, const ::tflite::MaximumMinimumOptions* opts) {
+  return createOp<ops::ElementwiseOp>(
+    inputs, ActivationFunctionType_NONE, ops::ElementwiseOp::OpType::max, inputs.size());
+}
+
 void TFLiteOpCreator::checkFullyConnected(const FullyConnectedOptions* opts,
                                           std::set<std::string>& problems_op_set) {
   checkActivationType(opts->fused_activation_function(), problems_op_set);
@@ -150,7 +185,8 @@ void TFLiteOpCreator::checkActivationType(ActivationFunctionType activation_type
                                           std::set<std::string>& problems_op_set) {
   if (activation_type != ActivationFunctionType_NONE
       && activation_type != ActivationFunctionType_RELU
-      && activation_type != ActivationFunctionType_RELU6)
+      && activation_type != ActivationFunctionType_RELU6
+      && activation_type != ActivationFunctionType_TANH)
     problems_op_set.insert(std::string("Unsupported activation type: ")
                            + EnumNamesActivationFunctionType()[activation_type]);
 }
@@ -168,6 +204,9 @@ mir::Operation* TFLiteOpCreator::addFusedActivation(mir::Operation* input,
       case ActivationFunctionType_RELU6:
         activation = graph->create<ops::CappedReluOp>("", 6);
         break;
+      case ActivationFunctionType_TANH:
+        activation = graph->create<ops::TanhOp>("");
+        break;
       default:
         assert(false && "Unsupported activation types must be detected before this pass");
     }
index fa9e0ad..bceed15 100644 (file)
@@ -71,6 +71,21 @@ public:
   std::vector<mir::Operation*> createSqueeze(InputOps& inputs, InputParams& params,
                                              const ::tflite::SqueezeOptions* opts);
 
+  /** @brief Elementwise Add  */
+  std::vector<mir::Operation*> createAdd(InputOps&, InputParams&, const ::tflite::AddOptions*);
+  /** @brief Elementwise product */
+  std::vector<mir::Operation*> createMul(InputOps&, InputParams&, const ::tflite::MulOptions*);
+  /** @brief Elementwise maximum  */
+  std::vector<mir::Operation*> createMax(InputOps&, InputParams&, const ::tflite::MaximumMinimumOptions*);
+
+  /**
+ * @brief Creates a Transposed convolution
+ * @param params 0 - output shape (unused), 1 - kernel, 2- input
+ */
+  std::vector<mir::Operation*> createTransposeConv(
+    InputOps&, InputParams&,
+    const ::tflite::TransposeConvOptions*);
+
   void checkPool2D(const ::tflite::Pool2DOptions*, std::set<std::string>&);
 
   void checkConcatenation(const ::tflite::ConcatenationOptions*, std::set<std::string>&);
index bb4cb76..8e47a7b 100644 (file)
@@ -1,6 +1,13 @@
 file(GLOB_RECURSE TESTS "*.cpp")
 
-if (NNC_FRONTEND_TFLITE_ENABLED)
+#Feature detect:
+execute_process(
+        COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/test_data/gen_test.py
+        WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/test_data/
+        RESULT_VARIABLE test_create_failed
+)
+
+if (NNC_FRONTEND_TFLITE_ENABLED AND NOT ${test_create_failed})
   add_definitions(-DTFLITE_TEST_DIR="${CMAKE_CURRENT_SOURCE_DIR}/test_data/")
   add_nnc_unit_test(nnc_tflite_frontend_test ${TESTS} ${OPTIONS_SRC})
   if (TARGET nnc_tflite_frontend_test)
diff --git a/contrib/nnc/unittests/tflite_frontend/test_data/gen_test.py b/contrib/nnc/unittests/tflite_frontend/test_data/gen_test.py
new file mode 100755 (executable)
index 0000000..30032a2
--- /dev/null
@@ -0,0 +1,25 @@
+#!/usr/bin/python3
+import tensorflow as tf
+import numpy as np
+
+output_shape = [1, 28, 28, 1]
+strides = [1,1,1,1]
+W = tf.constant(np.ones([7, 7, 1, 1]).astype(np.float32), name = "ker_d")
+
+# Create the graph.
+X = tf.placeholder(shape=[1, 28, 28, 1], name='input', dtype=tf.float32)
+Y = tf.sin(X)
+
+out0 = tf.identity(Y, name="out")
+# Filter the input image.
+with tf.Session() as sess:
+    print('Evaluating...')
+    out = sess.run(out0, feed_dict = {"input:0": np.ones((1, 28, 28, 1)).astype(np.float32)})
+    # print(sess.graph_def)
+
+    frozen_graphdef = tf.graph_util.convert_variables_to_constants(
+        sess, sess.graph_def, ["out"])
+    tflite_model = tf.contrib.lite.TocoConverter(
+        frozen_graphdef, [X], [out0]).convert()
+
+    open("unsupported.tflite", "wb").write(tflite_model)
diff --git a/contrib/nnc/unittests/tflite_frontend/test_data/unsupported.tflite b/contrib/nnc/unittests/tflite_frontend/test_data/unsupported.tflite
deleted file mode 100644 (file)
index 496c0e2..0000000
Binary files a/contrib/nnc/unittests/tflite_frontend/test_data/unsupported.tflite and /dev/null differ
index d1f0b85..d6df6f2 100644 (file)
@@ -5,8 +5,7 @@
 #include <iostream>
 
 const char *ErrorMsg = "Detected problems:\n"
-                       "ADD: unsupported operator\n"
-                       "TANH: unsupported operator\n";
+                       "SIN: unsupported operator\n";
 
 // When adding support for new layers, change the model, not the test
 TEST(TFLITE_IMPORT_UNSUPPORTED, ImportModelWithUnsupportedLayers) {