From ebef502c6749edb6f796a0b96f3ec2c0f318f3fb Mon Sep 17 00:00:00 2001 From: =?utf8?q?Prasanna=20R/SNAP=20/SRI-Bangalore/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 28 Nov 2018 06:24:24 +0530 Subject: [PATCH] Update Validation of Arguments in SquaredDifference CL kernel (#3433) This patch will update validation of arguments in CL Kernel of SquaredDifference op. Signed-off-by: prasannar --- .../core/CL/kernels/CLSquaredDifferenceKernel.cpp | 36 +++++++++++++++++++--- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/libs/ARMComputeEx/src/core/CL/kernels/CLSquaredDifferenceKernel.cpp b/libs/ARMComputeEx/src/core/CL/kernels/CLSquaredDifferenceKernel.cpp index 3e3a17c..71af0f4 100644 --- a/libs/ARMComputeEx/src/core/CL/kernels/CLSquaredDifferenceKernel.cpp +++ b/libs/ARMComputeEx/src/core/CL/kernels/CLSquaredDifferenceKernel.cpp @@ -16,9 +16,16 @@ */ #include "arm_compute/core/CL/kernels/CLSquaredDifferenceKernel.h" +#include "arm_compute/core/AccessWindowStatic.h" #include "arm_compute/core/CL/CLHelpers.h" +#include "arm_compute/core/CL/CLKernelLibrary.h" #include "arm_compute/core/CL/CLKernelLibraryEx.h" #include "arm_compute/core/CL/ICLTensor.h" +#include "arm_compute/core/Error.h" +#include "arm_compute/core/Helpers.h" +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/Window.h" using namespace arm_compute; @@ -27,6 +34,30 @@ namespace constexpr unsigned int num_elems_processed_per_iteration = 16; } +Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output) +{ + const TensorShape &out_shape = + TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); + + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, + DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, + DataType::F16, DataType::F32); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, + "Inputs are not broadcast compatible"); + // Validate in case of configured output + if (output->total_size() > 0) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, + DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), + "Wrong shape for output"); + } + return Status{}; +} + CLSquaredDifferenceKernel::CLSquaredDifferenceKernel() : _input1(nullptr), _input2(nullptr), _output(nullptr) { @@ -35,12 +66,9 @@ CLSquaredDifferenceKernel::CLSquaredDifferenceKernel() void CLSquaredDifferenceKernel::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output) { - ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(input1->info()->tensor_shape(), - input2->info()->tensor_shape()); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(input1->info()->tensor_shape(), - output->info()->tensor_shape()); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2); ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output); + ARM_COMPUTE_ERROR_THROW_ON(validate(input1->info(), input2->info(), output->info())); _input1 = input1; _input2 = input2; -- 2.7.4