public:
void publish(std::shared_ptr<const internal::arm_compute::Plan> &plan) { plan = _plan; }
+ void publish(std::shared_ptr<const std::map<int, ::internal::tflite::operand::Shape>>
+ &broadcasting_tensor_shape)
+ {
+ broadcasting_tensor_shape = _broadcasting_tensor_shape;
+ }
+ void
+ setShapeForBroadcast(std::map<int, ::internal::tflite::operand::Shape> &broadcasting_tensor_shape)
+ {
+ _broadcasting_tensor_shape =
+ std::make_shared<const std::map<int, ::internal::tflite::operand::Shape>>(
+ broadcasting_tensor_shape);
+ }
private:
std::shared_ptr<internal::arm_compute::Plan> _plan;
+ std::shared_ptr<const std::map<int, ::internal::tflite::operand::Shape>>
+ _broadcasting_tensor_shape;
};
#endif
ANeuralNetworksExecution **execution)
{
std::shared_ptr<const ::internal::arm_compute::Plan> plan;
-
compilation->publish(plan);
-
*execution = new ANeuralNetworksExecution{plan};
+ std::shared_ptr<const std::map<int, ::internal::tflite::operand::Shape>>
+ broadcasting_tensor_shape;
+ compilation->publish(broadcasting_tensor_shape);
+ (*execution)->setShapeForBroadcast(broadcasting_tensor_shape);
+
return ANEURALNETWORKS_NO_ERROR;
}
input_type = type->type;
}
- if (operands.at(operand_index).shape().rank() == 1)
+ auto shape = operands.at(operand_index).shape();
+ auto rank = shape.rank();
+
+ if (execution->shapeForBroadcast() != nullptr)
+ {
+ auto it = execution->shapeForBroadcast()->find(operand_index.asInt());
+ if (it != execution->shapeForBroadcast()->end())
+ {
+ rank = 4;
+ shape = it->second;
+ }
+ }
+
+ if (rank == 1)
{
- const auto len = operands.at(operand_index).shape().dim(0);
+ const auto len = shape.dim(0);
asVectorSource(execution, input_type, index, len, buffer, length);
}
- else if (operands.at(operand_index).shape().rank() == 2)
+ else if (rank == 2)
{
- const auto &operand_shape = operands.at(operand_index).shape().asMatrix();
+ const auto &operand_shape = shape.asMatrix();
asMatrixSource(execution, input_type, index, operand_shape, buffer, length);
}
- else if (operands.at(operand_index).shape().rank() == 3)
+ else if (rank == 3)
{
- const auto &operand_shape = operands.at(operand_index).shape().asTensor();
+ const auto &operand_shape = shape.asTensor();
asTensorSource(execution, input_type, index, operand_shape, buffer, length);
}
- else if (operands.at(operand_index).shape().rank() == 4)
+ else if (rank == 4)
{
- const auto &operand_shape = operands.at(operand_index).shape().asFeature();
+ const auto &operand_shape = shape.asFeature();
asFeatureSource(execution, input_type, index, operand_shape, buffer, length);
}
public:
const internal::arm_compute::Plan &plan(void) const { return *_plan; }
+ std::shared_ptr<const std::map<int, ::internal::tflite::operand::Shape>> shapeForBroadcast(void)
+ {
+ return _broadcasting_tensor_shape;
+ }
+ void setShapeForBroadcast(
+ const std::shared_ptr<const std::map<int, ::internal::tflite::operand::Shape>>
+ &broadcasting_tensor_shape)
+ {
+ _broadcasting_tensor_shape = broadcasting_tensor_shape;
+ }
private:
std::shared_ptr<const internal::arm_compute::Plan> _plan;
+ std::shared_ptr<const std::map<int, ::internal::tflite::operand::Shape>>
+ _broadcasting_tensor_shape = nullptr;
public:
// TODO Use InputIndex instead of int