[layer] Verify ln, bn layers with fp16
authorskykongkong8 <ss.kong@samsung.com>
Tue, 22 Aug 2023 04:33:23 +0000 (13:33 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 22 Aug 2023 08:03:04 +0000 (17:03 +0900)
    - issue : adding cosine similarity check in fp32/fp16 revealed that there was unmatched cosine similarity Tensors in case of near-zero Tensors. Nevertheless, absolute value difference and mse pass our epsilon value. We would better to come back here for sanity check.
    - Same result for multi-headed attention layer as well. (Only for near-zero Tensors)
    - Added skip_cosine_similarity_check param to avoid this issue
    - Macro for enable-fp16 option

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/layers/bn_layer.cpp
nntrainer/layers/layer_normalization_layer.cpp
test/unittest/layers/layers_common_tests.h
test/unittest/layers/layers_golden_tests.cpp
test/unittest/layers/unittest_layers_batch_normalization.cpp
test/unittest/layers/unittest_layers_layer_normalization.cpp
test/unittest/layers/unittest_layers_multi_head_attention.cpp
test/unittest/layers/unittest_layers_positional_encoding.cpp

index 17d38c8..1723ac6 100644 (file)
@@ -67,13 +67,11 @@ void BatchNormalizationLayer::finalize(InitLayerContext &context) {
   auto &weight_decay = std::get<props::WeightDecay>(bn_props);
   auto &bias_decay = std::get<props::BiasDecay>(bn_props);
 
-  std::vector<TensorDim> output_dims(1);
-
   /** set output dimensions */
   auto const &in_dim = context.getInputDimensions()[0];
   context.setOutputDimensions(context.getInputDimensions());
 
-  TensorDim dim;
+  TensorDim dim(context.getFormat(), context.getWeightDataType());
 
   /// @note this logic cannot tell channel is actually 1 or it is just not used.
   auto &axis_prop = std::get<props::Axis>(bn_props);
index bd9127a..2046662 100644 (file)
@@ -73,7 +73,7 @@ void LayerNormalizationLayer::finalize(InitLayerContext &context) {
     std::unique(normalize_axes.begin(), normalize_axes.end()),
     normalize_axes.end());
 
-  TensorDim normalize_dim;
+  TensorDim normalize_dim(context.getFormat(), context.getWeightDataType());
   for (unsigned int axis : normalize_axes) {
     normalize_dim.setTensorDim(axis, input_dim.getTensorDim(axis));
   }
@@ -85,7 +85,7 @@ void LayerNormalizationLayer::finalize(InitLayerContext &context) {
     normalize_dim, beta_initializer, WeightRegularizer::NONE, 1.0f, bias_decay,
     "beta", true);
 
-  TensorDim remain_dim;
+  TensorDim remain_dim(context.getFormat(), context.getWeightDataType());
   std::vector<unsigned int> total_axes;
   total_axes.resize(ml::train::TensorDim::MAXDIM);
   std::iota(total_axes.begin(), total_axes.end(), 0u);
@@ -148,11 +148,11 @@ void LayerNormalizationLayer::forwarding(RunLayerContext &context,
   input.average(normalize_axes, temp_norm_size);
   input.subtract(temp_norm_size, deviation);
 
-  deviation.pow(2.0f, temp_full_size);
+  deviation.pow(2.0, temp_full_size);
   temp_full_size.average(normalize_axes, variance);
 
   variance.add_i(epsilon);
-  variance.pow(-0.5f, inv_std_dev);
+  variance.pow(-0.5, inv_std_dev);
 
   deviation.multiply(inv_std_dev, output);
   output.multiply_i(gamma);
@@ -161,7 +161,12 @@ void LayerNormalizationLayer::forwarding(RunLayerContext &context,
 
 void LayerNormalizationLayer::calcDerivative(RunLayerContext &context) {
   const bool trainable = context.getTrainable();
+
+  TensorDim::TensorType weight_tensor_type =
+    context.getWeight(wt_idx[LNParams::gamma]).getTensorType();
+
   Tensor empty;
+  empty.setTensorType(weight_tensor_type);
 
   Tensor &outgoing_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
   const Tensor &incoming_derivative =
@@ -200,7 +205,6 @@ void LayerNormalizationLayer::calcDerivative(RunLayerContext &context) {
 
 void LayerNormalizationLayer::calcGradient(RunLayerContext &context) {
   /** d_gamma is calculated in calcDerivative. d_beta is calculated here */
-
   const Tensor &incoming_derivative =
     context.getIncomingDerivative(SINGLE_INOUT_IDX);
   Tensor &d_beta = context.getWeightGrad(wt_idx[LNParams::beta]);
index abaf6a8..19e62cd 100644 (file)
@@ -99,6 +99,10 @@ typedef enum {
 
   DROPOUT_MATCH_60_PERCENT = 1 << 3, /**< set if only 60 percentage output
                                match is sufficient for dropout */
+
+  SKIP_COSINE_SIMILARITY =
+    1 << 4, /**< skip for zero error but large cos similarity case for now*/
+
   DEFAULT =
     0, /**< default set up, compare forward, backward in training mode */
 } LayerGoldenTestParamOptions;
@@ -167,6 +171,13 @@ public:
    * @return bool true if should skip calculating Gradient
    */
   bool shouldSkipCalcGrad();
+
+  /**
+   * @brief check if given test suite should skip cosine similarity check
+   *
+   * @return bool true if should skip cosine similarity check
+   */
+  bool shouldSkipCosineSimilarity();
 };
 
 #endif // __LAYERS_COMMON_TESTS_H__
index 139d933..4980cde 100644 (file)
@@ -55,6 +55,8 @@ createInitContext(Layer *layer, const std::string &input_shape_str,
   std::vector<shape_parser_> parsed;
   from_string(input_shape_str, parsed);
 
+  /// @todo tensor_type should not affect input layer data type since
+  /// technically a layer should not have information about its previous layer
   for (auto &par : parsed) {
     par.get().setFormat(
       str_converter<enum_class_prop_tag,
@@ -109,7 +111,7 @@ static TensorPacks prepareTensors(const InitLayerContext &context,
     vg.reserve(specs.size());
 
     for (auto &spec : specs) {
-      /// todo initializer should be depending is as well
+      /// @todo initializer should be depending is as well
       vg.emplace_back(spec.variable_spec.dim, Tensor::Initializer::NONE, true,
                       true, "golden");
     }
@@ -169,11 +171,12 @@ static RunLayerContext prepareRunContext(const TensorPacks &packs) {
 
 static void compareRunContext(RunLayerContext &rc, std::ifstream &file,
                               bool skip_grad, bool skip_deriv,
-                              bool dropout_match) {
+                              bool dropout_match, bool skip_cos_sim) {
   file.seekg(0, std::ios::beg);
 
   auto compare_percentage_tensors = [](const Tensor &t1, const Tensor &t2,
-                                       unsigned int match_percentage) -> bool {
+                                       unsigned int match_percentage,
+                                       bool skip_cos_sim) -> bool {
     if (t1.getDim() != t2.getDim())
       return false;
 
@@ -184,6 +187,17 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file,
         t2.getDim().getDataType() == ml::train::TensorDim::DataType::FP32) {
 
       if (match_percentage == 100) {
+
+        if (!skip_cos_sim) {
+          auto tensor = t1.clone();
+          auto answer = t2.clone();
+          const float epsilon = 1e-6;
+
+          auto cos_sim = cosine_similarity<float>(
+            answer.getData<float>(), tensor.getData<float>(), tensor.size());
+          EXPECT_IN_RANGE(cos_sim, 1 - epsilon, 1 + epsilon);
+        }
+
         EXPECT_EQ(t1, t2);
         return t1 == t2;
       }
@@ -200,6 +214,7 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file,
                                  (d1 != 0 && float_eq(d2, 0)),
                                1);
       }
+
       return (weak_match == total);
     } else if (t1.getDim().getDataType() ==
                  ml::train::TensorDim::DataType::FP16 &&
@@ -209,11 +224,18 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file,
       for (unsigned int idx = 0; idx < total; idx++) {
         auto d1 = t1.getValue<_FP16>(idx);
         auto d2 = t2.getValue<_FP16>(idx);
-        auto float_eq = [](_FP16 a, _FP16 b) {
-          constexpr auto eps = 1e-2;
-          if (a < b)
-            std::swap(a, b);
-          return (a - b) < eps;
+        auto float_eq = [skip_cos_sim](_FP16 a, _FP16 b) {
+          if (skip_cos_sim) {
+            constexpr auto eps = 1e-1;
+            if (a < b)
+              std::swap(a, b);
+            return (a - b) < eps;
+          } else {
+            constexpr auto eps = 1e-2;
+            if (a < b)
+              std::swap(a, b);
+            return (a - b) < eps;
+          }
         };
         /** either both the values must be equal or 1 must be zero */
         weak_match += std::min(float_eq(d1, d2) + (float_eq(d1, 0) && d2 != 0) +
@@ -225,12 +247,14 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file,
       auto tensor = t1.clone();
       auto answer = t2.clone();
 
-      auto cos_sim = cosine_similarity<_FP16>(
-        answer.getData<_FP16>(), tensor.getData<_FP16>(), tensor.size());
+      if (!skip_cos_sim) {
+        auto cos_sim = cosine_similarity<_FP16>(
+          answer.getData<_FP16>(), tensor.getData<_FP16>(), tensor.size());
+        EXPECT_IN_RANGE(cos_sim, 1 - epsilon, 1 + epsilon);
+      }
+
       auto mean_squared_error = mse<_FP16>(
         answer.getData<_FP16>(), answer.getData<_FP16>(), tensor.size());
-
-      EXPECT_IN_RANGE(cos_sim, 1 - epsilon, 1 + epsilon);
       EXPECT_IN_RANGE(mean_squared_error, 0, epsilon);
 
       return (weak_match == total);
@@ -243,7 +267,8 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file,
 
   auto compare_tensors = [&file, compare_percentage_tensors](
                            unsigned length, auto tensor_getter, auto pred,
-                           bool skip_compare, const std::string &name,
+                           bool skip_compare, bool skip_cos_sim,
+                           const std::string &name,
                            unsigned int match_percentage = 100) {
     for (unsigned i = 0; i < length; ++i) {
       if (!pred(i)) {
@@ -256,7 +281,8 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file,
       if (skip_compare) {
         continue;
       }
-      EXPECT_TRUE(compare_percentage_tensors(tensor, answer, match_percentage))
+      EXPECT_TRUE(compare_percentage_tensors(tensor, answer, match_percentage,
+                                             skip_cos_sim))
         << name << " at " << std::to_string(i);
     }
   };
@@ -274,22 +300,23 @@ static void compareRunContext(RunLayerContext &rc, std::ifstream &file,
 
   compare_tensors(rc.getNumWeights(),
                   [&rc](unsigned idx) { return rc.getWeight(idx); },
-                  always_read, skip_compare, "initial_weights");
+                  always_read, skip_compare, skip_cos_sim, "initial_weights");
   compare_tensors(rc.getNumInputs(),
                   [&rc](unsigned idx) { return rc.getInput(idx); }, always_read,
-                  !skip_compare, "inputs");
-  compare_tensors(rc.getNumOutputs(),
-                  [&rc](unsigned idx) { return rc.getOutput(idx); },
-                  always_read, !skip_compare, "outputs", match_percentage);
+                  !skip_compare, skip_cos_sim, "inputs");
+  compare_tensors(
+    rc.getNumOutputs(), [&rc](unsigned idx) { return rc.getOutput(idx); },
+    always_read, !skip_compare, skip_cos_sim, "outputs", match_percentage);
   compare_tensors(rc.getNumWeights(),
                   [&rc](unsigned idx) { return rc.getWeightGrad(idx); },
-                  only_read_trainable, skip_grad, "gradients");
+                  only_read_trainable, skip_grad, skip_cos_sim, "gradients");
   compare_tensors(rc.getNumWeights(),
                   [&rc](unsigned idx) { return rc.getWeight(idx); },
-                  always_read, !skip_compare, "weights");
+                  always_read, !skip_compare, skip_cos_sim, "weights");
   compare_tensors(rc.getNumInputs(),
                   [&rc](unsigned idx) { return rc.getOutgoingDerivative(idx); },
-                  always_read, skip_deriv, "derivatives", match_percentage);
+                  always_read, skip_deriv, skip_cos_sim, "derivatives",
+                  match_percentage);
 }
 
 LayerGoldenTest::~LayerGoldenTest() {}
@@ -318,6 +345,11 @@ bool LayerGoldenTest::shouldSkipCalcGrad() {
          LayerGoldenTestParamOptions::SKIP_CALC_GRAD;
 }
 
+bool LayerGoldenTest::shouldSkipCosineSimilarity() {
+  return std::get<int>(GetParam()) &
+         LayerGoldenTestParamOptions::SKIP_COSINE_SIMILARITY;
+}
+
 TEST_P(LayerGoldenTest, run) {
   auto f = std::get<0>(GetParam());
   auto layer = f(std::get<1>(GetParam()));
@@ -337,6 +369,7 @@ TEST_P(LayerGoldenTest, run) {
   bool skip_calc_grad = shouldSkipCalcGrad();
   bool skip_calc_deriv = shouldSkipCalcDeriv();
   bool dropout_compare_60_percent = shouldMatchDropout60Percent();
+  bool skip_cos_sim = shouldSkipCosineSimilarity();
 
   for (int i = 0; i < 4; ++i) {
     /// warm layer multiple times
@@ -352,7 +385,7 @@ TEST_P(LayerGoldenTest, run) {
   }
 
   compareRunContext(rc, golden_file, skip_calc_grad, skip_calc_deriv,
-                    dropout_compare_60_percent);
+                    dropout_compare_60_percent, skip_cos_sim);
 
   EXPECT_TRUE(true); // stub test for tcm
 }
index 4922bec..6920fea 100644 (file)
@@ -28,10 +28,11 @@ auto bn_inference_option = LayerGoldenTestParamOptions::SKIP_CALC_GRAD |
                            LayerGoldenTestParamOptions::SKIP_CALC_DERIV |
                            LayerGoldenTestParamOptions::FORWARD_MODE_INFERENCE;
 
+auto bn_option = LayerGoldenTestParamOptions::SKIP_COSINE_SIMILARITY;
+
 auto bn_basic_channels_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:4:2:3",
-  "bn_channels_training.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32", "fp32");
+  "bn_channels_training.nnlayergolden", bn_option, "nchw", "fp32", "fp32");
 
 auto bn_basic_channels_inference = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:4:2:3",
@@ -53,3 +54,31 @@ GTEST_PARAMETER_TEST(BatchNormalization, LayerGoldenTest,
                                        bn_basic_channels_inference,
                                        bn_basic_width_training,
                                        bn_basic_width_inference));
+
+#ifdef ENABLE_FP16
+auto bn_basic_channels_training_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:4:2:3",
+  "bn_channels_training_fp16fp16.nnlayergolden", bn_option, "nchw", "fp16",
+  "fp16");
+
+auto bn_basic_channels_inference_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:4:2:3",
+  "bn_channels_inference_fp16fp16.nnlayergolden", bn_inference_option, "nchw",
+  "fp16", "fp16");
+
+auto bn_basic_width_training_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:1:1:10",
+  "bn_width_training_fp16fp16.nnlayergolden", bn_option, "nchw", "fp16",
+  "fp16");
+
+auto bn_basic_width_inference_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:1:1:10",
+  "bn_width_inference_fp16fp16.nnlayergolden", bn_inference_option, "nchw",
+  "fp16", "fp16");
+
+GTEST_PARAMETER_TEST(BatchNormalization16, LayerGoldenTest,
+                     ::testing::Values(bn_basic_channels_training_fp16fp16,
+                                       bn_basic_channels_inference_fp16fp16,
+                                       bn_basic_width_training_fp16fp16,
+                                       bn_basic_width_inference_fp16fp16));
+#endif
index c2f537d..a653239 100644 (file)
@@ -21,45 +21,84 @@ auto semantic_layer_normalization = LayerSemanticsParamType(
   nntrainer::LayerNormalizationLayer::type, {"axis=1"},
   LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1);
 
+auto ln_option = LayerGoldenTestParamOptions::SKIP_COSINE_SIMILARITY;
+
 GTEST_PARAMETER_TEST(LayerNormalization, LayerSemantics,
                      ::testing::Values(semantic_layer_normalization));
 
 auto ln_axis_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1"},
-  "2:4:2:3", "ln_axis_1.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32", "fp32");
+  "2:4:2:3", "ln_axis_1.nnlayergolden", ln_option, "nchw", "fp32", "fp32");
 
 auto ln_axis_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=2"},
-  "2:4:2:3", "ln_axis_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32", "fp32");
+  "2:4:2:3", "ln_axis_2.nnlayergolden", ln_option, "nchw", "fp32", "fp32");
 
 auto ln_axis_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=3"},
