#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "tests/Globals.h"
#include "tests/Utils.h"
#include "tests/framework/Fixture.h"
{
namespace test
{
+namespace benchmark
+{
+using namespace arm_compute::misc::shape_calculator;
+
/** Fixture that can be used for NEON and CL */
template <typename TensorType, typename Function, typename Accessor>
class DepthwiseConvolutionLayerFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape dst_shape, PadStrideInfo info, DataType data_type, int batches)
+ void setup(TensorShape src_shape, Size2D kernel_size, PadStrideInfo info, DataType data_type, int batches)
{
+ // Get shapes
+ TensorShape weights_shape(kernel_size.width, kernel_size.height);
+
+ const TensorInfo in_info(src_shape, 1, data_type);
+ const TensorInfo we_info(weights_shape, 1, data_type);
+ TensorShape dst_shape = compute_depthwise_convolution_shape(in_info, we_info, info, 1);
+
+ weights_shape.set(2, dst_shape.z());
+
// Set batched in source and destination shapes
const unsigned int fixed_point_position = 4;
src_shape.set(3 /* batch */, batches);
TensorType dst{};
Function depth_conv{};
};
+} // namespace benchmark
} // namespace test
} // namespace arm_compute
#endif /* ARM_COMPUTE_TEST_DEPTHWISECONVOLUTIONFIXTURE */