void Planner::visit(const ::internal::tflite::op::Concat::Node &node)
{
const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
+ const ::internal::tflite::operand::Index axis_index{node.param().axis_index};
// NOTE This implementation assumes that inputs and output are a feature
- // TODO Remove this assumption
- uint32_t input_rank = _ctx.at(ofm_index).shape().rank();
- assert(input_rank == 4 || input_rank == 2);
+ const auto ofm_shape = _ctx.at(ofm_index).shape();
+ uint32_t input_rank = ofm_shape.rank();
+ int32_t axis = _ctx.at(axis_index).asScalar<int32_t>();
- // TODO Remove this assumption
- if (input_rank == 4)
+ // Handle negative axis
+ if (axis < 0)
{
- // NOTE This implementation assumes concat over feature depth
- assert(
- _ctx.at(::internal::tflite::operand::Index{node.param().axis_index}).asScalar<int32_t>() ==
- 3);
-
- const auto ofm_shape = _ctx.at(ofm_index).shape().asFeature();
-
- // TODO Should move to the place where the operand is handled, if it is possible.
- // Set Shape Constraints and TensorInfo (for output)
- _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
+ axis += input_rank;
+ }
- // Set Shape Constraints (for input)
- uint32_t depth = 0;
+ // Set Shape Constraints and TensorInfo (for output)
+ _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
- for (const auto &index : node.param().ifm_indexes)
- {
- const ::internal::tflite::operand::Index ifm_index{index};
- const auto ifm_shape = _ctx.at(ifm_index).shape().asFeature();
+ // Set Shape Constraints and TensorInfo (for input)
+ const uint32_t coord_index = ToARMComputeAxis(input_rank, axis).value();
+ uint32_t depth = 0;
- _builder.addSubsumptionConstr(ifm_index, ofm_index,
- ::arm_compute::Coordinates{0, 0, depth, 0},
- asTensorShape(ifm_shape));
+ ::arm_compute::Coordinates coordinates;
+ coordinates.set_num_dimensions(input_rank);
- depth += ifm_shape.C;
- }
- }
- else if (input_rank == 2)
+ for (const auto &index : node.param().ifm_indexes)
{
- // NOTE This implementation assumes concat over matrix row
- assert(
- _ctx.at(::internal::tflite::operand::Index{node.param().axis_index}).asScalar<int32_t>() ==
- 0);
-
- const auto ofm_shape = _ctx.at(ofm_index).shape();
- const auto ofm_rows = ofm_shape.dim(0);
- const auto ofm_cols = ofm_shape.dim(1);
-
- // TODO Should move to the place where the operand is handled, if it is possible.
- // Set Shape Constraints and TensorInfo (for output)
- _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_rows, ofm_cols, _ctx.at(ofm_index).type()));
+ const ::internal::tflite::operand::Index ifm_index{index};
+ const auto ifm_shape = _ctx.at(ifm_index).shape();
- // Set Shape Constraints (for input)
- uint32_t row_offset = 0;
-
- for (const auto &index : node.param().ifm_indexes)
- {
- const ::internal::tflite::operand::Index ifm_index{index};
- const auto ifm_shape = _ctx.at(ifm_index).shape();
- const auto ifm_rows = ifm_shape.dim(0);
- const auto ifm_cols = ifm_shape.dim(1);
+ coordinates[coord_index] = depth;
- _builder.addSubsumptionConstr(ifm_index, ofm_index,
- ::arm_compute::Coordinates{0, row_offset, 0, 0},
- asTensorShape(ifm_rows, ifm_cols));
+ _builder.addSubsumptionConstr(ifm_index, ofm_index, coordinates, asTensorShape(ifm_shape),
+ true);
- row_offset += ifm_rows;
- }
- }
- else
- {
- // Not implemented yet
- throw std::runtime_error("Not supported, yet");
+ depth += ifm_shape.dim(axis);
}
// NOTE Concat has no actual operation!