This commit supports TensorConcat loco node for loco exporter.
Signed-off-by: Cheongyo Bahk <ch.bahk@samsung.com>
}
}
+/// @brief Export CONCATENATION of **TWO** tensors only
+void exportConcat(loco::TensorConcat *node, FlatBufferBuilder &builder, SerializedModelData &gd)
+{
+ uint32_t op_idx = gd.registerBuiltinOpcode(tflite::BuiltinOperator_CONCATENATION);
+ std::vector<int32_t> inputs_vec{gd._node_to_tensor_id[node->lhs()],
+ gd._node_to_tensor_id[node->rhs()]};
+ std::vector<int32_t> outputs_vec{gd._node_to_tensor_id[static_cast<loco::Node *>(node)]};
+ auto inputs = builder.CreateVector(inputs_vec);
+ auto outputs = builder.CreateVector(outputs_vec);
+ auto options = CreateConcatenationOptions(builder, node->axis());
+ auto op_offset = CreateOperator(builder, op_idx, inputs, outputs,
+ tflite::BuiltinOptions_ConcatenationOptions, options.Union());
+
+ gd._operators.push_back(op_offset);
+}
+
void exportNode(loco::Node *node, flatbuffers::FlatBufferBuilder &builder,
SerializedModelData &data)
{
{
exportConv2D(conv2d, builder, data);
}
+ else if (auto *tconcat = dynamic_cast<loco::TensorConcat *>(node))
+ {
+ exportConcat(tconcat, builder, data);
+ }
else
{
assert(false && "unsupported node found");
{
exportOpDefinedTensor(relu, builder, gd);
}
+ else if (auto *tconcat = dynamic_cast<loco::TensorConcat *>(node))
+ {
+ exportOpDefinedTensor(tconcat, builder, gd);
+ }
else
{
assert(false && "unsupported node type");
return gd._node_to_type[node->input()];
}
+tflite::TensorType getOpResultType(loco::TensorConcat *node, SerializedModelData &gd)
+{
+ tflite::TensorType lhs_type = gd._node_to_type[node->lhs()];
+ tflite::TensorType rhs_type = gd._node_to_type[node->rhs()];
+
+ // TODO support heterogenous type combination
+ assert(lhs_type == rhs_type);
+
+ return lhs_type;
+}
+
int32_t decodeShapeDimension(const loco::Dimension &dim)
{
if (!dim.known())
return shape;
}
+ShapeDescription getOpResultShape(loco::TensorConcat *node, SerializedModelData &gd)
+{
+ const ShapeDescription &lhs_shape = gd._node_to_shape[node->lhs()];
+ if (!lhs_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+
+ const ShapeDescription &rhs_shape = gd._node_to_shape[node->rhs()];
+ if (!rhs_shape._rank_known)
+ {
+ // return unknown shape
+ return {};
+ }
+
+ ShapeDescription ret;
+
+ assert(lhs_shape._dims.size() == rhs_shape._dims.size());
+ ret._dims.resize(lhs_shape._dims.size());
+
+ uint32_t axis = node->axis();
+
+ for (uint32_t i = 0; i < lhs_shape._dims.size(); ++i)
+ {
+ if (i == axis)
+ {
+ ret._dims[i] = lhs_shape._dims[i] + rhs_shape._dims[i];
+ }
+ else
+ {
+ assert(lhs_shape._dims[i] == rhs_shape._dims[i]);
+ ret._dims[i] = lhs_shape._dims[i];
+ }
+ }
+ ret._rank_known = true;
+
+ return ret;
+}
+
} // namespace loco_exporter
tflite::TensorType getOpResultType(loco::FilterEncode *node, SerializedModelData &gd);
+tflite::TensorType getOpResultType(loco::TensorConcat *node, SerializedModelData &gd);
+
// Shape inference functions
ShapeDescription getOpResultShape(loco::Pull *node, SerializedModelData &);
ShapeDescription getOpResultShape(loco::FeatureDecode *node, SerializedModelData &gd);
ShapeDescription getOpResultShape(loco::FilterEncode *node, SerializedModelData &gd);
+
+ShapeDescription getOpResultShape(loco::TensorConcat *node, SerializedModelData &gd);
}
#endif //__LOCO_EXPORTER_TYPEINFERENCE_H__