[ Property ] Add Tensor Type property in model
authorjijoong.moon <jijoong.moon@samsung.com>
Thu, 29 Jun 2023 12:36:30 +0000 (21:36 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 17 Jul 2023 00:01:56 +0000 (09:01 +0900)
This PR enables the tensor type in model property as
"tensor_type=NHWC" or "tensor_type=NCHW". This information goes to
network_grap and layer node & manager.

Then, each layer can get the model tensor type information and it can
be used to request tensor or just using temporal tensor.

Resolves:

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
32 files changed:
api/ccapi/include/layer.h
nntrainer/graph/network_graph.cpp
nntrainer/graph/network_graph.h
nntrainer/layers/fc_layer.cpp
nntrainer/layers/layer_devel.h
nntrainer/layers/layer_node.cpp
nntrainer/layers/layer_node.h
nntrainer/models/model_common_properties.cpp
nntrainer/models/model_common_properties.h
nntrainer/models/neuralnet.cpp
nntrainer/models/neuralnet.h
nntrainer/tensor/manager.h
test/unittest/layers/layers_common_tests.h
test/unittest/layers/layers_golden_tests.cpp
test/unittest/layers/unittest_layers_attention.cpp
test/unittest/layers/unittest_layers_batch_normalization.cpp
test/unittest/layers/unittest_layers_concat.cpp
test/unittest/layers/unittest_layers_convolution1d.cpp
test/unittest/layers/unittest_layers_convolution2d.cpp
test/unittest/layers/unittest_layers_dropout.cpp
test/unittest/layers/unittest_layers_fully_connected.cpp
test/unittest/layers/unittest_layers_gru.cpp
test/unittest/layers/unittest_layers_grucell.cpp
test/unittest/layers/unittest_layers_layer_normalization.cpp
test/unittest/layers/unittest_layers_lstm.cpp
test/unittest/layers/unittest_layers_lstmcell.cpp
test/unittest/layers/unittest_layers_multi_head_attention.cpp
test/unittest/layers/unittest_layers_positional_encoding.cpp
test/unittest/layers/unittest_layers_rnn.cpp
test/unittest/layers/unittest_layers_rnncell.cpp
test/unittest/layers/unittest_layers_zoneout_lstmcell.cpp
test/unittest/unittest_nntrainer_models.cpp

index 9690b74..3bd4683 100644 (file)
@@ -168,6 +168,13 @@ public:
   virtual void setProperty(const std::vector<std::string> &values) = 0;
 
   /**
+   * @brief     Set Tensor Type : NCHW or NHWC
+   * @note      This is used mainly for the unittest case which does not have
+   * model.
+   */
+  virtual void setTensorType(const std::string &values){};
+
+  /**
    * @brief     Get name of the layer
    * @retval    name of the layer
    * @note      This name is unique to this layer in a model
index 033af8a..18f1256 100644 (file)
@@ -81,6 +81,19 @@ int NetworkGraph::compile(const std::string &loss_type) {
 
   inPlaceOptimize();
 
+  TensorDim::Format type =
+    (getModelTensorType().compare("NCHW") ||
+     getModelTensorType().compare("nchw") || getModelTensorType().compare(""))
+      ? TensorDim::Format::NCHW
+      : TensorDim::Format::NHWC;
+
+  for (auto iter = cbegin(); iter != cend(); iter++) {
+    auto lnode = (*iter);
+    /// @todo  later, we can set layer tensor type differenctly with model
+    /// tensor type
+    lnode->setTensorType(getModelTensorType());
+  }
+
   status = checkCompiledGraph();
   NN_RETURN_STATUS();
 
index 6a7996c..293e931 100644 (file)
@@ -48,7 +48,8 @@ public:
     backward_iter_end(nullptr),
     forward_iter_end(nullptr),
     optimize_memory(true),
-    exec_mode(ExecutionMode::TRAIN) {}
+    exec_mode(ExecutionMode::TRAIN),
+    model_tensor_type("NCHW") {}
 
   /**
    * @brief     Constructor of NeuralNetwork Graph Class
@@ -56,9 +57,10 @@ public:
    * @param[in] swap_path memory swap file path when the swap is enabled
    */
   NetworkGraph(bool enable_swap, const std::string &swap_path = "",
-               unsigned int lookahead = 0) :
-    tensor_manager(
-      std::make_shared<Manager>(enable_swap, swap_path, lookahead)),
+               unsigned int lookahead = 0,
+               const std::string &tensor_type = "NCHW") :
+    tensor_manager(std::make_shared<Manager>(enable_swap, swap_path, lookahead,
+                                             tensor_type)),
     graph(),
     compiled(false),
     batch_size(0),
@@ -66,7 +68,8 @@ public:
     backward_iter_end(nullptr),
     forward_iter_end(nullptr),
     optimize_memory(true),
-    exec_mode(ExecutionMode::TRAIN) {}
+    exec_mode(ExecutionMode::TRAIN),
+    model_tensor_type(tensor_type) {}
 
   /**
    * @brief   Destructor of the NeuralNetwork Graph class
@@ -366,6 +369,13 @@ public:
   std::vector<Tensor> getOutputTensors() const;
 
   /**
+   * @brief return model tensor type
+   *
+   * @return TensorDim::Format NCHW or NHWC
+   */
+  std::string getModelTensorType() const { return model_tensor_type; };
+
+  /**
    * @brief Flush data to the device
    *
    */
