[loco exporter] Support TensorConcat (#3861)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Tue, 18 Jun 2019 11:49:16 +0000 (20:49 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 18 Jun 2019 11:49:16 +0000 (20:49 +0900)
This commit supports TensorConcat loco node for loco exporter.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
contrib/loco-exporter/src/OperationExporter.cpp
contrib/loco-exporter/src/TensorExporter.cpp
contrib/loco-exporter/src/TypeInference.cpp
contrib/loco-exporter/src/TypeInference.h

index 3c962dd..0711f3b 100644 (file)
@@ -249,6 +249,22 @@ void exportFilterEncode(loco::FilterEncode *node, FlatBufferBuilder &builder,
   }
 }
 
+/// @brief Export CONCATENATION of **TWO** tensors only
+void exportConcat(loco::TensorConcat *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+{
+  uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION);
+  std::vector<int32_t> inputs_vec{gd._node_to_tensor_id[node->lhs()],
+                                  gd._node_to_tensor_id[node->rhs()]};
+  std::vector<int32_t> outputs_vec{gd._node_to_tensor_id[static_cast<loco::Node *>(node)]};
+  auto inputs = builder.CreateVector(inputs_vec);
+  auto outputs = builder.CreateVector(outputs_vec);
+  auto options = CreateConcatenationOptions(builder, node->axis());
+  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+                                  tflite::BuiltinOptions_ConcatenationOptions, options.Union());
+
+  gd._operators.push_back(op_offset);
+}
+
 void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
                 SerializedModelData &data)
 {
@@ -292,6 +308,10 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
   {
     exportConv2D(conv2d, builder, data);
   }
+  else if (auto *tconcat = dynamic_cast<loco::TensorConcat *>(node))
+  {
+    exportConcat(tconcat, builder, data);
+  }
   else
   {
     assert(false && "unsupported node found");
index 3703878..17392fd 100644 (file)
@@ -163,6 +163,10 @@ void exportOpDefinedTensors(loco::Graph::NodeContext *nodes, FlatBufferBuilder &
     {
       exportOpDefinedTensor(relu, builder, gd);
     }
+    else if (auto *tconcat = dynamic_cast<loco::TensorConcat *>(node))
+    {
+      exportOpDefinedTensor(tconcat, builder, gd);
+    }
     else
     {
       assert(false && "unsupported node type");
index ba5c1ad..efcd922 100644 (file)
@@ -107,6 +107,17 @@ tflite::TensorType getOpResultType(loco::FilterEncode *node, SerializedModelData
   return gd._node_to_type[node->input()];
 }
 
+tflite::TensorType getOpResultType(loco::TensorConcat *node, SerializedModelData &gd)
+{
+  tflite::TensorType lhs_type = gd._node_to_type[node->lhs()];
+  tflite::TensorType rhs_type = gd._node_to_type[node->rhs()];
+
+  // TODO support heterogenous type combination
+  assert(lhs_type == rhs_type);
+
+  return lhs_type;
+}
+
 int32_t decodeShapeDimension(const loco::Dimension &dim)
 {
   if (!dim.known())
@@ -374,4 +385,44 @@ ShapeDescription getOpResultShape(loco::FilterEncode *node, SerializedModelData
   return shape;
 }
 
+ShapeDescription getOpResultShape(loco::TensorConcat *node, SerializedModelData &gd)
+{
+  const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
+  if (!lhs_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+
+  const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
+  if (!rhs_shape._rank_known)
+  {
+    // return unknown shape
+    return {};
+  }
+
+  ShapeDescription ret;
+
+  assert(lhs_shape._dims.size() == rhs_shape._dims.size());
+  ret._dims.resize(lhs_shape._dims.size());
+
+  uint32_t axis = node->axis();
+
+  for (uint32_t i = 0; i < lhs_shape._dims.size(); ++i)
+  {
+    if (i == axis)
+    {
+      ret._dims[i] = lhs_shape._dims[i] + rhs_shape._dims[i];
+    }
+    else
+    {
+      assert(lhs_shape._dims[i] == rhs_shape._dims[i]);
+      ret._dims[i] = lhs_shape._dims[i];
+    }
+  }
+  ret._rank_known = true;
+
+  return ret;
+}
+
 } // namespace loco_exporter
index 3372d44..6a7595a 100644 (file)
@@ -43,6 +43,8 @@ tflite::TensorType getOpResultType(loco::FeatureDecode *node, SerializedModelDat
 
 tflite::TensorType getOpResultType(loco::FilterEncode *node, SerializedModelData &gd);
 
+tflite::TensorType getOpResultType(loco::TensorConcat *node, SerializedModelData &gd);
+
 // Shape inference functions
 
 ShapeDescription getOpResultShape(loco::Pull *node, SerializedModelData &);
@@ -62,6 +64,8 @@ ShapeDescription getOpResultShape(loco::FeatureEncode *node, SerializedModelData
 ShapeDescription getOpResultShape(loco::FeatureDecode *node, SerializedModelData &gd);
 
 ShapeDescription getOpResultShape(loco::FilterEncode *node, SerializedModelData &gd);
+
+ShapeDescription getOpResultShape(loco::TensorConcat *node, SerializedModelData &gd);
 }
 
 #endif //__LOCO_EXPORTER_TYPEINFERENCE_H__