arm_compute v18.05
[platform/upstream/armcl.git] / tests / validation / GLES_COMPUTE / DirectConvolutionLayer.cpp
index eb3d307..2ff6678 100644 (file)
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
  *
  * SPDX-License-Identifier: MIT
  *
@@ -68,6 +68,13 @@ const auto data = combine(datasets::SmallDirectConvolutionShapes(),
                                                                  combine(framework::dataset::make("PadY", 0, 2),
                                                                          framework::dataset::make("KernelSize", { 3, 5 })))),
                                                   framework::dataset::make("NumKernels", { 1, 4, 8, 16 })))));
+/** Activation function Dataset*/
+const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
+{
+    ActivationLayerInfo(),
+    ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC),
+    ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+});
 } // namespace
 
 TEST_SUITE(GC)
@@ -78,7 +85,9 @@ using GCDirectConvolutionLayerFixture = DirectConvolutionValidationFixture<GCTen
 
 TEST_SUITE(Float)
 TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<half_float::half>, framework::DatasetMode::ALL, combine(data, framework::dataset::make("DataType", DataType::F16)))
+FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<half_float::half>, framework::DatasetMode::ALL, combine(combine(combine(data, framework::dataset::make("DataType", DataType::F16)),
+                                                                                                                    ActivationFunctionsDataset),
+                                                                                                                    framework::dataset::make("DataLayout", DataLayout::NCHW)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_fp16, tolerance_num);
@@ -86,7 +95,9 @@ FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<half_float::half>, f
 TEST_SUITE_END()
 
 TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(data, framework::dataset::make("DataType", DataType::F32)))
+FIXTURE_DATA_TEST_CASE(Run, GCDirectConvolutionLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(data, framework::dataset::make("DataType", DataType::F32)),
+                                                                                                                 ActivationFunctionsDataset),
+                                                                                                         framework::dataset::make("DataLayout", DataLayout::NCHW)))
 {
     // Validate output
     validate(GCAccessor(_target), _reference, tolerance_fp32);