-  "2:4:2:3", "ln_axis_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32", "fp32");
+  "2:4:2:3", "ln_axis_3.nnlayergolden", ln_option, "nchw", "fp32", "fp32");
 
 auto ln_axis_1_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 2"},
-  "2:4:2:3", "ln_axis_1_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32", "fp32");
+  "2:4:2:3", "ln_axis_1_2.nnlayergolden", ln_option, "nchw", "fp32", "fp32");
 
 auto ln_axis_2_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=2, 3"},
-  "2:4:2:3", "ln_axis_2_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32", "fp32");
+  "2:4:2:3", "ln_axis_2_3.nnlayergolden", ln_option, "nchw", "fp32", "fp32");
 
 auto ln_axis_1_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 3"},
-  "2:4:2:3", "ln_axis_1_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32", "fp32");
+  "2:4:2:3", "ln_axis_1_3.nnlayergolden", ln_option, "nchw", "fp32", "fp32");
 
 auto ln_axis_1_2_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 2, 3"},
-  "2:4:2:3", "ln_axis_1_2_3.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
+  "2:4:2:3", "ln_axis_1_2_3.nnlayergolden", ln_option, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(LayerNormalization, LayerGoldenTest,
                      ::testing::Values(ln_axis_1, ln_axis_2, ln_axis_3,
                                        ln_axis_1_2, ln_axis_2_3, ln_axis_1_3,
                                        ln_axis_1_2_3));
