Create Model IR operators in TFLite model visitor (#398)
authorDmitry Mozolev/AI Tools Lab /SRR/Engineer/삼성전자 <d.mozolev@samsung.com>
Fri, 29 Jun 2018 10:21:52 +0000 (13:21 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Fri, 29 Jun 2018 10:21:52 +0000 (19:21 +0900)
This commit adds the code that makes use of the TFLite-to-Model-IR
operator creator to construct Model IR graph.

Signed-off-by: Dmitry Mozolev <d.mozolev@samsung.com>
contrib/nnc/libs/frontend/tflite/include/tflite_ir_visitor.h
contrib/nnc/libs/frontend/tflite/src/tflite_ir_visitor.cpp

index 46ef5e1..c3dbdf6 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "schema_v3.h"
 #include "tflite_visitor.h"
+#include "tflite_op_creator.h"
 
 namespace nncc
 {
@@ -47,6 +48,7 @@ public:
 
 private:
   Graph *graph = nullptr;
+  std::unique_ptr<OpCreator> opCreator;
 
   const flatbuffers::Vector<flatbuffers::Offset<OperatorCode>> *opcodes = nullptr;
   const flatbuffers::Vector<flatbuffers::Offset<Tensor>> *tensors = nullptr;
index 52f25db..1264260 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "shape_helper.h"
 #include "tflite_ir_visitor.h"
+#include "tflite_op_creator.h"
 
 namespace nncc
 {
@@ -31,6 +32,7 @@ IrVisitor::IrVisitor()
   // TODO: make this a smart pointer. Note that it requires changing the NNImporter interface,
   //       because currently it returns a void*.
   graph = new Graph();
+  opCreator.reset(new OpCreator(graph));
 }
 
 void IrVisitor::visit(const Model *m)
@@ -62,7 +64,47 @@ void IrVisitor::visit(const SubGraph *s)
 
 void IrVisitor::visit(const Operator *op)
 {
-  throw std::runtime_error{"Not yet implemented"};
+  auto inputs = createOpInputs(op);
+  auto params = createOpParams(op);
+
+  std::vector<INode::Ref> outputs;
+
+  unsigned int opcode = (*opcodes)[op->opcode_index()]->builtin_code();
+  // TODO: support other NN operator types
+  switch (opcode)
+  {
+  case BuiltinOperator_CONV_2D:
+    outputs = opCreator->createConv2D(inputs, params, op->builtin_options_as<Conv2DOptions>());
+    break;
+  case BuiltinOperator_DEPTHWISE_CONV_2D:
+    outputs = opCreator->createDepthConv2D(inputs, params,
+                                          op->builtin_options_as<DepthwiseConv2DOptions>());
+    break;
+  case BuiltinOperator_MAX_POOL_2D:
+    outputs = opCreator->createMaxPool(inputs, params, op->builtin_options_as<Pool2DOptions>());
+    break;
+  case BuiltinOperator_AVERAGE_POOL_2D:
+    outputs = opCreator->createAvgPool(inputs, params, op->builtin_options_as<Pool2DOptions>());
+    break;
+  case BuiltinOperator_CONCATENATION:
+    outputs = opCreator->createConcat(inputs, params, op->builtin_options_as<ConcatenationOptions>());
+    break;
+  case BuiltinOperator_RESHAPE:
+    outputs = opCreator->createReshape(inputs, params, op->builtin_options_as<ReshapeOptions>());
+    break;
+  case BuiltinOperator_SOFTMAX:
+    outputs = opCreator->createSoftmax(inputs, params, op->builtin_options_as<SoftmaxOptions>());
+    break;
+  default:
+    throw PluginException(
+            std::string("Encountered unsupported TFLite operator: ") +
+            EnumNamesBuiltinOperator()[opcode]);
+  }
+
+  for (int i = 0; i < op->outputs()->size(); ++i)
+  {
+    opsForTensorsTheyOutput[(*(op->outputs()))[i]] = outputs[i];
+  }
 }
 
 void IrVisitor::visit(const Tensor *) {}