arm_compute::PadStrideInfo padStrideInfo = BuildArmComputePadStrideInfo(m_Data.m_Parameters);
- // Check for optimisation opportunities.
- const bool use3x3Optimisation = (weightInfo.GetShape()[2] == 3) && (weightInfo.GetShape()[3] == 3);
- const bool use5x5Optimisation = (weightInfo.GetShape()[2] == 5) && (weightInfo.GetShape()[3] == 5);
-
- if (use3x3Optimisation||use5x5Optimisation)
+ const arm_compute::ITensorInfo* inputInfo = input.info();
+ const arm_compute::ITensorInfo* kernelInfo = m_KernelTensor->info();
+ const arm_compute::ITensorInfo* biasInfo = m_BiasTensor ? m_BiasTensor->info() : nullptr;
+ const arm_compute::ITensorInfo* outputInfo = output.info();
+
+ // Check for optimisation opportunities
+ arm_compute::Status optimizationStatus =
+ arm_compute::NEDepthwiseConvolutionLayerOptimized::validate(inputInfo,
+ kernelInfo,
+ biasInfo,
+ outputInfo,
+ padStrideInfo,
+ depthMultiplier,
+ arm_compute::ActivationLayerInfo(),
+ aclDilationInfo);
+
+ if (optimizationStatus.error_code() == arm_compute::ErrorCode::OK)
{
m_pDepthwiseConvolutionLayer = std::make_unique<arm_compute::NEDepthwiseConvolutionLayerOptimized>();
static_cast<arm_compute::NEDepthwiseConvolutionLayerOptimized*>(