*/
#include "arm_compute/core/CL/kernels/CLSquaredDifferenceKernel.h"
+#include "arm_compute/core/AccessWindowStatic.h"
#include "arm_compute/core/CL/CLHelpers.h"
+#include "arm_compute/core/CL/CLKernelLibrary.h"
#include "arm_compute/core/CL/CLKernelLibraryEx.h"
#include "arm_compute/core/CL/ICLTensor.h"
+#include "arm_compute/core/Error.h"
+#include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/Window.h"
using namespace arm_compute;
constexpr unsigned int num_elems_processed_per_iteration = 16;
}
+Status validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+{
+ const TensorShape &out_shape =
+ TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
+
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16,
+ DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16,
+ DataType::F16, DataType::F32);
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0,
+ "Inputs are not broadcast compatible");
+ // Validate in case of configured output
+ if (output->total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16,
+ DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ detail::have_different_dimensions(out_shape, output->tensor_shape(), 0),
+ "Wrong shape for output");
+ }
+ return Status{};
+}
+
CLSquaredDifferenceKernel::CLSquaredDifferenceKernel()
: _input1(nullptr), _input2(nullptr), _output(nullptr)
{
void CLSquaredDifferenceKernel::configure(const ICLTensor *input1, const ICLTensor *input2,
ICLTensor *output)
{
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(input1->info()->tensor_shape(),
- input2->info()->tensor_shape());
- ARM_COMPUTE_ERROR_ON_MISMATCHING_DIMENSIONS(input1->info()->tensor_shape(),
- output->info()->tensor_shape());
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, input2);
ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input1, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate(input1->info(), input2->info(), output->info()));
_input1 = input1;
_input2 = input2;