TensorShape inputShape = inputShapes[0];
TensorShape outputShape(inputShape);
- outputShape[0] = inputShape[0];
-
DataLayoutIndexed dimensionIndices{m_Param.m_DataLayout};
unsigned int hIndex = dimensionIndices.GetHeightIndex();
unsigned int wIndex = dimensionIndices.GetWidthIndex();
visitor.VisitSpaceToDepthLayer(this, GetParameters(), GetName());
}
-} // namespace armnn
\ No newline at end of file
+} // namespace armnn
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+ ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ if (m_Parameters.m_BlockSize == 0)
+ {
+ throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
+ }
+
DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
const unsigned int wIndex = dimensionIndices.GetWidthIndex();
const unsigned int hIndex = dimensionIndices.GetHeightIndex();
const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
const TensorShape& inputShape = inputTensorInfo.GetShape();
-
- const unsigned int numInputElements =
- inputShape[0] * inputShape[wIndex] * inputShape[hIndex] * inputShape[cIndex];
- const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
-
- if (numOutputElements != numInputElements)
- {
- throw InvalidArgumentException(descriptorName + ": Input tensor has " +
- std::to_string(numInputElements) + " but output tensor has " +
- std::to_string(numOutputElements) + " elements.");
- }
-
if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
{
throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
"by block size in all spatial dimensions");
}
+
+ const TensorShape& outputShape = outputTensorInfo.GetShape();
+ if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
+ {
+ throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
+ "must be divisible by the square of block size." );
+ }
}
void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const