arm_compute v18.05
[platform/upstream/armcl.git] / tests / benchmark / fixtures / DepthwiseConvolutionLayerFixture.h
index a156f4b..9276431 100644 (file)
@@ -26,6 +26,7 @@
 
 #include "arm_compute/core/TensorShape.h"
 #include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
 #include "tests/Globals.h"
 #include "tests/Utils.h"
 #include "tests/framework/Fixture.h"
@@ -34,14 +35,27 @@ namespace arm_compute
 {
 namespace test
 {
+namespace benchmark
+{
+using namespace arm_compute::misc::shape_calculator;
+
 /** Fixture that can be used for NEON and CL */
 template <typename TensorType, typename Function, typename Accessor>
 class DepthwiseConvolutionLayerFixture : public framework::Fixture
 {
 public:
     template <typename...>
-    void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape dst_shape, PadStrideInfo info, DataType data_type, int batches)
+    void setup(TensorShape src_shape, Size2D kernel_size, PadStrideInfo info, DataType data_type, int batches)
     {
+        // Get shapes
+        TensorShape weights_shape(kernel_size.width, kernel_size.height);
+
+        const TensorInfo in_info(src_shape, 1, data_type);
+        const TensorInfo we_info(weights_shape, 1, data_type);
+        TensorShape      dst_shape = compute_depthwise_convolution_shape(in_info, we_info, info, 1);
+
+        weights_shape.set(2, dst_shape.z());
+
         // Set batched in source and destination shapes
         const unsigned int fixed_point_position = 4;
         src_shape.set(3 /* batch */, batches);
@@ -89,6 +103,7 @@ private:
     TensorType dst{};
     Function   depth_conv{};
 };
+} // namespace benchmark
 } // namespace test
 } // namespace arm_compute
 #endif /* ARM_COMPUTE_TEST_DEPTHWISECONVOLUTIONFIXTURE */