#include <nncc/core/ADT/tensor/Shape.h>
#include <schema_generated.h>
+#include <array>
+
using namespace nncc::core::ADT;
using namespace morph::tflite;
+namespace
+{
+
+/**
+ * @brief Convert a numeric tensor axis as a ConcatF FeatureAxis value
+ */
+coco::ConcatF::Axis as_ConcatF_axis(uint32_t axis)
+{
+ // NOTE The feature map (in TensorFlow) is a rank-4 (NHWC) tensor
+ assert(axis < 4);
+
+ coco::ConcatF::Axis res = coco::ConcatF::Axis::Unknown;
+
+ switch (axis)
+ {
+ case 0:
+ res = coco::ConcatF::Axis::Batch;
+ break;
+ case 1:
+ res = coco::ConcatF::Axis::Height;
+ break;
+ case 2:
+ res = coco::ConcatF::Axis::Width;
+ break;
+ case 3:
+ res = coco::ConcatF::Axis::Depth;
+ break;
+ default:
+ break;
+ }
+
+ return res;
+}
+
+/**
+ * @brief Convert a feature shape as an array of 'unit32_t' values
+ */
+std::array<uint32_t, 4> as_dims(const feature::Shape &shape)
+{
+ std::array<uint32_t, 4> res;
+
+ res[0] = 1; /* BATCH */
+ res[1] = shape.height();
+ res[2] = shape.width();
+ res[3] = shape.depth();
+
+ return res;
+}
+
+} // namespace
+
namespace tflimport
{
assert(concat_axis >= 0);
assert(concat_axis < rank);
}
- // TODO handle other axis
- assert(concat_axis == 3);
+ assert(as_ConcatF_axis(concat_axis) != coco::ConcatF::Axis::Unknown);
assert(activation == tflite::ActivationFunctionType_NONE);
// Construct a vector of input objects
assert(left_feature->layout()->batch() == 1);
assert(right_feature->layout()->batch() == 1);
- // Height and Width SHOULD BE IDENTICAL for depth concat
- assert(left_shape.height() == right_shape.height());
- assert(left_shape.width() == right_shape.width());
+ // Compute output dimensionalities
+ auto compute_out_dims = [&left_shape, &right_shape, concat_axis](void) {
+ std::array<uint32_t, 4> out_dims;
+
+ const auto left_dims = as_dims(left_shape);
+ const auto right_dims = as_dims(right_shape);
+
+ for (uint32_t axis = 0; axis < 4 /* FEATURE MAP RANK */; ++axis)
+ {
+ // The dimensionality of all the axises except 'concat' axis SHOULD BE INDETICAL
+ assert((concat_axis == axis) || (left_dims[axis] == right_dims[axis]));
+
+ out_dims[axis] = left_dims[axis];
+ if (axis == concat_axis)
+ {
+ out_dims[axis] += right_dims[axis];
+ }
+ }
+
+ return out_dims;
+ };
+
+ const auto out_dims = compute_out_dims();
- const uint32_t C = left_shape.depth() + right_shape.depth();
- const uint32_t H = left_shape.height();
- const uint32_t W = left_shape.width();
+ const uint32_t C = out_dims[3 /* DEPTH */];
+ const uint32_t H = out_dims[1 /* HEIGHT */];
+ const uint32_t W = out_dims[2 /* WIDTH */];
const nncc::core::ADT::feature::Shape out_shape{C, H, W};
auto concat_f = m->entity()->op()->create<coco::ConcatF>();
- concat_f->axis(coco::ConcatF::Axis::Depth);
+ concat_f->axis(as_ConcatF_axis(concat_axis));
concat_f->left(left_load);
concat_f->right(right_load);