IVGCVSW-3640 Add multi-channel TransposeConvolution2d unit tests to CL backend
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Wed, 14 Aug 2019 13:37:42 +0000 (14:37 +0100)
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>
Tue, 20 Aug 2019 14:25:56 +0000 (14:25 +0000)
* Fixed bug in multi-channel test and reference workload implementation
* Enabled multi-channel tests on CL backend

Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com>
Change-Id: I08bb523bc68d9c93a1012b4f487a5bce36a085b1

src/backends/backendsCommon/test/TransposeConvolution2dTestImpl.hpp
src/backends/cl/test/ClLayerTests.cpp
src/backends/reference/workloads/TransposeConvolution2d.cpp

index 9140c19..64caa3f 100644 (file)
@@ -493,7 +493,8 @@ LayerTestResult<T, 4> MultiChannelTransposeConvolution2dTest(
     TensorShape inputShape   = MakeTensorShape(1, 1, 2, 2, layout);
     TensorShape outputShape  = MakeTensorShape(1, 2, 5, 5, layout);
 
-    TensorShape weightsShape = MakeTensorShape(1, 2, 3, 3, layout);
+    // OIHW for NCHW; OHWI for NHWC
+    TensorShape weightsShape = MakeTensorShape(2, 1, 3, 3, layout);
     TensorShape biasesShape  = { 2 };
 
     TensorInfo inputInfo(inputShape, ArmnnType);
index d3f3921..8a5435b 100644 (file)
@@ -760,6 +760,19 @@ ARMNN_AUTO_TEST_CASE(UnbiasedStridedTransposeConvolution2dUint8Nhwc,
                      true,
                      DataLayout::NHWC)
 
+ARMNN_AUTO_TEST_CASE(MultiChannelTransposeConvolution2dFloatNchw,
+                     MultiChannelTransposeConvolution2dTest<DataType::Float32, DataType::Float32>,
+                     DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(MultiChannelTransposeConvolution2dFloatNhwc,
+                     MultiChannelTransposeConvolution2dTest<DataType::Float32, DataType::Float32>,
+                     DataLayout::NHWC)
+ARMNN_AUTO_TEST_CASE(MultiChannelTransposeConvolution2dUint8Nchw,
+                     MultiChannelTransposeConvolution2dTest<DataType::QuantisedAsymm8, DataType::Signed32>,
+                     DataLayout::NCHW)
+ARMNN_AUTO_TEST_CASE(MultiChannelTransposeConvolution2dUint8Nhwc,
+                     MultiChannelTransposeConvolution2dTest<DataType::QuantisedAsymm8, DataType::Signed32>,
+                     DataLayout::NHWC)
+
 // ============================================================================
 // COMPARE tests
 
index acbfe0c..52cc18c 100644 (file)
@@ -83,7 +83,7 @@ void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descript
                                     inputDecoder[inputIndex];
 
                                     const unsigned int weightsIndex =
-                                        dataLayoutIndexed.GetIndex(weightsShape, batch, dOutput, yWeights, xWeights);
+                                        dataLayoutIndexed.GetIndex(weightsShape, dOutput, dInput, yWeights, xWeights);
                                     weightsDecoder[weightsIndex];
 
                                     const unsigned int outputIndex =