Use NE/CLSynthetizeFunction instead of NE/CLQLSTMLayerNormalizationValidationFixture
Signed-off-by: Sheri Zhang <sheri.zhang@arm.com>
Change-Id: I62ace213a5261f2d307da6953d0521492aa05292
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3019
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
k->configure(std::forward<Args>(args)...);
_kernel = std::move(k);
}
+ /** Validate input arguments
+ *
+ * @param[in] args Configuration arguments.
+ */
+ template <typename... Args>
+ static Status validate(Args &&... args)
+ {
+ return K::validate(std::forward<Args>(args)...);
+ }
};
/** As above but this also setups a Zero border on the input tensor of the specified bordersize */
*/
#include "arm_compute/core/CL/kernels/CLQLSTMLayerNormalizationKernel.h"
#include "tests/CL/CLAccessor.h"
+#include "tests/CL/Helper.h"
#include "tests/PaddingCalculator.h"
#include "tests/datasets/ShapeDatasets.h"
#include "tests/framework/Asserts.h"
constexpr uint32_t vector_size_byte = 16;
using test::datasets::ShapeDataset;
+using CLQLSTMLayerNormalization = CLSynthetizeFunction<CLQLSTMLayerNormalizationKernel>;
template <uint32_t num_elements_per_iter, uint32_t num_batches, uint32_t num_iteration>
class QLSTMLayerNormShapeDataSet : public ShapeDataset
{
), input_info, weight_info, bias_info)
{
TensorInfo dummy_output{};
- const Status s = CLQLSTMLayerNormalizationKernel::validate(&input_info, &dummy_output, &weight_info, &bias_info);
+ const Status s = CLQLSTMLayerNormalization::validate(&input_info, &dummy_output, &weight_info, &bias_info);
ARM_COMPUTE_EXPECT(!bool(s), framework::LogLevel::ERRORS);
}
// *INDENT-ON*
template <typename T>
-using CLQLSTMLayerNormalizationFixture = CLQLSTMLayerNormalizationValidationFixture<CLTensor, CLAccessor, CLQLSTMLayerNormalizationKernel, T>;
+using CLQLSTMLayerNormalizationFixture = QLSTMLayerNormalizationValidationFixture<CLTensor, CLAccessor, CLQLSTMLayerNormalization, T>;
TEST_SUITE(Quantized)
TEST_SUITE(QSYMM16)
#include "arm_compute/runtime/Tensor.h"
#include "arm_compute/runtime/TensorAllocator.h"
#include "tests/NEON/Accessor.h"
+#include "tests/NEON/Helper.h"
#include "tests/PaddingCalculator.h"
#include "tests/datasets/ShapeDatasets.h"
#include "tests/framework/Asserts.h"
constexpr uint32_t vector_size_byte = 16;
using test::datasets::ShapeDataset;
+using NEQLSTMLayerNormalization = NESynthetizeFunction<NEQLSTMLayerNormalizationKernel>;
+
template <uint32_t num_elements_per_iter, uint32_t num_batches, uint32_t num_iteration>
class QLSTMLayerNormShapeDataSet : public ShapeDataset
{
),
input_info, weight_info, bias_info, output_info)
{
- const Status s = NEQLSTMLayerNormalizationKernel::validate(&input_info, &output_info, &weight_info, &bias_info);
+ const Status s = NEQLSTMLayerNormalization::validate(&input_info, &output_info, &weight_info, &bias_info);
ARM_COMPUTE_EXPECT(!bool(s), framework::LogLevel::ERRORS);
}
// *INDENT-ON*
template <typename T>
-using NEQLSTMLayerNormalizationFixture = NEQLSTMLayerNormalizationValidationFixture<Tensor, Accessor, NEQLSTMLayerNormalizationKernel, T>;
+using NEQLSTMLayerNormalizationFixture = QLSTMLayerNormalizationValidationFixture<Tensor, Accessor, NEQLSTMLayerNormalization, T>;
TEST_SUITE(Quantized)
TEST_SUITE(QSYMM16)
#include "arm_compute/core/TensorShape.h"
#include "arm_compute/core/Types.h"
-#ifdef ARM_COMPUTE_CL
-#include "arm_compute/runtime/CL/CLScheduler.h"
-#endif /* ARM_COMPUTE_CL */
-#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
#include "tests/IAccessor.h"
}
}
- virtual void run_target(FunctionType &fn) = 0;
-
TensorType compute_target(const TensorShape &input_shape, const TensorShape &weight_shape, const TensorShape &bias_shape)
{
TensorType input = create_tensor<TensorType>(input_shape, _data_type, 1);
fn.configure(&input, &output, &weight, &bias);
allocate_tensors({ &input, &weight, &bias, &output });
fill(AccessorType(input), AccessorType(weight), AccessorType(bias));
-
- run_target(fn);
+ fn.run();
return output;
}
DataType _data_type{};
QuantizationInfo _qinfo{};
};
-
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class NEQLSTMLayerNormalizationValidationFixture : public QLSTMLayerNormalizationValidationFixture<TensorType, AccessorType, FunctionType, T>
-{
-protected:
- void run_target(FunctionType &fn) override
- {
- ThreadInfo tinfo;
- tinfo.cpu_info = &NEScheduler::get().cpu_info();
- fn.run(fn.window(), tinfo);
- }
-};
-
-#ifdef ARM_COMPUTE_CL
-template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class CLQLSTMLayerNormalizationValidationFixture : public QLSTMLayerNormalizationValidationFixture<TensorType, AccessorType, FunctionType, T>
-{
-protected:
- void run_target(FunctionType &fn) override
- {
- CLScheduler::get().default_init();
- fn.run(fn.window(), CLScheduler::get().queue());
- }
-};
-#endif /* ARM_COMPUTE_CL */
-
} // namespace validation
} // namespace test
} // namespace arm_compute