+
+#ifdef ENABLE_FP16
+auto ln_axis_1_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1"},
+  "2:4:2:3", "ln_axis_1_fp16fp16.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16");
+
+auto ln_axis_2_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=2"},
+  "2:4:2:3", "ln_axis_2_fp16fp16.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16");
+
+auto ln_axis_3_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=3"},
+  "2:4:2:3", "ln_axis_3_fp16fp16.nnlayergolden", ln_option, "nchw", "fp16",
+  "fp16");
+
+auto ln_axis_1_2_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 2"},
+  "2:4:2:3", "ln_axis_1_2_fp16fp16.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16");
+
+auto ln_axis_2_3_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=2, 3"},
+  "2:4:2:3", "ln_axis_2_3_fp16fp16.nnlayergolden", ln_option, "nchw", "fp16",
+  "fp16");
+
+auto ln_axis_1_3_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 3"},
+  "2:4:2:3", "ln_axis_1_3_fp16fp16.nnlayergolden", ln_option, "nchw", "fp16",
+  "fp16");
+
+auto ln_axis_1_2_3_fp16fp16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 2, 3"},
+  "2:4:2:3", "ln_axis_1_2_3_fp16fp16.nnlayergolden", ln_option, "nchw", "fp16",
+  "fp16");
+
+GTEST_PARAMETER_TEST(LayerNormalization16, LayerGoldenTest,
+                     ::testing::Values(ln_axis_1_fp16fp16, ln_axis_2_fp16fp16,
+                                       ln_axis_3_fp16fp16, ln_axis_1_2_fp16fp16,
+                                       ln_axis_2_3_fp16fp16,
+                                       ln_axis_1_3_fp16fp16,
+                                       ln_axis_1_2_3_fp16fp16));
+#endif
index 1dfb335..acaf31d 100644 (file)
@@ -33,17 +33,19 @@ GTEST_PARAMETER_TEST(
   ::testing::Values(semantic_multi_head_attention,
                     semantic_multi_head_attention_with_mask));
 
