template<typename Iterable>
mir::Shape ShapeHelper::createShape(const Iterable &iter, std::size_t size)
{
+ //Zero-dim tensor is just a tensor with 1 element
+ if (size == 0) {
+ return mir::Shape{1};
+ }
+
mir::Shape sh;
sh.resize(static_cast<int32_t>(size));
Shape tensor_shape = ShapeHelper::createShape(*t->shape(), t->shape()->size());
+ assert(tensor_shape.numElements() * elementSize == b->data()->size());
+
return std::make_shared<IrTensor>(tensor_shape, tensor_buffer_copy, type, elementSize);
}
auto tensor = mir::Tensor<int>(*params.at(0));
std::vector<int32_t> axes;
- if (params.at(0)->getShape().rank() == 0) {
- // TODO: Dangerous black magic (Default construced Index is 0 dim, as is 0 dim Tensor)
- axes.push_back(tensor.at(Index()));
- } else {
- for (const auto& i: mir::ShapeRange(tensor.getShape())) {
- axes.emplace_back(tensor.at(i));
- }
+ for (const auto& i: mir::ShapeRange(tensor.getShape())) {
+ axes.emplace_back(tensor.at(i));
}
+
+ std::sort(axes.begin(), axes.end());
+
return createOp<ops::ReduceFOp>(
ActivationFunctionType_NONE, inputs[0]->getOutput(0),
axes, opts->keep_dims(), ft);