[exo] Export TFLTransposeConv (#8260)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Thu, 17 Oct 2019 06:26:13 +0000 (15:26 +0900)
committer박세희/On-Device Lab(SR)/Principal Engineer/삼성전자 <saehie.park@samsung.com>
Thu, 17 Oct 2019 06:26:13 +0000 (15:26 +0900)
This commit introduces export stage of TFLTransposeConv operator

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
compiler/exo/src/TFLite/TFLOperationExporter.cpp

index 27a1dfd..23b2695 100644 (file)
@@ -68,6 +68,7 @@ public:
   void visit(locoex::TFLSub *) final;
   // TODO TFLTanh
   void visit(locoex::TFLTranspose *) final;
+  void visit(locoex::TFLTransposeConv *) final;
 
   // FOR canonical nodes. These will be removed later
   void visit(loco::ReLU *) final;
@@ -247,6 +248,27 @@ void OperationExporter::visit(locoex::TFLTranspose *node)
   gd._operators.push_back(op_offset);
 }
 
+void OperationExporter::visit(locoex::TFLTransposeConv *node)
+{
+  uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_TRANSPOSE_CONV);
+
+  // Make input, output and options for operator
+  std::vector<int32_t> inputs_vec{get_tensor_index(node->inputSizes()),
+                                  get_tensor_index(node->filter()),
+                                  get_tensor_index(node->outBackprop())};
+  std::vector<int32_t> outputs_vec{get_tensor_index(static_cast<loco::Node *>(node))};
+  auto inputs = builder.CreateVector(inputs_vec);
+  auto outputs = builder.CreateVector(outputs_vec);
+  tflite::Padding padding = getOpPadding(node->padding());
+  auto options =
+      CreateTransposeConvOptions(builder, padding, node->stride()->w(), node->stride()->h());
+
+  // Make TRANSPOSE_CONV operator
+  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+                                  tflite::BuiltinOptions_TransposeConvOptions, options.Union());
+  gd._operators.push_back(op_offset);
+}
+
 template <class TFLPool2D>
 void OperationExporter::export_pool_2d(TFLPool2D *node, tflite::BuiltinOperator builtin_op)
 {