+auto no_cos_sim_option = LayerGoldenTestParamOptions::SKIP_COSINE_SIMILARITY;
+
 auto multi_head_attention_single_batch = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3"}, "1:1:5:7,1:1:3:7,1:1:3:7",
-  "multi_head_attention_single_batch.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
+  "multi_head_attention_single_batch.nnlayergolden", no_cos_sim_option, "nchw",
+  "fp32", "fp32");
 
 auto multi_head_attention = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3"}, "2:1:5:7,2:1:3:7,2:1:3:7",
-  "multi_head_attention.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32", "fp32");
+  "multi_head_attention.nnlayergolden", no_cos_sim_option, "nchw", "fp32",
+  "fp32");
 
 auto multi_head_attention_return_attention_scores = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
@@ -51,19 +53,19 @@ auto multi_head_attention_return_attention_scores = LayerGoldenTestParamType(
    "average_attention_weight=false"},
   "2:1:5:7,2:1:3:7,2:1:3:7",
   "multi_head_attention_return_attention_scores.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
+  no_cos_sim_option, "nchw", "fp32", "fp32");
 
 auto multi_head_attention_value_dim = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3", "projected_value_dim=5"},
   "2:1:5:7,2:1:3:7,2:1:3:7", "multi_head_attention_value_dim.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
