assert(asDataType(info.type().type()) == parent_tensor->info()->data_type());
// TODO Change to set data_layout for each front-end
auto shape = asTensorShape(info.shape(), _layout, _apply_dim_correction_map[current]);
-
- // Only support axis: 3 (channel)
- ::arm_compute::Coordinates coordinates;
- coordinates.set_num_dimensions(4);
- assert(info.offset()[0] == 0);
- assert(info.offset()[1] == 0);
- assert(info.offset()[2] == 0);
- // TODO Change to set data_layout for each front-end
- auto channel_index = get_data_layout_dimension_index(
- parent_tensor->info()->data_layout(), ::arm_compute::DataLayoutDimension::CHANNEL);
- coordinates[channel_index] = info.offset()[3];
+ ::arm_compute::Coordinates coordinates = asTensorCoordinate(info.offset(), _layout);
auto tensor = std::make_shared<T_SubTensor>(parent_tensor.get(), shape, coordinates, true);
_subtensors[current] = tensor;
stack.pop();
#include "cpp14/memory.h"
#include "model/OperandIndexSequence.h"
#include "util/logging.h"
+#include "util/Coordinates.h"
namespace neurun
{
return;
}
- // NOTE This implementation assumes concat over feature depth
- // TODO Remove this assumption
- int32_t axis = _ctx.at(axis_index).asScalar<int32_t>();
- if (axis != 3)
- {
- VERBOSE(SUBTENSOR) << "Cannot handle axis is not channel" << std::endl;
- return;
- }
+ int32_t axis_raw = _ctx.at(axis_index).asScalar<int32_t>();
auto &output_index = node.getOutputs().at(0);
auto &inputs = node.getInputs();
int32_t axis_point = 0;
+ const auto rank = _ctx.at(output_index).shape().rank();
+ uint32_t axis = axis_raw < 0 ? (axis_raw + rank) : axis_raw;
+ assert(rank > axis);
+
for (auto &input_index : inputs)
{
auto input_shape = _ctx.at(input_index).shape();
- std::vector<int32_t> offset = {0, 0, 0, 0};
- offset[axis] = axis_point;
- neurun::util::Coordinates coordinate_info({offset[0], offset[1], offset[2], offset[3]});
+ assert(rank == input_shape.rank());
+
+ neurun::util::Coordinates coordinate_info{};
+ for (uint32_t i = 0; i < rank; i++)
+ {
+ coordinate_info.set(i, 0);
+ }
+ coordinate_info.set(axis, axis_point);
+
std::unique_ptr<graph::operand::ParentInfo> parentInfo =
nnfw::cpp14::make_unique<graph::operand::ParentInfo>(output_index, coordinate_info);