}
}
+void ValidateTensorMaxNumElements(const TensorInfo& tensor,
+ std::string const& descName,
+ unsigned int maxNumElements,
+ std::string const& tensorName)
+{
+ if (tensor.GetNumElements() > maxNumElements)
+ {
+ throw InvalidArgumentException(descName + ": Expected maximum of " + to_string(maxNumElements) + " but got " +
+ to_string(tensor.GetNumElements()) + " elements for " + tensorName + " tensor.");
+ }
+}
+
//---------------------------------------------------------------
void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
const std::string& descName, std::string const& tensorName)
{
ValidateSingleInput(workloadInfo, "MeanQueueDescriptor");
ValidateSingleOutput(workloadInfo, "MeanQueueDescriptor");
+
+ const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
+
+ if (m_Keepdims)
+ {
+ ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output");
+ }
+ else if (m_Axis == nullptr)
+ {
+ ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output");
+ }
+ else
+ {
+ const TensorInfo& axis = m_Axis->GetTensorInfo();
+ ValidateTensorNumDimensions(axis, "MeanQueueDescriptor", 1, "axis");
+ ValidateTensorMaxNumElements(axis, "MeanQueueDescriptor", input.GetNumDimensions(), "axis");
+ unsigned int outputDim = input.GetNumDimensions() - axis.GetNumElements();
+ ValidateTensorNumDimensions(output,
+ "MeanQueueDescriptor",
+ outputDim > 0 ? outputDim : 1,
+ "output");
+ }
}
} //namespace armnn