+  no_cos_sim_option, "nchw", "fp32", "fp32");
 
 auto multi_head_attention_output_shape = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3", "output_shape=5"},
   "2:1:5:7,2:1:3:7,2:1:3:7", "multi_head_attention_output_shape.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
+  no_cos_sim_option, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(
   MultiHeadAttention, LayerGoldenTest,
index 039b0c3..fa14897 100644 (file)
@@ -39,6 +39,7 @@ INSTANTIATE_TEST_CASE_P(PositionalEncoding, LayerGoldenTest,
                         ::testing::Values(positional_encoding_partial,
                                           positional_encoding));
 
+#ifdef ENABLE_FP16
 auto positional_encoding_partial_fp16fp16 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::PositionalEncodingLayer>,
   {"max_timestep=10"}, "3:1:7:6",
@@ -50,6 +51,7 @@ auto positional_encoding_fp16fp16 = LayerGoldenTestParamType(
   {"max_timestep=10"}, "3:1:10:6", "positional_encoding_fp16fp16.nnlayergolden",
   LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp16", "fp16");
 
-INSTANTIATE_TEST_CASE_P(PositionalEncoding16, LayerGoldenTest,
-                        ::testing::Values(positional_encoding_partial_fp16fp16,
-                                          positional_encoding_fp16fp16));
+GTEST_PARAMETER_TEST(PositionalEncoding16, LayerGoldenTest,
+                     ::testing::Values(positional_encoding_partial_fp16fp16,
+                                       positional_encoding_fp16fp16));
+#endif