Update Validation of Arguments in SquaredDifference CL kernel (#3433)
authorPrasanna R/SNAP /SRI-Bangalore/Engineer/삼성전자 <prasanna.r@samsung.com>
Wed, 28 Nov 2018 00:54:24 +0000 (06:24 +0530)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 28 Nov 2018 00:54:24 +0000 (09:54 +0900)
This patch will update validation of arguments in CL Kernel of SquaredDifference op.

Signed-off-by: prasannar <prasanna.r@samsung.com>
libs/ARMComputeEx/src/core/CL/kernels/CLSquaredDifferenceKernel.cpp

index 3e3a17c..71af0f4 100644 (file)
  */
 #include "arm_compute/core/CL/kernels/CLSquaredDifferenceKernel.h"
 
+#include "arm_compute/core/AccessWindowStatic.h"
 #include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLKernelLibrary.h"
 #include "arm_compute/core/CL/CLKernelLibraryEx.h"
 #include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
 
 using namespace arm_compute;
 
@@ -27,6 +34,30 @@ namespace
 constexpr unsigned int num_elems_processed_per_iteration = 16;
 }
 
+Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+{
+  const TensorShape &out_shape =
+      TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
+
+  ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16,
+                                                       DataType::F16, DataType::F32);
+  ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16,
+                                                       DataType::F16, DataType::F32);
+
+  ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0,
+                                  "Inputs are not broadcast compatible");
+  // Validate in case of configured output
+  if (output->total_size() > 0)
+  {
+    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16,
+                                                         DataType::F16, DataType::F32);
+    ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+        detail::have_different_dimensions(out_shape, output->tensor_shape(), 0),
+        "Wrong shape for output");
+  }
+  return Status{};
+}
+
 CLSquaredDifferenceKernel::CLSquaredDifferenceKernel()
     : _input1(nullptr), _input2(nullptr), _output(nullptr)
 {
@@ -35,12 +66,9 @@ CLSquaredDifferenceKernel::CLSquaredDifferenceKernel()
 void CLSquaredDifferenceKernel::configure(const ICLTensor *input1, const ICLTensor *input2,
                                           ICLTensor *output)
 {
-  ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(input1->info()->tensor_shape(),
-                                              input2->info()->tensor_shape());
-  ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(input1->info()->tensor_shape(),
-                                              output->info()->tensor_shape());
   ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
   ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output);
+  ARM_COMPUTE_ERROR_THROW_ON(validate(input1->info(), input2->info(), output->info()));
 
   _input1 = input1;
   _input2 = input2;