const TensorInfo& inputTensorInfo = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
- ValidateTensorNumDimensions(inputTensorInfo, descriptorName, 4, "input");
-
std::vector<DataType> supportedTypes =
{
DataType::BFloat16,
// Initialise temp output
std::vector<float> tempOut(numOutputs);
- if (reduceOperation == ReduceOperation::Max || reduceOperation == ReduceOperation::Min)
+ switch(reduceOperation)
{
- for (unsigned int idx = 0; idx < numOutputs; ++idx)
- {
- input[idx];
- tempOut[idx] = input.Get();
- }
- }
- else
- {
- std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ case ReduceOperation::Mean:
+ case ReduceOperation::Sum:
+ std::fill(tempOut.begin(), tempOut.end(), 0.0);
+ break;
+ case ReduceOperation::Max:
+ std::fill(tempOut.begin(), tempOut.end(), -1 * std::numeric_limits<float>::max());
+ break;
+ case ReduceOperation::Min:
+ std::fill(tempOut.begin(), tempOut.end(), std::numeric_limits<float>::max());
+ break;
+ default:
+ throw armnn::InvalidArgumentException("Unknown reduce method: " +
+ std::to_string(static_cast<int>(reduceOperation)));
}
// Initialise temp index