Fix some minor issues around SpaceToDepth
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Thu, 19 Sep 2019 13:39:37 +0000 (14:39 +0100)
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>
Fri, 20 Sep 2019 14:49:41 +0000 (14:49 +0000)
* Removed unnecessary code from SpaceToDepthLayer::InferOutputShapes()
* Refactored SpaceToDepthQueueDescriptor::Validate() and added extra
  checks for block size and output depth

Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: Ieeed3144e2589b2e8695ef65ce17752bc595332f

src/armnn/layers/SpaceToDepthLayer.cpp
src/backends/backendsCommon/WorkloadData.cpp

index b24490f..8a9f1c2 100644 (file)
@@ -47,8 +47,6 @@ std::vector<TensorShape> SpaceToDepthLayer::InferOutputShapes(const std::vector<
     TensorShape inputShape = inputShapes[0];
     TensorShape outputShape(inputShape);
 
-    outputShape[0] = inputShape[0];
-
     DataLayoutIndexed dimensionIndices{m_Param.m_DataLayout};
     unsigned int hIndex = dimensionIndices.GetHeightIndex();
     unsigned int wIndex = dimensionIndices.GetWidthIndex();
@@ -82,4 +80,4 @@ void SpaceToDepthLayer::Accept(ILayerVisitor& visitor) const
     visitor.VisitSpaceToDepthLayer(this, GetParameters(), GetName());
 }
 
-} // namespace armnn
\ No newline at end of file
+} // namespace armnn
index 52d1409..3fbdec7 100644 (file)
@@ -1408,29 +1408,31 @@ void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
     ValidateDataTypes(inputTensorInfo,  supportedTypes, descriptorName);
     ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
 
+    ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+    if (m_Parameters.m_BlockSize == 0)
+    {
+        throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
+    }
+
     DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
     const unsigned int wIndex = dimensionIndices.GetWidthIndex();
     const unsigned int hIndex = dimensionIndices.GetHeightIndex();
     const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
 
     const TensorShape& inputShape = inputTensorInfo.GetShape();
-
-    const unsigned int numInputElements  =
-        inputShape[0] * inputShape[wIndex] * inputShape[hIndex] * inputShape[cIndex];
-    const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
-
-    if (numOutputElements != numInputElements)
-    {
-        throw InvalidArgumentException(descriptorName + ": Input tensor has " +
-            std::to_string(numInputElements) + " but output tensor has " +
-            std::to_string(numOutputElements) + " elements.");
-    }
-
     if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex]  % m_Parameters.m_BlockSize != 0)
     {
         throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
                                        "by block size in all spatial dimensions");
     }
+
+    const TensorShape& outputShape = outputTensorInfo.GetShape();
+    if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
+    {
+        throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
+                                       "must be divisible by the square of block size." );
+    }
 }
 
 void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const