[ Mixed Tensor ] add tensor type property in initContext
authorjijoong.moon <jijoong.moon@samsung.com>
Thu, 27 Jul 2023 00:14:57 +0000 (09:14 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 21 Aug 2023 06:29:23 +0000 (15:29 +0900)
This PR add the tensor type (Format, Weight Tensor DataType,
Activation Tensor DataType) in initContext.
- Remove the tensor type variables and setter, getter member function
in layer, layer_devel, loss layer etc.
- add tensor type setter in initContext
- set the var_grad ( input & ouput ) Tensor Type according to model
Tensor Data Type.
- Add ModelTensorTypeInfo : eg. FP16_FP16 ( Weight FP16, Activation
FP16 )

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
33 files changed:
api/ccapi/include/layer.h
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h
nntrainer/layers/fc_layer.cpp
nntrainer/layers/layer_context.cpp
nntrainer/layers/layer_context.h
nntrainer/layers/layer_devel.h
nntrainer/layers/layer_node.cpp
nntrainer/layers/layer_node.h
nntrainer/layers/loss/loss_layer.h
nntrainer/models/model_common_properties.h
nntrainer/models/neuralnet.cpp
nntrainer/models/neuralnet.h
nntrainer/tensor/manager.h
nntrainer/utils/base_properties.h
test/unittest/layers/layers_common_tests.h
test/unittest/layers/layers_golden_tests.cpp
test/unittest/layers/unittest_layers_attention.cpp
test/unittest/layers/unittest_layers_batch_normalization.cpp
test/unittest/layers/unittest_layers_concat.cpp
test/unittest/layers/unittest_layers_convolution1d.cpp
test/unittest/layers/unittest_layers_convolution2d.cpp
test/unittest/layers/unittest_layers_dropout.cpp
test/unittest/layers/unittest_layers_fully_connected.cpp
test/unittest/layers/unittest_layers_gru.cpp
test/unittest/layers/unittest_layers_grucell.cpp
test/unittest/layers/unittest_layers_layer_normalization.cpp
test/unittest/layers/unittest_layers_lstm.cpp
test/unittest/layers/unittest_layers_lstmcell.cpp
test/unittest/layers/unittest_layers_multi_head_attention.cpp
test/unittest/layers/unittest_layers_positional_encoding.cpp
test/unittest/layers/unittest_layers_rnn.cpp
test/unittest/layers/unittest_layers_rnncell.cpp

index 9a6952d..9690b74 100644 (file)
@@ -168,13 +168,6 @@ public:
   virtual void setProperty(const std::vector<std::string> &values) = 0;
 
   /**
-   * @brief     Set Tensor format & data type
-   * @note      This is used mainly for the unittest case which does not have
-   * model.
-   */
-  virtual void setTensorType(std::array<const std::string, 2> type){};
-
-  /**
    * @brief     Get name of the layer
    * @retval    name of the layer
    * @note      This name is unique to this layer in a model
index b316525..717e382 100644 (file)
@@ -81,13 +81,6 @@ int NetworkGraph::compile(const std::string &loss_type) {
 
   inPlaceOptimize();
 
-  for (auto iter = cbegin(); iter != cend(); iter++) {
-    auto lnode = (*iter);
-    /// @todo  later, we can set layer tensor type differenctly with model
-    /// tensor type
-    lnode->setTensorType(getTensorType());
-  }
-
   status = checkCompiledGraph();
   NN_RETURN_STATUS();
 
@@ -719,7 +712,7 @@ NetworkGraph::finalizeContext(const std::shared_ptr<LayerNode> &lnode,
                  [](const Var_Grad *vg) { return vg->getDim(); });
 
   /** finalize the layer and get the final context */
-  auto init_context = lnode->finalize(input_dims);
+  auto init_context = lnode->finalize(input_dims, getTensorType());
 
   /**
    * Request manager for either a pre-allocated output as input or a newly
index 34292e5..39c7963 100644 (file)
@@ -50,7 +50,7 @@ public:
     optimize_memory(true),
     exec_mode(ExecutionMode::TRAIN),
     tensor_format("NCHW"),
-    tensor_dtype("FP32") {}
+    tensor_dtype(split("FP32_FP32", std::regex("\\_"))) {}
 
   /**
    * @brief     Constructor of NeuralNetwork Graph Class
@@ -60,7 +60,7 @@ public:
   NetworkGraph(bool enable_swap, const std::string &swap_path = "",
                unsigned int lookahead = 0,
                const std::string &tensor_format_ = "NCHW",
-               const std::string &tensor_dtype_ = "FP32") :
+               const std::string &tensor_dtype_ = "FP32_FP32") :
     tensor_manager(std::make_shared<Manager>(enable_swap, swap_path, lookahead,
                                              tensor_format_, tensor_dtype_)),
     graph(),
@@ -72,7 +72,7 @@ public:
     optimize_memory(true),
     exec_mode(ExecutionMode::TRAIN),
     tensor_format(tensor_format_),
-    tensor_dtype(tensor_dtype_) {}
+    tensor_dtype(split(tensor_dtype_, std::regex("\\_"))) {}
 
   /**
    * @brief   Destructor of the NeuralNetwork Graph class
@@ -376,8 +376,9 @@ public:
    *
    * @return TensorDim::Format NCHW or NHWC
    */
-  std::array<const std::string, 2> getTensorType() {
-    return {tensor_format, tensor_dtype};
+  std::array<const std::string, 3> getTensorType() {
+
+    return {tensor_format, tensor_dtype[0], tensor_dtype[1]};
   };
 
   /**
@@ -435,7 +436,7 @@ private:
 
   std::string tensor_format; /**< Model Tensor Format: NCHW or NHWC */
 
-  std::string tensor_dtype; /**< Model Tensor Type: FP32, FP16 */
+  std::vector<std::string> tensor_dtype; /**< Model Tensor Type: FP32, FP16 */
 
   std::unordered_map<std::string, int>
     profile_keys; /**< profile keys based on the layer type */
index 81c9cdc..61ba0e4 100644 (file)
@@ -78,12 +78,12 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
   // global configuration
   TensorDim bias_dim(
     1, 1, 1, unit,
-    TensorDim::TensorType(getTensorFormat(), TensorDim::DataType::FP32),
+    TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
     0b0001);
 
   TensorDim weight_dim(
     1, 1, in_dim.width(), unit,
-    TensorDim::TensorType(getTensorFormat(), TensorDim::DataType::FP32),
+    TensorDim::TensorType(context.getFormat(), context.getWeightDataType()),
     0b0011);
 
   weight_idx[FCParams::weight] = context.requestWeight(
index 8ec36b8..6d55d66 100644 (file)
@@ -37,18 +37,18 @@ static void suffixSpec(VarGradSpecV2 &spec, unsigned int idx) {
   }
 }
 
-InitLayerContext::InitLayerContext(const std::vector<TensorDim> &dim,
-                                   const std::vector<bool> &req_out_connected,
-                                   bool in_place_, const std::string &n,
-                                   const std::string &prefix_,
-                                   const float max_norm) :
+InitLayerContext::InitLayerContext(
+  const std::vector<TensorDim> &dim, const std::vector<bool> &req_out_connected,
+  bool in_place_, const std::string &n, const std::string &prefix_,
+  const float max_norm, std::array<const std::string, 3> tensor_type_) :
   input_dim(dim),
   in_place(in_place_),
   clip_by_global_norm(max_norm),
   output_specs(),
   req_out_is_connected(req_out_connected),
   name(n),
-  prefix(prefix_) {
+  prefix(prefix_),
+  tensor_type(tensor_type_) {
   NNTR_THROW_IF(!validate(), std::invalid_argument)
     << "Invalid init context name: " << name
     << " num inputs: " << getNumInputs();
@@ -67,6 +67,21 @@ void InitLayerContext::setOutputDimensions(
 
   for (unsigned i = 0u, sz = out_dim.size(); i < sz; ++i) {
     auto spec = outSpec(out_dim.at(i));
+
+    spec.variable_spec.dim.setFormat(
+      str_converter<enum_class_prop_tag,
+                    nntrainer::TensorFormatInfo>::from_string(tensor_type[0]));
+    spec.variable_spec.dim.setDataType(
+      str_converter<enum_class_prop_tag, nntrainer::TensorDataTypeInfo>::
+        from_string(tensor_type[2]));
+
+    spec.gradient_spec->dim.setFormat(
+      str_converter<enum_class_prop_tag,
+                    nntrainer::TensorFormatInfo>::from_string(tensor_type[0]));
+    spec.gradient_spec->dim.setDataType(
+      str_converter<enum_class_prop_tag, nntrainer::TensorDataTypeInfo>::
+        from_string(tensor_type[2]));
+
     specs.push_back(std::move(spec));
   }
 
@@ -499,8 +514,18 @@ bool RunLayerContext::validate(bool skip_input, bool skip_label) {
   if (tensor_map.empty() || !tensor_map[inputs[0]->getName()]) {
     auto filler = [this](const auto &vec) {
       for (auto const &val : vec) {
-        tensor_map[val->getName()] = val->getVariableRef().getData<float>();
-        tensor_map[val->getGradientName()] = val->getGradientRef().getData<float>();
+        if (val->getVariableRef().getTensorType().data_type ==
+            TensorDim::DataType::FP32) {
+          tensor_map[val->getName()] = val->getVariableRef().getData<float>();
+          tensor_map[val->getGradientName()] =
+            val->getGradientRef().getData<float>();
+        } else if (val->getVariableRef().getTensorType().data_type ==
+                   TensorDim::DataType::FP32) {
+          tensor_map[val->getName()] =
+            val->getVariableRef().getData<_Float16>();
+          tensor_map[val->getGradientName()] =
+            val->getGradientRef().getData<_Float16>();
+        }
       }
     };
 
index 3d8f617..738b385 100644 (file)
@@ -51,7 +51,38 @@ public:
   InitLayerContext(const std::vector<TensorDim> &dim,
                    const std::vector<bool> &req_out_connected, bool in_place_,
                    const std::string &n = "", const std::string &prefix_ = "",
-                   const float max_norm = 0.0);
+                   const float max_norm = 0.0,
+                   std::array<const std::string, 3> tensor_type_ = {
+                     "NCHW", "FP32", "FP32"});
+  /**
+   * @brief   get Tensor Format of Layer
+   *
+   * @return Tensor Format of the layer
+   */
+  TensorDim::Format getFormat() {
+    return str_converter<enum_class_prop_tag, nntrainer::TensorFormatInfo>::
+      from_string(tensor_type[0]);
+  };
+
+  /**
+   * @brief   get Tensor DataType of the Weight
+   *
+   * @return Tensor DataType of the the Weight
+   */
+  TensorDim::DataType getWeightDataType() {
+    return str_converter<enum_class_prop_tag, nntrainer::TensorDataTypeInfo>::
+      from_string(tensor_type[1]);
+  };
+
+  /**
+   * @brief   get Tensor DataType of the Activation
+   *
+   * @return Tensor DataType of the the Activation
+   */
+  TensorDim::DataType getActivationDataType() {
+    return str_converter<enum_class_prop_tag, nntrainer::TensorDataTypeInfo>::
+      from_string(tensor_type[2]);
+  };
 
   /**
    * @brief   get name by the layer
@@ -298,6 +329,7 @@ private:
   /**< a bool vector to tell if requested out is actually connected to others */
   std::string name;   /**< name of the layer */
   std::string prefix; /**< prefix of the layer */
+  std::array<const std::string, 3> tensor_type;
 };
 
 /**
index 7dba0e7..7e3c6af 100644 (file)
@@ -241,65 +241,6 @@ public:
    * @return true if supports backwarding, else false
    */
   virtual bool supportBackwarding() const = 0;
-
-  /**
-   * @brief Set the Tensor format for the layer
-   * @param     Tensor format : TensorDim::Format::NCHW or
-   * TneosrDim::Format::NHWC
-   */
-  virtual void setTensorFormat(
-    ml::train::TensorDim::Format form = ml::train::TensorDim::Format::NCHW) {
-    tensor_format = form;
-  }
-
-  /**
-   * @brief Set the Tensor Type for the layer
-   * @param     Tensor Type : FP32, FP16
-   */
-
-  virtual void setTensorDataType(
-    ml::train::TensorDim::DataType ty = ml::train::TensorDim::DataType::FP32) {
-    tensor_dtype = ty;
-  }
-
-  /**
-   * @brief set the Tensor Type for the layer
-   * @param     Tensor Type : NCHW or NHWC
-   */
-  void setTensorType(std::array<const std::string, 2> t_type) {
-    if (t_type[0].compare("NCHW") == 0 || t_type[0].compare("nchw") == 0) {
-      tensor_format = ml::train::TensorDim::Format::NCHW;
-    } else {
-      tensor_format = ml::train::TensorDim::Format::NHWC;
-    }
-
-    nntrainer::props::TensorDataType type_;
-
-    from_string(t_type[1], type_);
-
-    tensor_dtype = type_;
-  }
-
-  /**
-   * @brief get the Tensor Format for the layer
-   * @return     Tensor Format : TensorDim::Format::NCHW or
-   * TneosrDim::Format::NHWC
-   */
-  virtual ml::train::TensorDim::Format getTensorFormat() {
-    return tensor_format;
-  }
-
-  /**
-   * @brief get the Tensor Type for the layer
-   * @return     Tensor Type : FP16, Fp32
-   */
-  virtual ml::train::TensorDim::DataType getTensorDataType() {
-    return tensor_dtype;
-  }
-
-private:
-  ml::train::TensorDim::Format tensor_format;
-  ml::train::TensorDim::DataType tensor_dtype;
 };
 
 /// @todo Decide where to put and how to implement(#986)
index 65e1586..776430b 100644 (file)
@@ -248,14 +248,6 @@ void LayerNode::setOutputConnection(unsigned nth, const std::string &name,
   con = std::make_unique<Connection>(name, index);
 }
 
-void LayerNode::setTensorType(const std::string form_, const std::string ty_) {
-  setTensorType({form_, ty_});
-}
-
-void LayerNode::setTensorType(std::array<const std::string, 2> t_type) {
-  getLayer()->setTensorType(t_type);
-}
-
 const std::string LayerNode::getName() const noexcept {
   auto &name = std::get<props::Name>(*layer_node_props);
   return name.empty() ? "" : name.get();
@@ -505,7 +497,13 @@ void LayerNode::clearOptVar() {
 /**
  * @brief     Finalize creating the layer node
  */
-InitLayerContext LayerNode::finalize(const std::vector<TensorDim> &input_dims) {
+InitLayerContext
+LayerNode::finalize(const std::vector<TensorDim> &input_dims,
+                    std::array<const std::string, 3> tensor_type) {
+  // auto get_tensor_datatype = [](const std::string ty) -> TensorDim::DataType {
+  //                          return from_string(ty);
+  // };
+  
   if (run_context)
     throw std::runtime_error(
       "Trying to finalizing a layer which is already finalized in layer: " +
@@ -526,6 +524,14 @@ InitLayerContext LayerNode::finalize(const std::vector<TensorDim> &input_dims) {
       NNTR_THROW_IF(input_dims != actual_prop_dims, std::invalid_argument)
         << "calculated input dimension is different from given input_shape "
            "property";
+      for (auto d : actual_prop_dims) {
+        d.setDataType(
+          str_converter<enum_class_prop_tag, nntrainer::TensorDataTypeInfo>::
+            from_string(tensor_type[2]));
+        d.setFormat(
+          str_converter<enum_class_prop_tag, nntrainer::TensorFormatInfo>::
+            from_string(tensor_type[0]));
+      }
     }
   } else {
     NNTR_THROW_IF(!hasInputShapeProperty(), std::invalid_argument)
@@ -540,6 +546,15 @@ InitLayerContext LayerNode::finalize(const std::vector<TensorDim> &input_dims) {
       << prop_dims.size();
     actual_input_dims =
       std::vector<TensorDim>(prop_dims.begin(), prop_dims.end());
+    for (auto d : actual_input_dims) {
+      /// Input Tensor type of input layer needs to be float.
+      d.setDataType(
+        str_converter<enum_class_prop_tag,
+                      nntrainer::TensorDataTypeInfo>::from_string("FP32"));
+      d.setFormat(
+        str_converter<enum_class_prop_tag, nntrainer::TensorFormatInfo>::
+          from_string(tensor_type[0]));
+    }
   }
 
   NNTR_THROW_IF(actual_input_dims.size() < getNumInputConnections(),
@@ -576,7 +591,7 @@ InitLayerContext LayerNode::finalize(const std::vector<TensorDim> &input_dims) {
 
   auto context = InitLayerContext(actual_input_dims, out_info,
                                   executeInPlace() != InPlace::NONE, getName(),
-                                  scope, max_norm);
+                                  scope, max_norm, tensor_type);
 
   layer->finalize(context);
 
index 27c6b37..957f85e 100644 (file)
@@ -243,7 +243,9 @@ public:
    * will be made available during execution of the layer with the context.
    * @note configureRunContext() is expected to called right after this.
    */
-  InitLayerContext finalize(const std::vector<TensorDim> &input_dims = {});
+  InitLayerContext finalize(const std::vector<TensorDim> &input_dims = {},
+                            std::array<const std::string, 3> tensor_type = {
+                              "NCHW", "FP32", "FP32"});
 
   /**
    * @brief     Forward Propagation of a layer
@@ -801,23 +803,6 @@ public:
    */
   bool needsCalcGradient() { return needs_calc_gradient; }
 
-  /**
-   * @brief Set Tensor type for layer
-   *
-   * @param format NCHW : NHWC
-   * @param type FP16, FP32
-   */
-  using Layer::setTensorType;
-  void setTensorType(const std::string form_ = "NCHW",
-                     const std::string type_ = "FP32");
-  /**
-   * @brief Set Tensor type for layer
-   *
-   * @param format NCHW : NHWC
-   * @param type FP16, FP32
-   */
-  void setTensorType(std::array<const std::string, 2> t_type);
-
 private:
   /**
    * @brief     Get the Input Layers object
@@ -845,10 +830,6 @@ private:
   std::vector<std::unique_ptr<Connection>>
     output_connections; /**< output layer names */
 
-  TensorDim::Format tensor_format;
-
-  TensorDim::DataType tensor_dtype;
-
 #ifdef ENABLE_TEST
   /**
    * @brief   Init context which is stored for debugging issue
index 0307f5c..00b520f 100644 (file)
@@ -52,28 +52,6 @@ public:
    */
   bool requireLabel() const override { return true; }
 
-  /**
-   * @brief set the Tensor Type for the layer
-   * @param     Tensor Type : NCHW or NHWC
-   */
-  void setTensorType(std::array<const std::string, 2> t_type) {
-    if (t_type[0].compare("NCHW") == 0 || t_type[0].compare("nchw") == 0) {
-      tensor_format = ml::train::TensorDim::Format::NCHW;
-    } else {
-      tensor_format = ml::train::TensorDim::Format::NHWC;
-    }
-
-    nntrainer::props::TensorDataType type_;
-
-    from_string(t_type[1], type_);
-
-    tensor_dtype = type_;
-  }
-
-private:
-  ml::train::TensorDim::Format tensor_format;
-  ml::train::TensorDim::DataType tensor_dtype;
-
 protected:
   /**
    * @brief     update loss
index d3f7e27..aba254d 100644 (file)
@@ -179,6 +179,31 @@ public:
   MemorySwapLookahead(const unsigned int &value = 0);
 };
 
+/**
+ * @brief     Enumeration of Data Type for model & layer
+ */
+struct ModelTensorDataTypeInfo {
+  enum Enum { W16A16, W16A32, W32A16, W32A32 };
+  static constexpr std::initializer_list<Enum> EnumList = {
+    Enum::W16A16, Enum::W16A32, Enum::W32A16, Enum::W32A32};
+
+  static constexpr const char *EnumStr[] = {"FP16_FP16", "FP16_FP32",
+                                            "FP32_FP16", "FP32_FP32"};
+};
+
+/**
+ * @brief Activation Enumeration Information
+ *
+ */
+class ModelTensorDataType final : public EnumProperty<ModelTensorDataTypeInfo> {
+public:
+  using prop_tag = enum_class_prop_tag;
+  static constexpr const char *key = "model_tensor_type";
+  ModelTensorDataType(ModelTensorDataTypeInfo::Enum value =
+                        ModelTensorDataTypeInfo::Enum::W32A32) {
+    set(value);
+  };
+};
 
 } // namespace nntrainer::props
 
index 5cdafe2..e9f315c 100644 (file)
@@ -70,7 +70,7 @@ NeuralNetwork::NeuralNetwork() :
     props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
     props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
     props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead(),
-    props::TensorFormat(), props::TensorDataType()),
+    props::TensorFormat(), props::ModelTensorDataType()),
   load_path(std::string()),
   epoch_idx(0),
   iter(0),
@@ -88,7 +88,7 @@ NeuralNetwork::NeuralNetwork(AppContext app_context_) :
     props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
     props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
     props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead(),
-    props::TensorFormat(), props::TensorDataType()),
+    props::TensorFormat(), props::ModelTensorDataType()),
   load_path(std::string()),
   epoch_idx(0),
   iter(0),
@@ -174,10 +174,10 @@ int NeuralNetwork::compile() {
     std::get<props::MemorySwapLookahead>(model_flex_props);
 
   const std::string tensor_format =
-    std::get<props::TensorFormat>(model_flex_props);
+    to_string(std::get<props::TensorFormat>(model_flex_props));
 
   const std::string tensor_type =
-    to_string(std::get<props::TensorDataType>(model_flex_props));
+    to_string(std::get<props::ModelTensorDataType>(model_flex_props));
 
   model_graph = NetworkGraph(memory_swap, memory_swap_path, lookahead,
                              tensor_format, tensor_type);
index 5f509d1..035e675 100644 (file)
@@ -557,7 +557,7 @@ private:
                props::ContinueTrain, props::SaveBestPath,
                props::MemoryOptimization, props::MemorySwap,
                props::MemorySwapPath, props::MemorySwapLookahead,
-               props::TensorFormat, props::TensorDataType>;
+               props::TensorFormat, props::ModelTensorDataType>;
   using RigidPropTypes =
     std::tuple<props::LossType, std::vector<props::InputConnection>,
                std::vector<props::LabelLayer>, props::ClipGradByGlobalNorm>;
index 5c515f7..574a3ce 100644 (file)
@@ -133,21 +133,21 @@ public:
   Manager() :
     enable_optimizations(true),
     swap_lookahead(0),
-    tensor_format("nchw"),
-    tensor_dtype("fp32") {}
+    tensor_format("NCHW"),
+    tensor_dtype(split("FP32_FP32", std::regex("\\_"))) {}
 
   /**
    * @brief     Constructor of Manager
    */
   Manager(bool enable_swap, const std::string &swap_path = "",
-          unsigned int lookahead = 0, const std::string tensor_format_ = "nchw",
-          const std::string tensor_dtype_ = "fp32") :
+          unsigned int lookahead = 0, const std::string tensor_format_ = "NCHW",
+          const std::string tensor_dtype_ = "FP32_FP32") :
     weight_pool(enable_swap, swap_path, "weight_pool"),
     tensor_pool(enable_swap, swap_path, "tensor_pool"),
     enable_optimizations(true),
     swap_lookahead(lookahead),
     tensor_format(tensor_format_),
-    tensor_dtype(tensor_dtype_) {}
+    tensor_dtype(split(tensor_dtype_, std::regex("\\_"))) {}
 
   /**
    * @brief Construct a new Manager object (deleted)
@@ -510,7 +510,7 @@ private:
 
   std::string tensor_format;
 
-  std::string tensor_dtype;
+  std::vector<std::string> tensor_dtype;
 
   /**
    * @brief Finalize the given tensor pool
index 117b02b..25662ae 100644 (file)
@@ -653,6 +653,14 @@ struct TensorDataTypeInfo {
   static constexpr const char *EnumStr[] = {"FP16", "FP32"};
 };
 
+struct TensorFormatInfo {
+  using Enum = nntrainer::TensorDim::Format;
+  static constexpr std::initializer_list<Enum> EnumList = {Enum::NCHW,
+                                                           Enum::NHWC};
+
+  static constexpr const char *EnumStr[] = {"NCHW", "NHWC"};
+};
+
 namespace props {
 
 /**
@@ -673,18 +681,20 @@ public:
  * @brief model tensor type : NCHW or NHWC
  *
  */
-class TensorFormat : public nntrainer::Property<std::string> {
+class TensorFormat final : public EnumProperty<TensorFormatInfo> {
 public:
   static constexpr const char *key =
     "tensor_format";             /**< unique key to access */
-  using prop_tag = str_prop_tag; /**< property type */
+  using prop_tag = enum_class_prop_tag; /**< property type */
 
   /**
    * @brief Constructor
    *
    * @param value value to set, defaults to false
    */
-  TensorFormat(const std::string &value = "NCHW") { set(value); };
+  TensorFormat(TensorFormatInfo::Enum value = TensorFormatInfo::Enum::NCHW) {
+    set(value);
+  };
 };
 } // namespace props
 
index 0b33b9e..abaf6a8 100644 (file)
@@ -110,7 +110,9 @@ using LayerGoldenTestParamType =
              const char * /**< Golden file name */,
              int /**< LayerGoldenTestParamOptions */,
              std::string /** < TensorFormat */,
-             std::string /** < TensorType */>;
+             std::string /** < Weight TensorType */,
+             std::string /** < Activation TensorType */
+             >;
 
 /**
  * @brief Golden Layer Test with designated format
index 99b58a8..40faf4f 100644 (file)
@@ -41,8 +41,9 @@ static const std::string getGoldenPath(const std::string &file_name) {
   return getResPath(file_name, {"test", "unittest_layers"});
 }
 
-static InitLayerContext createInitContext(Layer *layer,
-                                          const std::string &input_shape_str) {
+static InitLayerContext
+createInitContext(Layer *layer, const std::string &input_shape_str,
+                  std::array<const std::string, 3> tensor_type) {
   struct shape_parser_ : Property<TensorDim> {
     using prop_tag = dimension_prop_tag;
   };
@@ -55,7 +56,7 @@ static InitLayerContext createInitContext(Layer *layer,
   }
 
   InitLayerContext context({parsed.begin(), parsed.end()}, {true}, false,
-                           "golden_test");
+                           "golden_test", "", 0.0, tensor_type);
   layer->finalize(context);
 
   return context;
@@ -262,13 +263,15 @@ TEST_P(LayerGoldenTest, run) {
   auto f = std::get<0>(GetParam());
   auto layer = f(std::get<1>(GetParam()));
   std::string format = std::get<5>(GetParam());
-  std::string type = std::get<6>(GetParam());  
-  layer->setTensorType({format, type});
+  std::string type_w = std::get<6>(GetParam());
+  std::string type_a = std::get<7>(GetParam());
+
   auto golden_file = checkedOpenStream<std::ifstream>(
     getGoldenPath(std::get<3>(GetParam())), std::ios::in | std::ios::binary);
   auto &input_dims = std::get<2>(GetParam());
 
-  auto ic = createInitContext(layer.get(), input_dims);
+  auto ic =
+    createInitContext(layer.get(), input_dims, {format, type_w, type_a});
   auto tensors = prepareTensors(ic, golden_file);
   auto rc = prepareRunContext(tensors);
 
index 21ab07e..c2750d6 100644 (file)
@@ -27,17 +27,17 @@ GTEST_PARAMETER_TEST(Attention, LayerSemantics,
 auto attention_shared_kv = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::AttentionLayer>, {}, "1:1:5:7,1:1:3:7",
   "attention_shared_kv.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto attention_shared_kv_batched = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::AttentionLayer>, {}, "2:1:5:7,2:1:3:7",
   "attention_shared_kv_batched.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto attention_batched = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::AttentionLayer>, {},
   "2:1:5:7,2:1:3:7,2:1:3:7", "attention_batched.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(Attention, LayerGoldenTest,
                      ::testing::Values(attention_shared_kv,
index b759555..4922bec 100644 (file)
@@ -31,20 +31,22 @@ auto bn_inference_option = LayerGoldenTestParamOptions::SKIP_CALC_GRAD |
 auto bn_basic_channels_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:4:2:3",
   "bn_channels_training.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto bn_basic_channels_inference = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:4:2:3",
-  "bn_channels_inference.nnlayergolden", bn_inference_option, "nchw", "fp32");
+  "bn_channels_inference.nnlayergolden", bn_inference_option, "nchw", "fp32",
+  "fp32");
 
 auto bn_basic_width_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:1:1:10",
   "bn_width_training.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto bn_basic_width_inference = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:1:1:10",
-  "bn_width_inference.nnlayergolden", bn_inference_option, "nchw", "fp32");
+  "bn_width_inference.nnlayergolden", bn_inference_option, "nchw", "fp32",
+  "fp32");
 
 GTEST_PARAMETER_TEST(BatchNormalization, LayerGoldenTest,
                      ::testing::Values(bn_basic_channels_training,
index addbb9e..622a6a0 100644 (file)
@@ -26,17 +26,17 @@ GTEST_PARAMETER_TEST(Concat, LayerSemantics,
 auto concat_dim3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::ConcatLayer>, {"axis=3"},
   "2:3:3:2, 2:3:3:3", "concat_dim3.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto concat_dim2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::ConcatLayer>, {"axis=2"},
   "2:3:2:3, 2:3:3:3", "concat_dim2.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto concat_dim1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::ConcatLayer>, {"axis=1"},
   "2:2:3:3, 2:3:3:3", "concat_dim1.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(Concat, LayerGoldenTest,
                      ::testing::Values(concat_dim3, concat_dim2, concat_dim1));
index 9a608c6..29ec38b 100644 (file)
@@ -27,24 +27,24 @@ GTEST_PARAMETER_TEST(Convolution1D, LayerSemantics,
 auto conv1d_sb_minimum = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2"}, "1:1:1:4", "conv1d_sb_minimum.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_mb_minimum = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2"}, "3:1:1:4", "conv1d_mb_minimum.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_sb_same_remain = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=2", "kernel_size=3", "padding=same"}, "1:1:1:4",
   "conv1d_sb_same_remain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv1d_mb_same_remain = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=2", "kernel_size=3", "padding=same"}, "3:1:1:4",
   "conv1d_mb_same_remain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv1d_sb_same_uneven_remain_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -55,7 +55,7 @@ auto conv1d_sb_same_uneven_remain_1 = LayerGoldenTestParamType(
     "padding=same",
   },
   "1:3:1:4", "conv1d_sb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_sb_same_uneven_remain_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -66,7 +66,7 @@ auto conv1d_sb_same_uneven_remain_2 = LayerGoldenTestParamType(
     "padding=0,1",
   },
   "1:3:1:4", "conv1d_sb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_mb_same_uneven_remain_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -77,7 +77,7 @@ auto conv1d_mb_same_uneven_remain_1 = LayerGoldenTestParamType(
     "padding=same",
   },
   "3:3:1:4", "conv1d_mb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_mb_same_uneven_remain_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -88,7 +88,7 @@ auto conv1d_mb_same_uneven_remain_2 = LayerGoldenTestParamType(
     "padding=0,1",
   },
   "3:3:1:4", "conv1d_mb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_sb_valid_drop_last = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -99,7 +99,7 @@ auto conv1d_sb_valid_drop_last = LayerGoldenTestParamType(
     "padding=valid",
   },
   "1:3:1:7", "conv1d_sb_valid_drop_last.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_mb_valid_drop_last = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -110,13 +110,13 @@ auto conv1d_mb_valid_drop_last = LayerGoldenTestParamType(
     "padding=valid",
   },
   "3:3:1:7", "conv1d_mb_valid_drop_last.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_sb_no_overlap = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2", "stride=3"}, "1:2:1:5",
   "conv1d_sb_no_overlap.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv1d_mb_no_overlap = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -126,25 +126,25 @@ auto conv1d_mb_no_overlap = LayerGoldenTestParamType(
     "stride=3",
   },
   "3:2:1:5", "conv1d_mb_no_overlap.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_sb_causal = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2", "padding=causal"}, "1:1:1:4",
   "conv1d_sb_causal.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv1d_mb_causal = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2", "padding=causal"}, "3:1:1:4",
   "conv1d_mb_causal.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv1d_sb_1x1_kernel = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=1", "stride=2"}, "1:2:1:5",
   "conv1d_sb_1x1_kernel.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv1d_mb_1x1_kernel = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -154,7 +154,7 @@ auto conv1d_mb_1x1_kernel = LayerGoldenTestParamType(
     "stride=2",
   },
   "3:2:1:5", "conv1d_mb_1x1_kernel.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_sb_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -164,7 +164,7 @@ auto conv1d_sb_dilation = LayerGoldenTestParamType(
     "dilation=2",
   },
   "1:3:1:11", "conv1d_sb_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_mb_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -174,7 +174,7 @@ auto conv1d_mb_dilation = LayerGoldenTestParamType(
     "dilation=2",
   },
   "3:3:1:11", "conv1d_mb_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_sb_same_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -185,7 +185,7 @@ auto conv1d_sb_same_dilation = LayerGoldenTestParamType(
     "dilation=2",
   },
   "1:3:1:11", "conv1d_sb_same_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_mb_same_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -196,19 +196,19 @@ auto conv1d_mb_same_dilation = LayerGoldenTestParamType(
     "dilation=2",
   },
   "3:3:1:11", "conv1d_mb_same_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_sb_causal_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2", "padding=causal", "dilation=2"}, "1:1:1:4",
   "conv1d_sb_causal_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv1d_mb_causal_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2", "padding=causal", "dilation=2"}, "3:1:1:4",
   "conv1d_mb_causal_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(
   Convolution1D, LayerGoldenTest,
index d715a19..724c790 100644 (file)
@@ -28,25 +28,25 @@ auto conv2d_sb_minimum = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
   {"filters=3", "kernel_size=2,2"}, "1:1:4:4",
   "conv2d_sb_minimum.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv2d_mb_minimum = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
   {"filters=3", "kernel_size=2,2"}, "3:1:4:4",
   "conv2d_mb_minimum.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv2d_sb_same_remain = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
   {"filters=2", "kernel_size=3,3", "padding=same"}, "1:1:4:4",
   "conv2d_sb_same_remain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv2d_mb_same_remain = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
   {"filters=2", "kernel_size=3,3", "padding=same"}, "3:1:4:4",
   "conv2d_mb_same_remain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv2d_sb_same_uneven_remain_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -57,7 +57,7 @@ auto conv2d_sb_same_uneven_remain_1 = LayerGoldenTestParamType(
     "padding=same",
   },
   "1:3:4:4", "conv2d_sb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_sb_same_uneven_remain_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -68,7 +68,7 @@ auto conv2d_sb_same_uneven_remain_2 = LayerGoldenTestParamType(
     "padding=0,1,0,1",
   },
   "1:3:4:4", "conv2d_sb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_mb_same_uneven_remain_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -79,7 +79,7 @@ auto conv2d_mb_same_uneven_remain_1 = LayerGoldenTestParamType(
     "padding=same",
   },
   "3:3:4:4", "conv2d_mb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_mb_same_uneven_remain_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -90,7 +90,7 @@ auto conv2d_mb_same_uneven_remain_2 = LayerGoldenTestParamType(
     "padding=0,1,0,1",
   },
   "3:3:4:4", "conv2d_mb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_sb_valid_drop_last = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -101,7 +101,7 @@ auto conv2d_sb_valid_drop_last = LayerGoldenTestParamType(
     "padding=valid",
   },
   "1:3:7:7", "conv2d_sb_valid_drop_last.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_mb_valid_drop_last = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -112,13 +112,13 @@ auto conv2d_mb_valid_drop_last = LayerGoldenTestParamType(
     "padding=valid",
   },
   "3:3:7:7", "conv2d_mb_valid_drop_last.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_sb_no_overlap = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
   {"filters=3", "kernel_size=2,2", "stride=3,3"}, "1:2:5:5",
   "conv2d_sb_no_overlap.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv2d_mb_no_overlap = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -128,13 +128,13 @@ auto conv2d_mb_no_overlap = LayerGoldenTestParamType(
     "stride=3,3",
   },
   "3:2:5:5", "conv2d_mb_no_overlap.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_sb_1x1_kernel = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
   {"filters=3", "kernel_size=1,1", "stride=2,2"}, "1:2:5:5",
   "conv2d_sb_1x1_kernel.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto conv2d_mb_1x1_kernel = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -144,7 +144,7 @@ auto conv2d_mb_1x1_kernel = LayerGoldenTestParamType(
     "stride=2,2",
   },
   "3:2:5:5", "conv2d_mb_1x1_kernel.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_sb_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -154,7 +154,7 @@ auto conv2d_sb_dilation = LayerGoldenTestParamType(
     "dilation=2,2",
   },
   "1:3:11:11", "conv2d_sb_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_mb_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -164,7 +164,7 @@ auto conv2d_mb_dilation = LayerGoldenTestParamType(
     "dilation=2,2",
   },
   "3:3:11:11", "conv2d_mb_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_sb_same_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -175,7 +175,7 @@ auto conv2d_sb_same_dilation = LayerGoldenTestParamType(
     "dilation=2,2",
   },
   "1:3:11:11", "conv2d_sb_same_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto conv2d_mb_same_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -186,7 +186,7 @@ auto conv2d_mb_same_dilation = LayerGoldenTestParamType(
     "dilation=2,2",
   },
   "3:3:11:11", "conv2d_mb_same_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(
   Convolution2D, LayerGoldenTest,
index ecbfd64..aca658f 100644 (file)
@@ -34,22 +34,22 @@ auto dropout_20_training = LayerGoldenTestParamType(
   "2:3:2:3", "dropout_20_training.nnlayergolden",
   LayerGoldenTestParamOptions::DEFAULT |
     LayerGoldenTestParamOptions::DROPOUT_MATCH_60_PERCENT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto dropout_20_inference = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::DropOutLayer>, {"dropout_rate=0.2"},
   "2:3:2:3", "dropout_20_inference.nnlayergolden", dropout_inference_option,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto dropout_0_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::DropOutLayer>, {"dropout_rate=0.0"},
   "2:3:2:3", "dropout_0_training.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto dropout_100_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::DropOutLayer>, {"dropout_rate=1.0"},
   "2:3:2:3", "dropout_100_training.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(Dropout, LayerGoldenTest,
                      ::testing::Values(dropout_20_training, dropout_0_training,
index eead808..1f42959 100644 (file)
@@ -28,16 +28,16 @@ GTEST_PARAMETER_TEST(FullyConnected, LayerSemantics,
 auto fc_basic_plain = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>, {"unit=5"},
   "3:1:1:10", "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 auto fc_basic_single_batch = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>, {"unit=4"},
   "1:1:1:10", "fc_single_batch.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 auto fc_basic_no_decay = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>,
   {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:1:1:10",
   "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw",
-  "fp32");
+  "fp32", "fp32");
 
 auto fc_basic_plain_nhwc = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>, {"unit=5"},
index 63a62a1..29b0422 100644 (file)
@@ -26,82 +26,83 @@ GTEST_PARAMETER_TEST(GRU, LayerSemantics, ::testing::Values(semantic_gru));
 auto gru_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "integrate_bias=true", "reset_after=false"}, "3:1:1:7",
-  "gru_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "gru_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw",
+  "fp32", "fp32");
 
 auto gru_multi_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "integrate_bias=true", "reset_after=false"}, "3:1:4:7",
-  "gru_multi_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
-
-auto gru_single_step_seq =
-  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::GRULayer>,
-                           {"unit=5", "return_sequences=true",
-                            "integrate_bias=true", "reset_after=false"},
-                           "3:1:1:7", "gru_single_step_seq.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
-
-auto gru_multi_step_seq =
-  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::GRULayer>,
-                           {"unit=5", "return_sequences=true",
-                            "integrate_bias=true", "reset_after=false"},
-                           "3:1:4:7", "gru_multi_step_seq.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  "gru_multi_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw",
+  "fp32", "fp32");
+
+auto gru_single_step_seq = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::GRULayer>,
+  {"unit=5", "return_sequences=true", "integrate_bias=true",
+   "reset_after=false"},
+  "3:1:1:7", "gru_single_step_seq.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
+
+auto gru_multi_step_seq = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::GRULayer>,
+  {"unit=5", "return_sequences=true", "integrate_bias=true",
+   "reset_after=false"},
+  "3:1:4:7", "gru_multi_step_seq.nnlayergolden",
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto gru_multi_step_seq_act_orig = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "hidden_state_activation=tanh",
    "recurrent_activation=sigmoid", "integrate_bias=true", "reset_after=false"},
   "3:1:4:7", "gru_multi_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto gru_multi_step_seq_act = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "hidden_state_activation=sigmoid",
    "recurrent_activation=tanh", "integrate_bias=true", "reset_after=false"},
   "3:1:4:7", "gru_multi_step_seq_act.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 // Check reset_after
 auto gru_reset_after_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "integrate_bias=false", "reset_after=true"}, "3:1:1:7",
   "gru_reset_after_single_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto gru_reset_after_multi_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "integrate_bias=false", "reset_after=true"}, "3:1:4:7",
   "gru_reset_after_multi_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto gru_reset_after_single_step_seq = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "integrate_bias=false",
    "reset_after=true"},
   "3:1:1:7", "gru_reset_after_single_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto gru_reset_after_multi_step_seq = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "integrate_bias=false",
    "reset_after=true"},
   "3:1:4:7", "gru_reset_after_multi_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto gru_reset_after_multi_step_seq_act_orig = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "hidden_state_activation=tanh",
    "recurrent_activation=sigmoid", "integrate_bias=false", "reset_after=true"},
   "3:1:4:7", "gru_reset_after_multi_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto gru_reset_after_multi_step_seq_act = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "hidden_state_activation=sigmoid",
    "recurrent_activation=tanh", "integrate_bias=false", "reset_after=true"},
   "3:1:4:7", "gru_reset_after_multi_step_seq_act.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(
   GRU, LayerGoldenTest,
index d6575db..253a6e7 100644 (file)
@@ -29,20 +29,20 @@ auto grucell_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRUCellLayer>,
   {"unit=5", "integrate_bias=true", "reset_after=false"}, "3:1:1:7,3:1:1:5",
   "grucell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto grucell_reset_after_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRUCellLayer>,
   {"unit=5", "integrate_bias=false", "reset_after=true"}, "3:1:1:7,3:1:1:5",
   "grucell_reset_after_single_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto grucell_single_step_act = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRUCellLayer>,
   {"unit=5", "integrate_bias=true", "reset_after=false",
    "hidden_state_activation=sigmoid", "recurrent_activation=tanh"},
   "3:1:1:7,3:1:1:5", "grucell_single_step_act.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(GRUCell, LayerGoldenTest,
                      ::testing::Values(grucell_single_step,
index df8c93b..c2f537d 100644 (file)
@@ -27,37 +27,37 @@ GTEST_PARAMETER_TEST(LayerNormalization, LayerSemantics,
 auto ln_axis_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1"},
   "2:4:2:3", "ln_axis_1.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto ln_axis_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=2"},
   "2:4:2:3", "ln_axis_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto ln_axis_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=3"},
   "2:4:2:3", "ln_axis_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto ln_axis_1_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 2"},
   "2:4:2:3", "ln_axis_1_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto ln_axis_2_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=2, 3"},
   "2:4:2:3", "ln_axis_2_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto ln_axis_1_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 3"},
   "2:4:2:3", "ln_axis_1_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto ln_axis_1_2_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 2, 3"},
   "2:4:2:3", "ln_axis_1_2_3.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(LayerNormalization, LayerGoldenTest,
                      ::testing::Values(ln_axis_1, ln_axis_2, ln_axis_3,
index 7aea779..a67f610 100644 (file)
@@ -23,42 +23,42 @@ auto semantic_lstm = LayerSemanticsParamType(
 
 GTEST_PARAMETER_TEST(LSTM, LayerSemantics, ::testing::Values(semantic_lstm));
 
-auto lstm_single_step =
-  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::LSTMLayer>,
-                           {"unit=5", "integrate_bias=true"}, "3:1:1:7",
-                           "lstm_single_step.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+auto lstm_single_step = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::LSTMLayer>,
+  {"unit=5", "integrate_bias=true"}, "3:1:1:7",
+  "lstm_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw", "fp32", "fp32");
 
 auto lstm_multi_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true"}, "3:1:4:7", "lstm_multi_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto lstm_single_step_seq = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true", "return_sequences=true"}, "3:1:1:7",
   "lstm_single_step_seq.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto lstm_multi_step_seq = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true", "return_sequences=true"}, "3:1:4:7",
   "lstm_multi_step_seq.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto lstm_multi_step_seq_act_orig = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true", "return_sequences=true",
    "hidden_state_activation=tanh", "recurrent_activation=sigmoid"},
   "3:1:4:7", "lstm_multi_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto lstm_multi_step_seq_act = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true", "return_sequences=true",
    "hidden_state_activation=sigmoid", "recurrent_activation=tanh"},
   "3:1:4:7", "lstm_multi_step_seq_act.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(LSTM, LayerGoldenTest,
                      ::testing::Values(lstm_single_step, lstm_multi_step,
index 04059c0..e4423c0 100644 (file)
@@ -28,7 +28,7 @@ auto lstmcell_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMCellLayer>,
   {"unit=5", "integrate_bias=true"}, "3:1:1:7,3:1:1:5,3:1:1:5",
   "lstmcell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(LSTMCell, LayerGoldenTest,
                      ::testing::Values(lstmcell_single_step));
index c6736af..1dfb335 100644 (file)
@@ -37,13 +37,13 @@ auto multi_head_attention_single_batch = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3"}, "1:1:5:7,1:1:3:7,1:1:3:7",
   "multi_head_attention_single_batch.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto multi_head_attention = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3"}, "2:1:5:7,2:1:3:7,2:1:3:7",
   "multi_head_attention.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "nchw", "fp32", "fp32");
 
 auto multi_head_attention_return_attention_scores = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
@@ -51,19 +51,19 @@ auto multi_head_attention_return_attention_scores = LayerGoldenTestParamType(
    "average_attention_weight=false"},
   "2:1:5:7,2:1:3:7,2:1:3:7",
   "multi_head_attention_return_attention_scores.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto multi_head_attention_value_dim = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3", "projected_value_dim=5"},
   "2:1:5:7,2:1:3:7,2:1:3:7", "multi_head_attention_value_dim.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto multi_head_attention_output_shape = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3", "output_shape=5"},
   "2:1:5:7,2:1:3:7,2:1:3:7", "multi_head_attention_output_shape.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(
   MultiHeadAttention, LayerGoldenTest,
index caffe8a..95c7660 100644 (file)
@@ -28,12 +28,12 @@ INSTANTIATE_TEST_CASE_P(PositionalEncoding, LayerSemantics,
 auto positional_encoding_partial = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::PositionalEncodingLayer>,
   {"max_timestep=10"}, "3:1:7:6", "positional_encoding_partial.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 auto positional_encoding = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::PositionalEncodingLayer>,
   {"max_timestep=10"}, "3:1:10:6", "positional_encoding.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
 
 INSTANTIATE_TEST_CASE_P(PositionalEncoding, LayerGoldenTest,
                         ::testing::Values(positional_encoding_partial,
index 4534c38..a1924ac 100644 (file)
@@ -26,7 +26,7 @@ GTEST_PARAMETER_TEST(RNN, LayerSemantics, ::testing::Values(semantic_rnn));
 auto rnn_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::RNNLayer>,
   {"unit=5", "return_sequences=false", "integrate_bias=true"}, "3:1:1:7",
-  "rnn_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
-  "nchw", "fp32");
+  "rnn_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw",
+  "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(RNN, LayerGoldenTest, ::testing::Values(rnn_single_step));
index d8ff272..ca8337a 100644 (file)
@@ -24,11 +24,11 @@ auto semantic_rnncell = LayerSemanticsParamType(
 GTEST_PARAMETER_TEST(RNNCell, LayerSemantics,
                      ::testing::Values(semantic_rnncell));
 
-auto rnncell_single_step =
-  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::RNNCellLayer>,
-                           {"unit=5", "integrate_bias=true"}, "3:1:1:7,3:1:1:5",
-                           "rnncell_single_step.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32");
+auto rnncell_single_step = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::RNNCellLayer>,
+  {"unit=5", "integrate_bias=true"}, "3:1:1:7,3:1:1:5",
+  "rnncell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw", "fp32", "fp32");
 
 GTEST_PARAMETER_TEST(RNNCell, LayerGoldenTest,
                      ::testing::Values(rnncell_single_step));