[nnc] Sub and Squared Diff (#2679)
authorАндрей Шедько/AI Tools Lab /SRR/Engineer/삼성전자 <a.shedko@samsung.com>
Thu, 20 Dec 2018 17:51:03 +0000 (20:51 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Thu, 20 Dec 2018 17:51:03 +0000 (20:51 +0300)
- Refactored tflite import of various elementwise ops
- Added support for sub and SquaredDiff to C++ SB and Interpreter
- Fixed assert in c++ CPU artifact

Signed-off-by: Andrei Shedko <a.shedko@samsung.com>
14 files changed:
contrib/nnc/include/core/modelIR/operations/ElementwiseOp.h
contrib/nnc/passes/interpreter/Interpreter.cpp
contrib/nnc/passes/interpreter/ops/DeConv2D.cpp
contrib/nnc/passes/soft_backend/ModelAnalyzer.cpp
contrib/nnc/passes/soft_backend/code_snippets/cpp_elementwise.def
contrib/nnc/passes/soft_backend/code_snippets/cpp_operations.def
contrib/nnc/passes/tflite_frontend/schema/schema.fbs
contrib/nnc/passes/tflite_frontend/schema/schema.meta
contrib/nnc/passes/tflite_frontend/tflite_importer.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.cpp
contrib/nnc/passes/tflite_frontend/tflite_op_creator.h
contrib/nnc/unittests/soft_backend/CPPOperations.cpp
contrib/nnc/unittests/tflite_frontend/test_data/gen_test.py
contrib/nnc/unittests/tflite_frontend/unsupported_tflite_model.cpp

index 0b61535..3d0c6a6 100644 (file)
@@ -30,7 +30,8 @@ public:
     mul,
     add,
     max,
-    div
+    div,
+    sub,
   };
 
   /**
index 60731e9..6f5da6e 100644 (file)
@@ -253,6 +253,9 @@ void NNInterpreter::visit(ops::ElementwiseOp& op) {
     case ops::ElementwiseOp::OpType::div:
       func = [](float a, float b) { return a / b; };
       break;
+    case ops::ElementwiseOp::OpType::sub:
+      func = [](float a, float b) { return a - b; };
+      break;
     default:
       assert(false && "Unsupported Optype");
   }
index af2b5d5..750f967 100644 (file)
@@ -106,7 +106,6 @@ DeConv2D::DeConv2D(const TensorVariant& input, const DeConv2DOp& op)
   assert(input_shape.rank() == 4);
   assert(kernel_shape.rank() == 4);
   assert(kernel_shape.dim(3) == input_shape.dim(3));
-  assert(_strides.dim(2) == 1);
   assert(_op.getPaddingBefore().size() == 2);
   assert(_op.getPaddingAfter().size() == 2);
 }
index aedb9f4..db124f3 100644 (file)
@@ -295,6 +295,9 @@ void ModelAnalyzer::visit(mir::ops::ElementwiseOp& op) {
     case ops::ElementwiseOp::OpType::div:
       func_name = "ElementWise<Div>";
       break;
+    case ops::ElementwiseOp::OpType::sub:
+      func_name = "ElementWise<Sub>";
+      break;
     default:
       assert(false && "unsupported elementwise operation type");
   }
index 765ab99..9fc95ec 100644 (file)
@@ -68,6 +68,60 @@ inline void BinaryFunction(const RuntimeShape& input1_shape,
   }
 }
 
+struct Sub {
+  static inline void Sub_(const float* input1_data, const float* input2_data,
+                          float* output_data, const int size) {
+    int i = 0;
+#ifdef USE_NEON
+    for (; i <= size - 16; i += 16) {
+      auto a10 = vld1q_f32(input1_data + i);
+      auto a11 = vld1q_f32(input1_data + i + 4);
+      auto a12 = vld1q_f32(input1_data + i + 8);
+      auto a13 = vld1q_f32(input1_data + i + 12);
+      auto a20 = vld1q_f32(input2_data + i);
+      auto a21 = vld1q_f32(input2_data + i + 4);
+      auto a22 = vld1q_f32(input2_data + i + 8);
+      auto a23 = vld1q_f32(input2_data + i + 12);
+      auto x0 = vsubq_f32(a10, a20);
+      auto x1 = vsubq_f32(a11, a21);
+      auto x2 = vsubq_f32(a12, a22);
+      auto x3 = vsubq_f32(a13, a23);
+      vst1q_f32(output_data + i, x0);
+      vst1q_f32(output_data + i + 4, x1);
+      vst1q_f32(output_data + i + 8, x2);
+      vst1q_f32(output_data + i + 12, x3);
+    }
+    for (; i <= size - 4; i += 4) {
+      auto a1 = vld1q_f32(input1_data + i);
+      auto a2 = vld1q_f32(input2_data + i);
+      auto x = vsubq_f32(a1, a2);
+      vst1q_f32(output_data + i, x);
+    }
+#endif  // NEON
+
+    for (; i < size; i++) {
+      output_data[i] = input1_data[i] - input2_data[i];
+    }
+  }
+
+  static inline void Call(
+    const float* input1_data, RuntimeShape in1_shape,
+    const float* input2_data, RuntimeShape in2_shape,
+    float* output_data, RuntimeShape out_shape,
+    bool needsBroadcast) {
+    if (needsBroadcast) {
+      BroadcastBinaryFunction4DSlow<float, float, float>(
+        in1_shape, input1_data,
+        in2_shape, input2_data,
+        out_shape, output_data,
+        [](float a, float b) { return a - b; }
+      );
+    } else {
+      Sub_(input1_data, input2_data, output_data, out_shape.FlatSize());
+    }
+  }
+};
+
 struct Add {
   static inline void Add_(const float* input1_data, const float* input2_data,
                           float* output_data, const int size) {
@@ -217,4 +271,4 @@ struct Div {
       output = output.cwiseQuotient(MapAsVector(input2_data, out_shape.FlatSize()));
     }
   }
-};
\ No newline at end of file
+};
index b3a265c..02b5fae 100644 (file)
@@ -295,9 +295,10 @@ void convTransposed2d(Tensor &out, const char *params, const Tensor &in) {
 
   RuntimeShape out_shape = shapeToRuntimeShape(out_s);
 
+  assert(strides.getDims() == 2);
   const short stride_w = strides[1];
   const short stride_h = strides[0];
-  assert(strides[2] == 1);
+  assert(pads.getDims() == 2);
   const short pad_w = pads[1];
   const short pad_h = pads[0];
 
index 07b16f2..3ece383 100644 (file)
@@ -39,16 +39,34 @@ enum TensorType : byte {
   BOOL = 6,
   INT16 = 7,
   COMPLEX64 = 8,
+  INT8 = 9,
 }
 
-// Parameters for converting a quantized tensor back to float. Given a
-// quantized value q, the corresponding float value f should be:
-//   f = scale * (q - zero_point)
+// Custom quantization parameters for experimenting with new quantization
+// techniques.
+table CustomQuantization {
+  custom:[ubyte] (force_align: 16);
+}
+
+// Represents a specific quantization technique's parameters.
+union QuantizationDetails {
+  CustomQuantization,
+}
+
+// Parameters for converting a quantized tensor back to float.
 table QuantizationParameters {
+  // These four parameters are the asymmetric linear quantization parameters.
+  // Given a quantized value q, the corresponding float value f should be:
+  //   f = scale * (q - zero_point)
+  // For other quantization types, the QuantizationDetails below is used.
   min:[float];  // For importing back into tensorflow.
   max:[float];  // For importing back into tensorflow.
   scale:[float];  // For dequantizing the tensor's values.
   zero_point:[long];
+
+  // If this is not none, the quantization parameters above are ignored and the
+  // value of the QuantizationDetails union below should be used.
+  details:QuantizationDetails;
 }
 
 table Tensor {
@@ -182,6 +200,11 @@ enum BuiltinOperator : byte {
   FLOOR_MOD = 95,
   RANGE = 96,
   RESIZE_NEAREST_NEIGHBOR = 97,
+  LEAKY_RELU = 98,
+  SQUARED_DIFFERENCE = 99,
+  MIRROR_PAD = 100,
+  ABS = 101,
+  SPLIT_V = 102,
 }
 
 // Options for the builtin operators.
@@ -260,6 +283,11 @@ union BuiltinOptions {
   FloorModOptions,
   RangeOptions,
   ResizeNearestNeighborOptions,
+  LeakyReluOptions,
+  SquaredDifferenceOptions,
+  MirrorPadOptions,
+  AbsOptions,
+  SplitVOptions,
 }
 
 enum Padding : byte { SAME, VALID }
@@ -508,6 +536,10 @@ table SplitOptions {
   num_splits: int;
 }
 
+table SplitVOptions {
+  num_splits: int;
+}
+
 table StridedSliceOptions {
   begin_mask: int;
   end_mask: int;
@@ -611,6 +643,10 @@ table OneHotOptions {
   axis:int;
 }
 
+table AbsOptions {
+}
+
+
 table LogicalAndOptions {
 }
 
@@ -640,6 +676,24 @@ table FloorModOptions {
 table RangeOptions {
 }
 
+table LeakyReluOptions {
+  alpha:float;
+}
+
+table SquaredDifferenceOptions {
+}
+
+enum MirrorPadMode : byte {
+  // Doesn't include borders.
+  REFLECT = 0,
+  // Includes borders.
+  SYMMETRIC = 1,
+}
+
+table MirrorPadOptions {
+  mode:MirrorPadMode;
+}
+
 // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
 // builtin, or a string if the operator is custom.
 table OperatorCode {
@@ -737,4 +791,4 @@ table Model {
   metadata_buffer:[int];
 }
 
-root_type Model;
+root_type Model;
\ No newline at end of file
index 4b8296b..e55e068 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "schema_generated.h"
 #include "tflite_importer.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
 #include "tflite_op_creator.h"
 
 using namespace ::tflite;
@@ -99,6 +100,8 @@ void TfliteImporter::processUnsupportedOp(const Operator* op) {
     case BuiltinOperator_SQRT:
     case BuiltinOperator_PAD:
     case BuiltinOperator_ADD:
+    case BuiltinOperator_SUB:
+    case BuiltinOperator_SQUARED_DIFFERENCE:
     case BuiltinOperator_MUL:
     case BuiltinOperator_MEAN:
     case BuiltinOperator_MAXIMUM:
@@ -172,8 +175,9 @@ void TfliteImporter::walkOperator(const Operator* op) {
       outputs = _opCreator->convertConv2D(inputs, params, op->builtin_options_as<Conv2DOptions>());
       break;
     case BuiltinOperator_DEPTHWISE_CONV_2D:
-      outputs = _opCreator->convertDepthwiseConv2D(inputs, params,
-                                                   op->builtin_options_as<DepthwiseConv2DOptions>());
+      outputs = _opCreator->convertDepthwiseConv2D(
+        inputs, params,
+        op->builtin_options_as<DepthwiseConv2DOptions>());
       break;
     case BuiltinOperator_MAX_POOL_2D:
       outputs = _opCreator->convertMaxPool2D(inputs, params,
@@ -221,17 +225,32 @@ void TfliteImporter::walkOperator(const Operator* op) {
       outputs = _opCreator->createSqrt(inputs, params);
       break;
     case BuiltinOperator_ADD:
-      outputs = _opCreator->createAdd(inputs, params, op->builtin_options_as<AddOptions>());
+      outputs = _opCreator->createElementwise(
+        inputs, params, ops::ElementwiseOp::OpType::add,
+        op->builtin_options_as_AddOptions()->fused_activation_function());
       break;
     case BuiltinOperator_MUL:
-      outputs = _opCreator->createMul(inputs, params, op->builtin_options_as<MulOptions>());
+      outputs = _opCreator->createElementwise(
+        inputs, params, ops::ElementwiseOp::OpType::mul,
+        op->builtin_options_as_MulOptions()->fused_activation_function());
       break;
     case BuiltinOperator_DIV:
-      outputs = _opCreator->createDiv(inputs, params, op->builtin_options_as<DivOptions>());
+      outputs = _opCreator->createElementwise(
+        inputs, params, ops::ElementwiseOp::OpType::div,
+        op->builtin_options_as_DivOptions()->fused_activation_function());
       break;
     case BuiltinOperator_MAXIMUM:
-      outputs = _opCreator->createMax(
-        inputs, params, op->builtin_options_as<MaximumMinimumOptions>());
+      outputs = _opCreator->createElementwise(
+        inputs, params, ops::ElementwiseOp::OpType::max,
+        ActivationFunctionType_NONE); // no activation
+      break;
+    case BuiltinOperator_SUB:
+      outputs = _opCreator->createElementwise(
+        inputs, params, ops::ElementwiseOp::OpType::sub,
+        op->builtin_options_as_SubOptions()->fused_activation_function());
+      break;
+    case BuiltinOperator_SQUARED_DIFFERENCE:
+      outputs = _opCreator->createSquaredDifference(inputs, params); // no activation
       break;
     case BuiltinOperator_TRANSPOSE_CONV:
       outputs = _opCreator->createTransposeConv(
index 02e2e48..be8a0ae 100644 (file)
@@ -225,7 +225,7 @@ TFLiteOpCreator::convertReshape(InputOps& inputs, const InputParams& params,
 std::vector<mir::Operation*>
 TFLiteOpCreator::createTransposeConv(InputOps& inputs, const InputParams& params,
                                      const ::tflite::TransposeConvOptions* opts) {
-  Shape strides{opts->stride_h(), opts->stride_w(), 1};
+  Shape strides{opts->stride_h(), opts->stride_w()};
   return createOp<ops::DeConv2DOp>(ActivationFunctionType_NONE, inputs[0]->getOutput(0), params[1],
                                    strides, paddingMap[opts->padding()]);
 }
@@ -249,8 +249,9 @@ TFLiteOpCreator::convertResizeNN(InputOps& inputs, const InputParams& params,
 }
 
 std::vector<mir::Operation*>
-TFLiteOpCreator::createAdd(InputOps& inputs, const InputParams& params,
-                           const ::tflite::AddOptions* opts) {
+TFLiteOpCreator::createElementwise(const InputOps& inputs, const InputParams& params,
+                                   ops::ElementwiseOp::OpType opType,
+                                   const ::tflite::ActivationFunctionType activation) {
   std::vector<IODescriptor> descriptors;
 
   for (auto i : inputs)
@@ -260,14 +261,11 @@ TFLiteOpCreator::createAdd(InputOps& inputs, const InputParams& params,
     auto weights_tensor = createOp<ops::ConstantOp>(ActivationFunctionType_NONE, param);
     descriptors.push_back(weights_tensor[0]->getOutput(0));
   }
-
-  return createOp<ops::ElementwiseOp>(opts->fused_activation_function(), descriptors,
-                                      ops::ElementwiseOp::OpType::add);
+  return createOp<ops::ElementwiseOp>(activation, descriptors, opType);
 }
 
 std::vector<mir::Operation*>
-TFLiteOpCreator::createMul(InputOps& inputs, const InputParams& params,
-                           const ::tflite::MulOptions* opts) {
+TFLiteOpCreator::createSquaredDifference(const InputOps& inputs, const InputParams& params) {
   std::vector<IODescriptor> descriptors;
 
   for (auto i : inputs)
@@ -278,28 +276,13 @@ TFLiteOpCreator::createMul(InputOps& inputs, const InputParams& params,
     descriptors.push_back(weights_tensor[0]->getOutput(0));
   }
 
-  return createOp<ops::ElementwiseOp>(opts->fused_activation_function(), descriptors,
-                                      ops::ElementwiseOp::OpType::mul);
-}
+  auto sub_result = createOp<ops::ElementwiseOp>(ActivationFunctionType_NONE, descriptors,
+                                             ops::ElementwiseOp::OpType::sub);
 
-std::vector<mir::Operation*>
-TFLiteOpCreator::createDiv(InputOps& inputs, const InputParams&,
-                           const ::tflite::DivOptions* opts) {
-  std::vector<IODescriptor> descriptors;
-  for (auto i : inputs)
-    descriptors.push_back(i->getOutput(0));
-  return createOp<ops::ElementwiseOp>(opts->fused_activation_function(), descriptors,
-                                      ops::ElementwiseOp::OpType::div);
-}
-
-std::vector<mir::Operation*>
-TFLiteOpCreator::createMax(InputOps& inputs, const InputParams&,
-                           const ::tflite::MaximumMinimumOptions* opts) {
-  std::vector<IODescriptor> descriptors;
-  for (auto i : inputs)
-    descriptors.push_back(i->getOutput(0));
-  return createOp<ops::ElementwiseOp>(ActivationFunctionType_NONE, descriptors,
-                                      ops::ElementwiseOp::OpType::max);
+  return createOp<ops::ElementwiseOp>(ActivationFunctionType_NONE,
+                                      std::vector<IODescriptor>{sub_result[0]->getOutput(0),
+                                                                sub_result[0]->getOutput(0)},
+                                      ops::ElementwiseOp::OpType::mul);
 }
 
 std::vector<mir::Operation*>
index 82ca736..482e2e3 100644 (file)
@@ -30,6 +30,7 @@
 
 #include "core/modelIR/operations/CommonProps.h"
 #include "core/modelIR/operations/ReduceFOp.h"
+#include "core/modelIR/operations/ElementwiseOp.h"
 
 #include "schema_generated.h"
 #include "passes/common_frontend/shape_helper.h"
@@ -88,18 +89,12 @@ public:
   std::vector<mir::Operation*> createSqueeze(InputOps& inputs, const InputParams& params,
                                              const ::tflite::SqueezeOptions* opts);
 
-  /** @brief Elementwise Add  */
-  std::vector<mir::Operation*> createAdd(InputOps&, const InputParams&,
-                                         const ::tflite::AddOptions*);
-  /** @brief Elementwise product */
-  std::vector<mir::Operation*> createMul(InputOps&, const InputParams&,
-                                         const ::tflite::MulOptions*);
-  /** @brief Elementwise maximum  */
-  std::vector<mir::Operation*> createMax(InputOps&, const InputParams&,
-                                         const ::tflite::MaximumMinimumOptions*);
-  /** @brief Elementwise division  */
-  std::vector<mir::Operation*> createDiv(InputOps&, const InputParams&,
-                                         const ::tflite::DivOptions*);
+  /** @brief Elementwise Operation */
+  std::vector<mir::Operation*> createElementwise(
+    const InputOps&, const InputParams&, ops::ElementwiseOp::OpType opType,
+    const ::tflite::ActivationFunctionType);
+
+  std::vector<mir::Operation*> createSquaredDifference(const InputOps&, const InputParams&);
 
   /// @brief Free-standing ( non-fused ) activation function based on tflite activation
   std::vector<mir::Operation*> createActivation(InputOps&, const InputParams&,
index 35938bc..1007db6 100644 (file)
@@ -505,6 +505,28 @@ TEST(cpp_operations_test, add2) {
   }
 }
 
+TEST(cpp_operations_test, sub3) {
+  for (int num_dims = 2; num_dims <= 4; ++num_dims) {
+    // test prerequisites
+    vector<int> shape_data{2, 3, 5, 7};
+    shape_data.resize(num_dims);
+    vector<Tensor> input_atensors(3);
+    vector<unique_ptr<mir::TensorVariant>> input_n_tensors(3);
+    fillTensors(input_n_tensors[0], input_atensors[0], shape_data, 1.0f);
+    fillTensors(input_n_tensors[1], input_atensors[1], shape_data, 2.0f);
+    fillTensors(input_n_tensors[2], input_atensors[2], shape_data, 3.0f);
+    auto opGenerator = [](mir::Graph& g, const std::vector<mir::IODescriptor>& inputs) {
+      return g.create<mir::ops::ElementwiseOp>("y", inputs,
+                                               mir::ops::ElementwiseOp::OpType::sub);
+    };
+
+    createAndRunTestGraph(opGenerator, ElementWise<Sub, Tensor, Tensor, Tensor>, input_n_tensors,
+                          input_atensors[0],
+                          input_atensors[1],
+                          input_atensors[2]);
+  }
+}
+
 TEST(cpp_operations_test, mul3) {
   for (int num_dims = 2; num_dims <= 4; ++num_dims) {
     // test prerequisites
@@ -567,7 +589,7 @@ TEST(cpp_operations_test, convTransposed2d) {
             for (iT stride_w = 1; stride_w <= 3; ++stride_w) {
               vector<int> input_shape_data{3, 9, 3, static_cast<int>(input_c)};  // NHWC
               mir::Shape kernel_shape{kernel_h, kernel_w, output_c, input_c};
-              mir::Shape strides{stride_h, stride_w, 1};
+              mir::Shape strides{stride_h, stride_w};
               vector<unique_ptr<mir::TensorVariant>> input_ntensors(1);
               Tensor input_atensor;
               fillTensors(input_ntensors[0], input_atensor, input_shape_data, 1.0f);
@@ -792,27 +814,27 @@ TEST(cpp_operations_test, sigmoid) {
 TEST(cpp_operations_test, elu) {
   // test prerequisites
   vector<int> shape_data{2, 3, 4, 5};
-  Tensor a_input_tensor;
+  Tensor input_atensor;
   vector<unique_ptr<mir::TensorVariant>> input_ntensors(1);
-  fillTensors(input_ntensors[0], a_input_tensor, shape_data, 1.0f);
+  fillTensors(input_ntensors[0], input_atensor, shape_data, 1.0f);
   auto op_generator = [](mir::Graph& g, const std::vector<mir::IODescriptor>& inputs) {
     return g.create<mir::ops::EluOp>("y", inputs[0], 1);
   };
 
-  createAndRunTestGraph(op_generator, elu, input_ntensors, a_input_tensor);
+  createAndRunTestGraph(op_generator, elu, input_ntensors, input_atensor);
 }
 
 TEST(cpp_operations_test, tanh) {
   // test prerequisites
   vector<int> shape_data{2, 3, 4, 5};
-  Tensor a_input_tensor;
+  Tensor input_atensor;
   vector<unique_ptr<mir::TensorVariant>> input_ntensors(1);
-  fillTensors(input_ntensors[0], a_input_tensor, shape_data, 1.0f);
+  fillTensors(input_ntensors[0], input_atensor, shape_data, 1.0f);
   auto op_generator = [](mir::Graph& g, const std::vector<mir::IODescriptor>& inputs) {
     return g.create<mir::ops::TanhOp>("y", inputs[0]);
   };
 
-  createAndRunTestGraph(op_generator, tanhActivation, input_ntensors, a_input_tensor);
+  createAndRunTestGraph(op_generator, tanhActivation, input_ntensors, input_atensor);
 }
 
 TEST(cpp_operations_test, reduceMeanTst) {
@@ -830,9 +852,9 @@ TEST(cpp_operations_test, reduceMeanTst) {
   for (const vector<int>& axis_list: test_axis_list) {
     for (const bool keep_dims: {true, false}) {
       vector<int> input_shape_data{2, 3, 4, 5};
-      Tensor a_input_tensor;
+      Tensor input_atensor;
       vector<unique_ptr<mir::TensorVariant>> input_ntensors(1);
-      fillTensors(input_ntensors[0], a_input_tensor, input_shape_data, 1.0f);
+      fillTensors(input_ntensors[0], input_atensor, input_shape_data, 1.0f);
       auto op_generator = [axis_list, keep_dims](mir::Graph& g,
                                                 const std::vector<mir::IODescriptor>& inputs) {
         auto op = g.create<mir::ops::ReduceFOp>(
@@ -841,7 +863,7 @@ TEST(cpp_operations_test, reduceMeanTst) {
         return op;
       };
 
-      createAndRunTestGraph(op_generator, reduceMean, input_ntensors, a_input_tensor);
+      createAndRunTestGraph(op_generator, reduceMean, input_ntensors, input_atensor);
     }
   }
 }
@@ -876,14 +898,14 @@ TEST(cpp_operations_test, slice4d) {
   };
   for (auto st : starts) {
     for (auto sz : sizes) {
-      Tensor a_input_tensor;
+      Tensor input_atensor;
       vector<unique_ptr<mir::TensorVariant>> input_n_tensor(1);
-      fillTensors(input_n_tensor[0], a_input_tensor, shape_data, 1.0f);
+      fillTensors(input_n_tensor[0], input_atensor, shape_data, 1.0f);
       auto op_gen = [&st, &sz](mir::Graph& g, const std::vector<mir::IODescriptor>& inputs) {
         return g.create<mir::ops::SliceOp>("y", inputs[0], mir::Shape(st),
                                            mir::Shape(sz));
       };
-      createAndRunTestGraph(op_gen, slice, input_n_tensor, a_input_tensor);
+      createAndRunTestGraph(op_gen, slice, input_n_tensor, input_atensor);
     }
   }
 }
index c475b50..bafbf80 100755 (executable)
@@ -31,7 +31,7 @@ with tf.Session() as sess:
     out = sess.run(out0, feed_dict = {"input:0": np.ones((1, 28, 28, 1)).astype(np.float32)})
     frozen_graphdef = tf.graph_util.convert_variables_to_constants(
         sess, sess.graph_def, ["out"])
-    tflite_model = tf.contrib.lite.TocoConverter(
-        frozen_graphdef, [X], [out0]).convert()
+    converter = tf.contrib.lite.TocoConverter.from_session(sess, [X], [out0])
+    tflite_model = converter.convert()
 
     open(resDir+"unsupported.tflite", "wb").write(tflite_model)
index d6df6f2..4105035 100644 (file)
@@ -1,4 +1,4 @@
-#include "passes/tflite_frontend/tflite_importer.h"
+#include "tflite_importer.h"
 #include "gtest/gtest.h"
 #include "pass/PassException.h"
 #include <string>