[loco exporter] Support BiasAdd<Tensor> (#3875)
author박천교/On-Device Lab(SR)/Engineer/삼성전자 <ch.bahk@samsung.com>
Wed, 19 Jun 2019 05:24:42 +0000 (14:24 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 19 Jun 2019 05:24:42 +0000 (14:24 +0900)
This commit supports BiasAdd loco node of Tensor domain for loco
exporter.

Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
contrib/loco-exporter/src/OperationExporter.cpp
contrib/loco-exporter/src/TensorExporter.cpp
contrib/loco-exporter/src/TypeInference.cpp
contrib/loco-exporter/src/TypeInference.h

index c7d6a07..64d9f9d 100644 (file)
@@ -249,6 +249,21 @@ void exportFilterEncode(loco::FilterEncode *node, FlatBufferBuilder &builder,
   }
 }
 
+void exportBiasAdd(loco::BiasAdd<loco::Domain::Tensor> *node, FlatBufferBuilder &builder,
+                   SerializedModelData &gd)
+{
+  uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_ADD);
+  std::vector<int32_t> inputs_vec{gd._node_to_tensor_id[node->value()],
+                                  gd._node_to_tensor_id[node->bias()]};
+  std::vector<int32_t> outputs_vec{gd._node_to_tensor_id[static_cast<loco::Node *>(node)]};
+  auto inputs = builder.CreateVector(inputs_vec);
+  auto outputs = builder.CreateVector(outputs_vec);
+  auto options = CreateAddOptions(builder); // dummy option
+  auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+                                  tflite::BuiltinOptions_AddOptions, options.Union());
+  gd._operators.push_back(op_offset);
+}
+
 /// @brief Export CONCATENATION of **TWO** tensors only
 void exportConcat(loco::TensorConcat *node, FlatBufferBuilder &builder, SerializedModelData &gd)
 {
@@ -316,6 +331,10 @@ void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
   {
     exportIdentity(encode, builder, data);
   }
+  else if (auto *biasadd = dynamic_cast<loco::BiasAdd<loco::Domain::Tensor> *>(node))
+  {
+    exportBiasAdd(biasadd, builder, data);
+  }
   else
   {
     assert(false && "unsupported node found");
index 67c2b9f..0b8b5de 100644 (file)
@@ -171,6 +171,10 @@ void exportOpDefinedTensors(loco::Graph::NodeContext *nodes, FlatBufferBuilder &
     {
       exportOpDefinedTensor(encode, builder, gd);
     }
+    else if (auto *biasadd = dynamic_cast<loco::BiasAdd<loco::Domain::Tensor> *>(node))
+    {
+      exportOpDefinedTensor(biasadd, builder, gd);
+    }
     else
     {
       assert(false && "unsupported node type");
index b26e7d3..4d1b0eb 100644 (file)
@@ -123,6 +123,18 @@ tflite::TensorType getOpResultType(loco::BiasEncode *node, SerializedModelData &
   return gd._node_to_type[node->input()];
 }
 
+tflite::TensorType getOpResultType(loco::BiasAdd<loco::Domain::Tensor> *node,
+                                   SerializedModelData &gd)
+{
+  tflite::TensorType value_type = gd._node_to_type[node->value()];
+  tflite::TensorType bias_type = gd._node_to_type[node->bias()];
+
+  // TODO support heterogenous type combination
+  assert(value_type == bias_type);
+
+  return value_type;
+}
+
 int32_t decodeShapeDimension(const loco::Dimension &dim)
 {
   if (!dim.known())
@@ -440,4 +452,23 @@ ShapeDescription getOpResultShape(loco::BiasEncode *node, SerializedModelData &g
   return input_shape;
 }
 
+ShapeDescription getOpResultShape(loco::BiasAdd<loco::Domain::Tensor> *node,
+                                  SerializedModelData &gd)
+{
+  const ShapeDescription &value_shape = gd._node_to_shape[node->value()];
+  const ShapeDescription &bias_shape = gd._node_to_shape[node->bias()];
+
+  // For TFlite, only supports last bias add axis. Unless, broadcasting is not performed as
+  // expected.
+  assert(node->axis() == value_shape._dims.size() - 1);
+
+  // Bias should be rank 1
+  assert(bias_shape._dims.size() == 1);
+
+  // Channel count coherency for proper broadcast
+  assert(bias_shape._dims[0] == value_shape._dims[node->axis()]);
+
+  return value_shape;
+}
+
 } // namespace loco_exporter
index 356a1e9..b87aeb0 100644 (file)
@@ -47,6 +47,9 @@ tflite::TensorType getOpResultType(loco::TensorConcat *node, SerializedModelData
 
 tflite::TensorType getOpResultType(loco::BiasEncode *node, SerializedModelData &gd);
 
+tflite::TensorType getOpResultType(loco::BiasAdd<loco::Domain::Tensor> *node,
+                                   SerializedModelData &gd);
+
 // Shape inference functions
 
 ShapeDescription getOpResultShape(loco::Pull *node, SerializedModelData &);
@@ -70,6 +73,9 @@ ShapeDescription getOpResultShape(loco::FilterEncode *node, SerializedModelData
 ShapeDescription getOpResultShape(loco::TensorConcat *node, SerializedModelData &gd);
 
 ShapeDescription getOpResultShape(loco::BiasEncode *node, SerializedModelData &gd);
+
+ShapeDescription getOpResultShape(loco::BiasAdd<loco::Domain::Tensor> *node,
+                                  SerializedModelData &gd);
 }
 
 #endif //__LOCO_EXPORTER_TYPEINFERENCE_H__