@@ -418,6 +428,8 @@ private:
   ExecutionMode exec_mode; /**< execution mode with which the graph has been
                               currently set or previously set */
 
+  std::string model_tensor_type; /**< Model Tensor Type: NCHW or NHWC */
+
   std::unordered_map<std::string, int>
     profile_keys; /**< profile keys based on the layer type */
   std::vector<Weight *>
index a2044c8..3b7cdeb 100644 (file)
@@ -75,9 +75,8 @@ void FullyConnectedLayer::finalize(InitLayerContext &context) {
   /** set weight specifications */
   // @todo : This NCHW format setting is just temporal, it needs to be set by
   // global configuration
-  TensorDim bias_dim(1, 1, 1, unit, ml::train::TensorDim::Format::NCHW, 0b0001);
-  TensorDim weight_dim(1, 1, in_dim.width(), unit,
-                       ml::train::TensorDim::Format::NCHW, 0b0011);
+  TensorDim bias_dim(1, 1, 1, unit, getTensorType(), 0b0001);
+  TensorDim weight_dim(1, 1, in_dim.width(), unit, getTensorType(), 0b0011);
 
   weight_idx[FCParams::weight] = context.requestWeight(
     weight_dim, weight_initializer, weight_regularizer,
index 72bc330..23cf09f 100644 (file)
@@ -28,6 +28,7 @@
 #include <vector>
 
 #include <common.h>
+#include <tensor_dim.h>
 
 namespace ml::train {
 class Layer;
@@ -239,6 +240,37 @@ public:
    * @return true if supports backwarding, else false
    */
   virtual bool supportBackwarding() const = 0;
+
+  /**
+   * @brief Set the Tensor Type for the layer
+   * @param     Tensor Type : TensorDim::Format::NCHW or TneosrDim::Format::NHWC
+   */
+  virtual void setTensorType(
+    ml::train::TensorDim::Format type = ml::train::TensorDim::Format::NCHW) {
+    tensor_type = type;
+  }
+
+  /**
+   * @brief set the Tensor Type for the layer
+   * @param     Tensor Type : NCHW or NHWC
+   */
+  void setTensorType(const std::string &values) {
+    if (values.compare("NCHW") || values.compare("nchw")) {
+      tensor_type = ml::train::TensorDim::Format::NCHW;
+    } else {
+      tensor_type = ml::train::TensorDim::Format::NHWC;
+    }
+  }
+
+  /**
+   * @brief get the Tensor Type for the layer
+   * @return     Tensor Type : TensorDim::Format::NCHW or
+   * TneosrDim::Format::NHWC
+   */
+  virtual ml::train::TensorDim::Format getTensorType() { return tensor_type; }
+
+private:
+  ml::train::TensorDim::Format tensor_type;
 };
 
 /// @todo Decide where to put and how to implement(#986)
index 38154cf..2ef2da6 100644 (file)
@@ -247,6 +247,13 @@ void LayerNode::setOutputConnection(unsigned nth, const std::string &name,
   con = std::make_unique<Connection>(name, index);
 }
 
+void LayerNode::setTensorType(const std::string type_) {
+  TensorDim::Format type = (type_.compare("NCHW") || type_.compare("nchw"))
+                             ? TensorDim::Format::NCHW
+                             : TensorDim::Format::NHWC;
+  getLayer()->setTensorType(type);
+}
+
 const std::string LayerNode::getName() const noexcept {
   auto &name = std::get<props::Name>(*layer_node_props);
   return name.empty() ? "" : name.get();
index 924c746..07a6da8 100644 (file)
@@ -801,6 +801,13 @@ public:
    */
   bool needsCalcGradient() { return needs_calc_gradient; }
 
+  /**
+   * @brief Set Tensor type for layer
+   *
+   * @param type NCHW : NHWC
+   */
+  void setTensorType(const std::string type_ = "NCHW");
+
 private:
   /**
    * @brief     Get the Input Layers object
@@ -828,6 +835,8 @@ private:
   std::vector<std::unique_ptr<Connection>>
     output_connections; /**< output layer names */
 
+  TensorDim::Format tensor_type;
+
 #ifdef ENABLE_TEST
   /**
    * @brief   Init context which is stored for debugging issue
index f47e75b..534620a 100644 (file)
@@ -37,4 +37,6 @@ MemorySwapLookahead::MemorySwapLookahead(const unsigned int &value) {
   set(value);
 }
 
+ModelTensorType::ModelTensorType(const std::string &value) { set(value); }
+
 } // namespace nntrainer::props
index 3792aff..dfb46f3 100644 (file)
@@ -179,6 +179,23 @@ public:
   MemorySwapLookahead(const unsigned int &value = 0);
 };
 
+/**
+ * @brief model tensor type : NCHW or NHWC
+ *
+ */
+class ModelTensorType : public Property<std::string> {
+public:
+  static constexpr const char *key = "tensor_type"; /**< unique key to access */
+  using prop_tag = str_prop_tag;                    /**< property type */
+
+  /**
+   * @brief Constructor
+   *
+   * @param value value to set, defaults to false
+   */
+  ModelTensorType(const std::string &value = "NCHW");
+};
+
 } // namespace nntrainer::props
 
 #endif
index 57bb8eb..4d61932 100644 (file)
@@ -66,10 +66,11 @@ namespace nntrainer {
 
 NeuralNetwork::NeuralNetwork() :
   model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
-  model_flex_props(
-    props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
-    props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
-    props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead()),
+  model_flex_props(props::Epochs(), props::TrainingBatchSize(),
+                   props::SavePath(), props::ContinueTrain(),
+                   props::SaveBestPath(), props::MemoryOptimization(),
+                   props::MemorySwap(), props::MemorySwapPath(),
+                   props::MemorySwapLookahead(), props::ModelTensorType()),
   load_path(std::string()),
   epoch_idx(0),
   iter(0),
@@ -83,10 +84,11 @@ NeuralNetwork::NeuralNetwork() :
 
 NeuralNetwork::NeuralNetwork(AppContext app_context_) :
   model_props(props::LossType(), {}, {}, props::ClipGradByGlobalNorm()),
-  model_flex_props(
-    props::Epochs(), props::TrainingBatchSize(), props::SavePath(),
-    props::ContinueTrain(), props::SaveBestPath(), props::MemoryOptimization(),
-    props::MemorySwap(), props::MemorySwapPath(), props::MemorySwapLookahead()),
+  model_flex_props(props::Epochs(), props::TrainingBatchSize(),
+                   props::SavePath(), props::ContinueTrain(),
+                   props::SaveBestPath(), props::MemoryOptimization(),
+                   props::MemorySwap(), props::MemorySwapPath(),
+                   props::MemorySwapLookahead(), props::ModelTensorType()),
   load_path(std::string()),
   epoch_idx(0),
   iter(0),
@@ -170,7 +172,12 @@ int NeuralNetwork::compile() {
     std::get<props::MemorySwapPath>(model_flex_props);
   unsigned int lookahead =
     std::get<props::MemorySwapLookahead>(model_flex_props);
-  model_graph = NetworkGraph(memory_swap, memory_swap_path, lookahead);
+
+  const std::string tensor_type =
+    std::get<props::ModelTensorType>(model_flex_props);
+
+  model_graph =
+    NetworkGraph(memory_swap, memory_swap_path, lookahead, tensor_type);
 
   model_graph.setMemoryOptimizations(
     std::get<props::MemoryOptimization>(model_flex_props));
index 0e8820b..e63317f 100644 (file)
@@ -556,7 +556,8 @@ private:
     std::tuple<props::Epochs, props::TrainingBatchSize, props::SavePath,
                props::ContinueTrain, props::SaveBestPath,
                props::MemoryOptimization, props::MemorySwap,
-               props::MemorySwapPath, props::MemorySwapLookahead>;
+               props::MemorySwapPath, props::MemorySwapLookahead,
+               props::ModelTensorType>;
   using RigidPropTypes =
     std::tuple<props::LossType, std::vector<props::InputConnection>,
                std::vector<props::LabelLayer>, props::ClipGradByGlobalNorm>;
index 5f287ac..a13cefc 100644 (file)
@@ -130,17 +130,21 @@ public:
   /**
    * @brief     Constructor of Manager
    */
-  Manager() : enable_optimizations(true), swap_lookahead(0) {}
+  Manager() :
+    enable_optimizations(true),
+    swap_lookahead(0),
+    tensor_type("nchw") {}
 
   /**
    * @brief     Constructor of Manager
    */
   Manager(bool enable_swap, const std::string &swap_path = "",
-          unsigned int lookahead = 0) :
+          unsigned int lookahead = 0, const std::string tensor_type_ = "nchw") :
     weight_pool(enable_swap, swap_path, "weight_pool"),
     tensor_pool(enable_swap, swap_path, "tensor_pool"),
     enable_optimizations(true),
-    swap_lookahead(lookahead) {}
+    swap_lookahead(lookahead),
+    tensor_type(tensor_type_) {}
 
   /**
    * @brief Construct a new Manager object (deleted)
@@ -501,6 +505,8 @@ private:
 
   unsigned int swap_lookahead; /** lookahead for memory swap */
 
+  std::string tensor_type;
+
   /**
    * @brief Finalize the given tensor pool
    *
index 54dac43..b41d232 100644 (file)
@@ -108,7 +108,8 @@ using LayerGoldenTestParamType =
              std::vector<std::string> /**< Properties */,
              const char *, /**< Input Tensor dimensions representation */
              const char * /**< Golden file name */,
-             int /**< LayerGoldenTestParamOptions */>;
+             int /**< LayerGoldenTestParamOptions */,
+             std::string /**<TensorType*/>;
 
 /**
  * @brief Golden Layer Test with designated format
index 6d6f4a8..8189468 100644 (file)
@@ -257,6 +257,8 @@ bool LayerGoldenTest::shouldSkipCalcGrad() {
 TEST_P(LayerGoldenTest, run) {
   auto f = std::get<0>(GetParam());
   auto layer = f(std::get<1>(GetParam()));
+  std::string type = std::get<5>(GetParam());
+  layer->setTensorType(type);
   auto golden_file = checkedOpenStream<std::ifstream>(
     getGoldenPath(std::get<3>(GetParam())), std::ios::in | std::ios::binary);
   auto &input_dims = std::get<2>(GetParam());
index 04ae2af..00910f2 100644 (file)
@@ -26,17 +26,18 @@ GTEST_PARAMETER_TEST(Attention, LayerSemantics,
 
 auto attention_shared_kv = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::AttentionLayer>, {}, "1:1:5:7,1:1:3:7",
-  "attention_shared_kv.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "attention_shared_kv.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto attention_shared_kv_batched = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::AttentionLayer>, {}, "2:1:5:7,2:1:3:7",
   "attention_shared_kv_batched.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto attention_batched = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::AttentionLayer>, {},
   "2:1:5:7,2:1:3:7,2:1:3:7", "attention_batched.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(Attention, LayerGoldenTest,
                      ::testing::Values(attention_shared_kv,
index a9c994d..5a79b2f 100644 (file)
@@ -30,19 +30,21 @@ auto bn_inference_option = LayerGoldenTestParamOptions::SKIP_CALC_GRAD |
 
 auto bn_basic_channels_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:4:2:3",
-  "bn_channels_training.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "bn_channels_training.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto bn_basic_channels_inference = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:4:2:3",
-  "bn_channels_inference.nnlayergolden", bn_inference_option);
+  "bn_channels_inference.nnlayergolden", bn_inference_option, "nchw");
 
 auto bn_basic_width_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:1:1:10",
-  "bn_width_training.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "bn_width_training.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto bn_basic_width_inference = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::BatchNormalizationLayer>, {}, "2:1:1:10",
-  "bn_width_inference.nnlayergolden", bn_inference_option);
+  "bn_width_inference.nnlayergolden", bn_inference_option, "nchw");
 
 GTEST_PARAMETER_TEST(BatchNormalization, LayerGoldenTest,
                      ::testing::Values(bn_basic_channels_training,
index 52c266f..3a81572 100644 (file)
@@ -26,17 +26,17 @@ GTEST_PARAMETER_TEST(Concat, LayerSemantics,
 auto concat_dim3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::ConcatLayer>, {"axis=3"},
   "2:3:3:2, 2:3:3:3", "concat_dim3.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto concat_dim2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::ConcatLayer>, {"axis=2"},
   "2:3:2:3, 2:3:3:3", "concat_dim2.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto concat_dim1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::ConcatLayer>, {"axis=1"},
   "2:2:3:3, 2:3:3:3", "concat_dim1.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(Concat, LayerGoldenTest,
                      ::testing::Values(concat_dim3, concat_dim2, concat_dim1));
index d85c4a1..4193503 100644 (file)
@@ -27,22 +27,24 @@ GTEST_PARAMETER_TEST(Convolution1D, LayerSemantics,
 auto conv1d_sb_minimum = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2"}, "1:1:1:4", "conv1d_sb_minimum.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_minimum = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2"}, "3:1:1:4", "conv1d_mb_minimum.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv1d_sb_same_remain = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv1DLayer>,
-  {"filters=2", "kernel_size=3", "padding=same"}, "1:1:1:4",
-  "conv1d_sb_same_remain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv1d_sb_same_remain =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
+                           {"filters=2", "kernel_size=3", "padding=same"},
+                           "1:1:1:4", "conv1d_sb_same_remain.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv1d_mb_same_remain = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv1DLayer>,
-  {"filters=2", "kernel_size=3", "padding=same"}, "3:1:1:4",
-  "conv1d_mb_same_remain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv1d_mb_same_remain =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
+                           {"filters=2", "kernel_size=3", "padding=same"},
+                           "3:1:1:4", "conv1d_mb_same_remain.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_sb_same_uneven_remain_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -53,7 +55,7 @@ auto conv1d_sb_same_uneven_remain_1 = LayerGoldenTestParamType(
     "padding=same",
   },
   "1:3:1:4", "conv1d_sb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_sb_same_uneven_remain_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -64,7 +66,7 @@ auto conv1d_sb_same_uneven_remain_2 = LayerGoldenTestParamType(
     "padding=0,1",
   },
   "1:3:1:4", "conv1d_sb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_same_uneven_remain_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -75,7 +77,7 @@ auto conv1d_mb_same_uneven_remain_1 = LayerGoldenTestParamType(
     "padding=same",
   },
   "3:3:1:4", "conv1d_mb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_same_uneven_remain_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -86,7 +88,7 @@ auto conv1d_mb_same_uneven_remain_2 = LayerGoldenTestParamType(
     "padding=0,1",
   },
   "3:3:1:4", "conv1d_mb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_sb_valid_drop_last =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -97,7 +99,7 @@ auto conv1d_sb_valid_drop_last =
                              "padding=valid",
                            },
                            "1:3:1:7", "conv1d_sb_valid_drop_last.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_valid_drop_last =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -108,12 +110,13 @@ auto conv1d_mb_valid_drop_last =
                              "padding=valid",
                            },
                            "3:3:1:7", "conv1d_mb_valid_drop_last.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv1d_sb_no_overlap = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv1DLayer>,
-  {"filters=3", "kernel_size=2", "stride=3"}, "1:2:1:5",
-  "conv1d_sb_no_overlap.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv1d_sb_no_overlap =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
+                           {"filters=3", "kernel_size=2", "stride=3"},
+                           "1:2:1:5", "conv1d_sb_no_overlap.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_no_overlap =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -123,22 +126,25 @@ auto conv1d_mb_no_overlap =
                              "stride=3",
                            },
                            "3:2:1:5", "conv1d_mb_no_overlap.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv1d_sb_causal = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv1DLayer>,
-  {"filters=3", "kernel_size=2", "padding=causal"}, "1:1:1:4",
-  "conv1d_sb_causal.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv1d_sb_causal =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
+                           {"filters=3", "kernel_size=2", "padding=causal"},
+                           "1:1:1:4", "conv1d_sb_causal.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv1d_mb_causal = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv1DLayer>,
-  {"filters=3", "kernel_size=2", "padding=causal"}, "3:1:1:4",
-  "conv1d_mb_causal.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv1d_mb_causal =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
+                           {"filters=3", "kernel_size=2", "padding=causal"},
+                           "3:1:1:4", "conv1d_mb_causal.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv1d_sb_1x1_kernel = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv1DLayer>,
-  {"filters=3", "kernel_size=1", "stride=2"}, "1:2:1:5",
-  "conv1d_sb_1x1_kernel.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv1d_sb_1x1_kernel =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
+                           {"filters=3", "kernel_size=1", "stride=2"},
+                           "1:2:1:5", "conv1d_sb_1x1_kernel.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_1x1_kernel =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -148,7 +154,7 @@ auto conv1d_mb_1x1_kernel =
                              "stride=2",
                            },
                            "3:2:1:5", "conv1d_mb_1x1_kernel.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_sb_dilation =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -158,7 +164,7 @@ auto conv1d_sb_dilation =
                              "dilation=2",
                            },
                            "1:3:1:11", "conv1d_sb_dilation.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_dilation =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -168,7 +174,7 @@ auto conv1d_mb_dilation =
                              "dilation=2",
                            },
                            "3:3:1:11", "conv1d_mb_dilation.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_sb_same_dilation =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -179,7 +185,7 @@ auto conv1d_sb_same_dilation =
                              "dilation=2",
                            },
                            "1:3:1:11", "conv1d_sb_same_dilation.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_same_dilation =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv1DLayer>,
@@ -190,19 +196,19 @@ auto conv1d_mb_same_dilation =
                              "dilation=2",
                            },
                            "3:3:1:11", "conv1d_mb_same_dilation.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_sb_causal_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2", "padding=causal", "dilation=2"}, "1:1:1:4",
   "conv1d_sb_causal_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv1d_mb_causal_dilation = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv1DLayer>,
   {"filters=3", "kernel_size=2", "padding=causal", "dilation=2"}, "3:1:1:4",
   "conv1d_mb_causal_dilation.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(
   Convolution1D, LayerGoldenTest,
index 2923201..bc87da3 100644 (file)
@@ -24,25 +24,29 @@ auto semantic_conv2d = LayerSemanticsParamType(
 GTEST_PARAMETER_TEST(Convolution2D, LayerSemantics,
                      ::testing::Values(semantic_conv2d));
 
-auto conv2d_sb_minimum = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv2DLayer>,
-  {"filters=3", "kernel_size=2,2"}, "1:1:4:4",
-  "conv2d_sb_minimum.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv2d_sb_minimum =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
+                           {"filters=3", "kernel_size=2,2"}, "1:1:4:4",
+                           "conv2d_sb_minimum.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv2d_mb_minimum = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv2DLayer>,
-  {"filters=3", "kernel_size=2,2"}, "3:1:4:4",
-  "conv2d_mb_minimum.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv2d_mb_minimum =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
+                           {"filters=3", "kernel_size=2,2"}, "3:1:4:4",
+                           "conv2d_mb_minimum.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv2d_sb_same_remain = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv2DLayer>,
-  {"filters=2", "kernel_size=3,3", "padding=same"}, "1:1:4:4",
-  "conv2d_sb_same_remain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv2d_sb_same_remain =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
+                           {"filters=2", "kernel_size=3,3", "padding=same"},
+                           "1:1:4:4", "conv2d_sb_same_remain.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv2d_mb_same_remain = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv2DLayer>,
-  {"filters=2", "kernel_size=3,3", "padding=same"}, "3:1:4:4",
-  "conv2d_mb_same_remain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv2d_mb_same_remain =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
+                           {"filters=2", "kernel_size=3,3", "padding=same"},
+                           "3:1:4:4", "conv2d_mb_same_remain.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_sb_same_uneven_remain_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -53,7 +57,7 @@ auto conv2d_sb_same_uneven_remain_1 = LayerGoldenTestParamType(
     "padding=same",
   },
   "1:3:4:4", "conv2d_sb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_sb_same_uneven_remain_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -64,7 +68,7 @@ auto conv2d_sb_same_uneven_remain_2 = LayerGoldenTestParamType(
     "padding=0,1,0,1",
   },
   "1:3:4:4", "conv2d_sb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_mb_same_uneven_remain_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -75,7 +79,7 @@ auto conv2d_mb_same_uneven_remain_1 = LayerGoldenTestParamType(
     "padding=same",
   },
   "3:3:4:4", "conv2d_mb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_mb_same_uneven_remain_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -86,7 +90,7 @@ auto conv2d_mb_same_uneven_remain_2 = LayerGoldenTestParamType(
     "padding=0,1,0,1",
   },
   "3:3:4:4", "conv2d_mb_same_uneven_remain.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_sb_valid_drop_last =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -97,7 +101,7 @@ auto conv2d_sb_valid_drop_last =
                              "padding=valid",
                            },
                            "1:3:7:7", "conv2d_sb_valid_drop_last.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_mb_valid_drop_last =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -108,12 +112,13 @@ auto conv2d_mb_valid_drop_last =
                              "padding=valid",
                            },
                            "3:3:7:7", "conv2d_mb_valid_drop_last.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv2d_sb_no_overlap = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv2DLayer>,
-  {"filters=3", "kernel_size=2,2", "stride=3,3"}, "1:2:5:5",
-  "conv2d_sb_no_overlap.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv2d_sb_no_overlap =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
+                           {"filters=3", "kernel_size=2,2", "stride=3,3"},
+                           "1:2:5:5", "conv2d_sb_no_overlap.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_mb_no_overlap =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -123,12 +128,13 @@ auto conv2d_mb_no_overlap =
                              "stride=3,3",
                            },
                            "3:2:5:5", "conv2d_mb_no_overlap.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
-auto conv2d_sb_1x1_kernel = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::Conv2DLayer>,
-  {"filters=3", "kernel_size=1,1", "stride=2,2"}, "1:2:5:5",
-  "conv2d_sb_1x1_kernel.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto conv2d_sb_1x1_kernel =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
+                           {"filters=3", "kernel_size=1,1", "stride=2,2"},
+                           "1:2:5:5", "conv2d_sb_1x1_kernel.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_mb_1x1_kernel =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -138,7 +144,7 @@ auto conv2d_mb_1x1_kernel =
                              "stride=2,2",
                            },
                            "3:2:5:5", "conv2d_mb_1x1_kernel.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_sb_dilation =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -148,7 +154,7 @@ auto conv2d_sb_dilation =
                              "dilation=2,2",
                            },
                            "1:3:11:11", "conv2d_sb_dilation.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_mb_dilation =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -158,7 +164,7 @@ auto conv2d_mb_dilation =
                              "dilation=2,2",
                            },
                            "3:3:11:11", "conv2d_mb_dilation.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_sb_same_dilation =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -169,7 +175,7 @@ auto conv2d_sb_same_dilation =
                              "dilation=2,2",
                            },
                            "1:3:11:11", "conv2d_sb_same_dilation.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto conv2d_mb_same_dilation =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::Conv2DLayer>,
@@ -180,7 +186,7 @@ auto conv2d_mb_same_dilation =
                              "dilation=2,2",
                            },
                            "3:3:11:11", "conv2d_mb_same_dilation.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(
   Convolution2D, LayerGoldenTest,
index 3155002..06efbcb 100644 (file)
@@ -33,21 +33,23 @@ auto dropout_20_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::DropOutLayer>, {"dropout_rate=0.2"},
   "2:3:2:3", "dropout_20_training.nnlayergolden",
   LayerGoldenTestParamOptions::DEFAULT |
-    LayerGoldenTestParamOptions::DROPOUT_MATCH_60_PERCENT);
+    LayerGoldenTestParamOptions::DROPOUT_MATCH_60_PERCENT,
+  "nchw");
 
 auto dropout_20_inference = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::DropOutLayer>, {"dropout_rate=0.2"},
-  "2:3:2:3", "dropout_20_inference.nnlayergolden", dropout_inference_option);
+  "2:3:2:3", "dropout_20_inference.nnlayergolden", dropout_inference_option,
+  "nchw");
 
 auto dropout_0_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::DropOutLayer>, {"dropout_rate=0.0"},
   "2:3:2:3", "dropout_0_training.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto dropout_100_training = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::DropOutLayer>, {"dropout_rate=1.0"},
   "2:3:2:3", "dropout_100_training.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(Dropout, LayerGoldenTest,
                      ::testing::Values(dropout_20_training, dropout_0_training,
index df929c9..c25168a 100644 (file)
@@ -27,15 +27,16 @@ GTEST_PARAMETER_TEST(FullyConnected, LayerSemantics,
 
 auto fc_basic_plain = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>, {"unit=5"},
-  "3:1:1:10", "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "3:1:1:10", "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 auto fc_basic_single_batch = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>, {"unit=4"},
   "1:1:1:10", "fc_single_batch.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 auto fc_basic_no_decay = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::FullyConnectedLayer>,
   {"unit=5", "weight_decay=0.0", "bias_decay=0.0"}, "3:1:1:10",
-  "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "fc_plain.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(FullyConnected, LayerGoldenTest,
                      ::testing::Values(fc_basic_plain, fc_basic_single_batch,
index e856b4a..f61d2a6 100644 (file)
@@ -26,81 +26,82 @@ GTEST_PARAMETER_TEST(GRU, LayerSemantics, ::testing::Values(semantic_gru));
 auto gru_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "integrate_bias=true", "reset_after=false"}, "3:1:1:7",
-  "gru_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "gru_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto gru_multi_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "integrate_bias=true", "reset_after=false"}, "3:1:4:7",
-  "gru_multi_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "gru_multi_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_single_step_seq =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::GRULayer>,
                            {"unit=5", "return_sequences=true",
                             "integrate_bias=true", "reset_after=false"},
                            "3:1:1:7", "gru_single_step_seq.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_multi_step_seq =
   LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::GRULayer>,
                            {"unit=5", "return_sequences=true",
                             "integrate_bias=true", "reset_after=false"},
                            "3:1:4:7", "gru_multi_step_seq.nnlayergolden",
-                           LayerGoldenTestParamOptions::DEFAULT);
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_multi_step_seq_act_orig = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "hidden_state_activation=tanh",
    "recurrent_activation=sigmoid", "integrate_bias=true", "reset_after=false"},
   "3:1:4:7", "gru_multi_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_multi_step_seq_act = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "hidden_state_activation=sigmoid",
    "recurrent_activation=tanh", "integrate_bias=true", "reset_after=false"},
   "3:1:4:7", "gru_multi_step_seq_act.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 // Check reset_after
 auto gru_reset_after_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "integrate_bias=false", "reset_after=true"}, "3:1:1:7",
   "gru_reset_after_single_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_reset_after_multi_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "integrate_bias=false", "reset_after=true"}, "3:1:4:7",
   "gru_reset_after_multi_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_reset_after_single_step_seq = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "integrate_bias=false",
    "reset_after=true"},
   "3:1:1:7", "gru_reset_after_single_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_reset_after_multi_step_seq = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "integrate_bias=false",
    "reset_after=true"},
   "3:1:4:7", "gru_reset_after_multi_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_reset_after_multi_step_seq_act_orig = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "hidden_state_activation=tanh",
    "recurrent_activation=sigmoid", "integrate_bias=false", "reset_after=true"},
   "3:1:4:7", "gru_reset_after_multi_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto gru_reset_after_multi_step_seq_act = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRULayer>,
   {"unit=5", "return_sequences=true", "hidden_state_activation=sigmoid",
    "recurrent_activation=tanh", "integrate_bias=false", "reset_after=true"},
   "3:1:4:7", "gru_reset_after_multi_step_seq_act.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(
   GRU, LayerGoldenTest,
index 6b77111..25bfdb4 100644 (file)
@@ -28,20 +28,21 @@ GTEST_PARAMETER_TEST(GRUCell, LayerSemantics,
 auto grucell_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRUCellLayer>,
   {"unit=5", "integrate_bias=true", "reset_after=false"}, "3:1:1:7,3:1:1:5",
-  "grucell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "grucell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto grucell_reset_after_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRUCellLayer>,
   {"unit=5", "integrate_bias=false", "reset_after=true"}, "3:1:1:7,3:1:1:5",
   "grucell_reset_after_single_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto grucell_single_step_act = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::GRUCellLayer>,
   {"unit=5", "integrate_bias=true", "reset_after=false",
    "hidden_state_activation=sigmoid", "recurrent_activation=tanh"},
   "3:1:1:7,3:1:1:5", "grucell_single_step_act.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(GRUCell, LayerGoldenTest,
                      ::testing::Values(grucell_single_step,
index eac719b..cb34c13 100644 (file)
@@ -26,32 +26,38 @@ GTEST_PARAMETER_TEST(LayerNormalization, LayerSemantics,
 
 auto ln_axis_1 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1"},
-  "2:4:2:3", "ln_axis_1.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "2:4:2:3", "ln_axis_1.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto ln_axis_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=2"},
-  "2:4:2:3", "ln_axis_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "2:4:2:3", "ln_axis_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto ln_axis_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=3"},
-  "2:4:2:3", "ln_axis_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "2:4:2:3", "ln_axis_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto ln_axis_1_2 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 2"},
-  "2:4:2:3", "ln_axis_1_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "2:4:2:3", "ln_axis_1_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto ln_axis_2_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=2, 3"},
-  "2:4:2:3", "ln_axis_2_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "2:4:2:3", "ln_axis_2_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto ln_axis_1_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 3"},
-  "2:4:2:3", "ln_axis_1_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "2:4:2:3", "ln_axis_1_3.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto ln_axis_1_2_3 = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LayerNormalizationLayer>, {"axis=1, 2, 3"},
   "2:4:2:3", "ln_axis_1_2_3.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(LayerNormalization, LayerGoldenTest,
                      ::testing::Values(ln_axis_1, ln_axis_2, ln_axis_3,
index 16948a6..a8234bf 100644 (file)
@@ -23,39 +23,42 @@ auto semantic_lstm = LayerSemanticsParamType(
 
 GTEST_PARAMETER_TEST(LSTM, LayerSemantics, ::testing::Values(semantic_lstm));
 
-auto lstm_single_step = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::LSTMLayer>,
-  {"unit=5", "integrate_bias=true"}, "3:1:1:7",
-  "lstm_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto lstm_single_step =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::LSTMLayer>,
+                           {"unit=5", "integrate_bias=true"}, "3:1:1:7",
+                           "lstm_single_step.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto lstm_multi_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true"}, "3:1:4:7", "lstm_multi_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto lstm_single_step_seq = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true", "return_sequences=true"}, "3:1:1:7",
-  "lstm_single_step_seq.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "lstm_single_step_seq.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto lstm_multi_step_seq = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true", "return_sequences=true"}, "3:1:4:7",
-  "lstm_multi_step_seq.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "lstm_multi_step_seq.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto lstm_multi_step_seq_act_orig = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true", "return_sequences=true",
    "hidden_state_activation=tanh", "recurrent_activation=sigmoid"},
   "3:1:4:7", "lstm_multi_step_seq.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto lstm_multi_step_seq_act = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMLayer>,
   {"unit=5", "integrate_bias=true", "return_sequences=true",
    "hidden_state_activation=sigmoid", "recurrent_activation=tanh"},
   "3:1:4:7", "lstm_multi_step_seq_act.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(LSTM, LayerGoldenTest,
                      ::testing::Values(lstm_single_step, lstm_multi_step,
index 5f3d755..a7387b3 100644 (file)
@@ -27,7 +27,8 @@ GTEST_PARAMETER_TEST(LSTMCell, LayerSemantics,
 auto lstmcell_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::LSTMCellLayer>,
   {"unit=5", "integrate_bias=true"}, "3:1:1:7,3:1:1:5,3:1:1:5",
-  "lstmcell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "lstmcell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 GTEST_PARAMETER_TEST(LSTMCell, LayerGoldenTest,
                      ::testing::Values(lstmcell_single_step));
index b5b19b5..6f0aecf 100644 (file)
@@ -37,12 +37,13 @@ auto multi_head_attention_single_batch = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3"}, "1:1:5:7,1:1:3:7,1:1:3:7",
   "multi_head_attention_single_batch.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto multi_head_attention = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3"}, "2:1:5:7,2:1:3:7,2:1:3:7",
-  "multi_head_attention.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "multi_head_attention.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 auto multi_head_attention_return_attention_scores = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
@@ -50,19 +51,19 @@ auto multi_head_attention_return_attention_scores = LayerGoldenTestParamType(
    "average_attention_weight=false"},
   "2:1:5:7,2:1:3:7,2:1:3:7",
   "multi_head_attention_return_attention_scores.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto multi_head_attention_value_dim = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3", "projected_value_dim=5"},
   "2:1:5:7,2:1:3:7,2:1:3:7", "multi_head_attention_value_dim.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto multi_head_attention_output_shape = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::MultiHeadAttentionLayer>,
   {"num_heads=2", "projected_key_dim=3", "output_shape=5"},
   "2:1:5:7,2:1:3:7,2:1:3:7", "multi_head_attention_output_shape.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(
   MultiHeadAttention, LayerGoldenTest,
index 728de02..95ce618 100644 (file)
@@ -28,12 +28,12 @@ INSTANTIATE_TEST_CASE_P(PositionalEncoding, LayerSemantics,
 auto positional_encoding_partial = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::PositionalEncodingLayer>,
   {"max_timestep=10"}, "3:1:7:6", "positional_encoding_partial.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 auto positional_encoding = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::PositionalEncodingLayer>,
   {"max_timestep=10"}, "3:1:10:6", "positional_encoding.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 INSTANTIATE_TEST_CASE_P(PositionalEncoding, LayerGoldenTest,
                         ::testing::Values(positional_encoding_partial,
index a8aba49..9c32b5d 100644 (file)
@@ -26,6 +26,7 @@ GTEST_PARAMETER_TEST(RNN, LayerSemantics, ::testing::Values(semantic_rnn));
 auto rnn_single_step = LayerGoldenTestParamType(
   nntrainer::createLayer<nntrainer::RNNLayer>,
   {"unit=5", "return_sequences=false", "integrate_bias=true"}, "3:1:1:7",
-  "rnn_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+  "rnn_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT,
+  "nchw");
 
 GTEST_PARAMETER_TEST(RNN, LayerGoldenTest, ::testing::Values(rnn_single_step));
index 4cee9ea..97c0ae1 100644 (file)
@@ -24,10 +24,11 @@ auto semantic_rnncell = LayerSemanticsParamType(
 GTEST_PARAMETER_TEST(RNNCell, LayerSemantics,
                      ::testing::Values(semantic_rnncell));
 
-auto rnncell_single_step = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::RNNCellLayer>,
-  {"unit=5", "integrate_bias=true"}, "3:1:1:7,3:1:1:5",
-  "rnncell_single_step.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT);
+auto rnncell_single_step =
+  LayerGoldenTestParamType(nntrainer::createLayer<nntrainer::RNNCellLayer>,
+                           {"unit=5", "integrate_bias=true"}, "3:1:1:7,3:1:1:5",
+                           "rnncell_single_step.nnlayergolden",
+                           LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 GTEST_PARAMETER_TEST(RNNCell, LayerGoldenTest,
                      ::testing::Values(rnncell_single_step));
index 91a18b3..f291486 100644 (file)
@@ -30,7 +30,7 @@ auto zoneout_lstmcell_single_step = LayerGoldenTestParamType(
   {"unit=5", "integrate_bias=true", "hidden_state_zoneout_rate=0.1",
    "cell_state_zoneout_rate=0.0"},
   "3:1:1:7,3:1:1:5,3:1:1:5", "zoneout_lstmcell_single_step.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT);
+  LayerGoldenTestParamOptions::DEFAULT, "nchw");
 
 INSTANTIATE_TEST_CASE_P(LSTMCell, LayerGoldenTest,
                         ::testing::Values(zoneout_lstmcell_single_step));
index a805568..de9afc3 100644 (file)
@@ -50,6 +50,9 @@ static nntrainer::IniSection relu_base = act_base + "Activation = relu";
 static nntrainer::IniSection bn_base("bn", "Type=batch_normalization");
 static nntrainer::IniSection sgd_base("optimizer", "Type = sgd");
 
+static nntrainer::IniSection nn_base_nhwc = nn_base + "tensor_type=NHWC";
+static nntrainer::IniSection nn_base_nchw = nn_base + "tensor_type=NCHW";
+
 using I = nntrainer::IniSection;
 using INI = nntrainer::IniWrapper;
 
@@ -92,7 +95,7 @@ using INI = nntrainer::IniWrapper;
 
 INI fc_sigmoid_baseline(
   "fc_sigmoid",
-  {nn_base + "batch_size = 3",
+  {nn_base_nchw + "batch_size = 3",
    sgd_base + "learning_rate = 1",
    I("input") + input_base + "input_shape = 1:1:3",
    I("dense") + fc_base + "unit = 5",