[batchsize] Change the semantics of batch size
authorParichay Kapoor <pk.kapoor@samsung.com>
Thu, 24 Sep 2020 05:09:50 +0000 (14:09 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Sun, 4 Oct 2020 23:19:00 +0000 (08:19 +0900)
This patch changes the semantics of the way batchsize is used in the library
1. batch_size is no longer property of the layer. It can still be set externally by the model.
The method to set is still public but will soon be changed to private.
2. batch_size of the input/label/derivative tensor provided to the forwarding/backwarding function
can no longer be arbitrary. It must match the batch_size set to the model and the layer.
This change in semantics to follow for a long-term design where memory for input/output
is pre-allocated.
3. batch_size can now be set at train time than the earlier design where the batch_size
had to be set at init time. This comes from the design change that the memory for the model weights
can be allocated init time, which is not dependent on the batch size. However, the memory for
input/output should be allocated at train/inference as the batch size can be different for these.
In the current design, memory is allocated every iteration. Later, when memory is allocated at once
and reused, change in batch size will change the memory at once at the first iteration (rather doing
this at init times). This change is necessary to allow running inference without the need to initialize
the model again.

V2:
Updated validation to run on a whole batch of data at once
Also updated Tensor.argmax() to perform on batch of data than on the whole data

**Self evaluation:**
1. Build test: [x]Passed [ ]Failed [ ]Skipped
2. Run test: [x]Passed [ ]Failed [ ]Skipped

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
13 files changed:
Applications/MNIST/Tensorflow/model.bin [new file with mode: 0644]
nntrainer/include/layer.h
nntrainer/include/neuralnet.h
nntrainer/include/tensor.h
nntrainer/src/addition_layer.cpp
nntrainer/src/conv2d_layer.cpp
nntrainer/src/flatten_layer.cpp
nntrainer/src/layer.cpp
nntrainer/src/neuralnet.cpp
nntrainer/src/parse_util.cpp
nntrainer/src/pooling2d_layer.cpp
nntrainer/src/tensor.cpp
test/unittest/unittest_nntrainer_layers.cpp

diff --git a/Applications/MNIST/Tensorflow/model.bin b/Applications/MNIST/Tensorflow/model.bin
new file mode 100644 (file)
index 0000000..9281731
Binary files /dev/null and b/Applications/MNIST/Tensorflow/model.bin differ
index 1e5246b28a95b9e616b58d49f7205ec3301b99d3..a2233d6e65a9b428be99ec3715dbcf8cfec56537 100644 (file)
@@ -183,12 +183,11 @@ public:
    *            17. name : string (type)
    *            18. num_inputs : unsigned int (minimum 1)
    *            19. num_outputs : unsigned int (minimum 1)
-   *            20. batch_size : unsigned int (minimum 1)
-   *            21. momentum : float,
-   *            22. moving_mean_initializer : string (type),
-   *            23. moving_variance_initializer : string (type),
-   *            24. gamma_initializer : string (type),
-   *            25. beta_initializer" : string (type)
+   *            20. momentum : float,
+   *            21. moving_mean_initializer : string (type),
+   *            22. moving_variance_initializer : string (type),
+   *            23. gamma_initializer : string (type),
+   *            24. beta_initializer" : string (type)
    */
   enum class PropertyType {
     input_shape = 0,
@@ -211,12 +210,11 @@ public:
     name = 17,
     num_inputs = 18,
     num_outputs = 19,
-    batch_size = 20,
-    momentum = 21,
-    moving_mean_initializer = 22,
-    moving_variance_initializer = 23,
-    gamma_initializer = 24,
-    beta_initializer = 25,
+    momentum = 20,
+    moving_mean_initializer = 21,
+    moving_variance_initializer = 22,
+    gamma_initializer = 23,
+    beta_initializer = 24,
     unknown
   };
 
@@ -315,8 +313,7 @@ public:
   /**
    * @brief Set the batch for the layer
    * @param batch Batch value to be set
-   * @note This denotes the maximum batch size of input. The actual batchsize
-   * of the data can be smaller in case of validation or testing
+   * @todo Make this private. Only model should be able to do this.
    */
   void setBatch(unsigned int batch) {
     input_dim.setTensorDim(0, batch);
index 32b3f8b47bfee31800c9afaf4bb59a021e07dddc..4035c68a9d4d14b03679f14d6184a2c3b26a8bf2 100644 (file)
@@ -422,6 +422,11 @@ private:
    * @retval #ML_ERROR_INVALID_PARAMETER invalid parameter.
    */
   int setTrainConfig(std::vector<std::string> values);
+
+  /**
+   * @brief     Update batch size of the model as well as its layers/dataset
+   */
+  void setBatchSize(unsigned int batch_size);
 };
 
 } /* namespace nntrainer */
index 6d1a44ae9aecc703d469d33bd145884e9439526a..260df2de04142ebf48d44394d5c4e017a55fae7b 100644 (file)
@@ -508,10 +508,10 @@ public:
   void read(std::ifstream &file);
 
   /**
-   * @brief     return argument index which value is max
+   * @brief     return argument index which value is max by batch
    * @retval    unsigned int argument index
    */
-  unsigned int argmax() const;
+  std::vector<unsigned int> argmax() const;
 
   /**
    * @brief     return a copy of the Tensor Dim
index 6da5a55b976f04ba0a038bc4a5711dd618fea56a..bb73ea838a926e2934516495d307f3d800a45769 100644 (file)
@@ -38,8 +38,7 @@ int AdditionLayer::initialize() {
 }
 
 sharedConstTensor AdditionLayer::forwarding(sharedConstTensor in) {
-  hidden = Tensor(in->batch(), output_dim.channel(), output_dim.height(),
-                  output_dim.width());
+  hidden = Tensor(input_dim);
   hidden.setZero();
 
   for (unsigned int idx = 0; idx < num_inputs; ++idx) {
index 8a7590cdce97ca57a117471daf61a3a7b74d45ae..6cecc332e7d1036f384a7e404842c51e2c5c7786 100644 (file)
@@ -77,7 +77,6 @@ sharedConstTensor Conv2DLayer::forwarding(sharedConstTensor in) {
   }
 
   TensorDim hidden_dim = output_dim;
-  hidden_dim.batch(in->batch());
   hidden = Tensor(hidden_dim);
   hidden.setZero();
 
@@ -130,7 +129,7 @@ sharedConstTensor Conv2DLayer::forwarding(sharedConstTensor in) {
            kdim.getFeatureLen() * sizeof(float));
   }
 
-  for (unsigned int b = 0; b < input.batch(); ++b) {
+  for (unsigned int b = 0; b < input_dim.batch(); ++b) {
     std::vector<float> out(output_dim.getFeatureLen());
     Tensor inSub(TensorDim(1, input.channel(), input.height(), input.width()),
                  input.getAddress(b * input.getDim().getFeatureLen()));
index 75daa62289adc990f50e7c4c3ba1f2e0c985142d..ff675f3515ec5862df92ee63307c738e854da061 100644 (file)
@@ -38,9 +38,7 @@ sharedConstTensor FlattenLayer::forwarding(sharedConstTensor in) {
   input = *in;
   hidden = input;
 
-  /// @note in->batch can be different from input_dim.batch();
-  hidden.reshape({in->batch(), output_dim.channel(), output_dim.height(),
-                  output_dim.width()});
+  hidden.reshape(output_dim);
 
   return MAKE_SHARED_TENSOR(hidden);
 }
index e3849b778e63722beebd7d36ec004985630a069b..7f1a8984d70a82451ce12444bbae4c1b96a16e3d 100644 (file)
@@ -136,14 +136,6 @@ void Layer::setProperty(const PropertyType type, const std::string &value) {
       throw_status(status);
     }
     break;
-  case PropertyType::batch_size:
-    if (!value.empty()) {
-      unsigned int batch_size;
-      status = setUint(batch_size, value);
-      throw_status(status);
-      input_dim.batch(batch_size);
-    }
-    break;
   case PropertyType::activation:
     if (!value.empty()) {
       status = setActivation((ActivationType)parseType(value, TOKEN_ACTI));
index d010806099b4d9eb9b557ff7c046684280fc97e6..8e34598b374f835c777937ae975c354858844f43 100644 (file)
@@ -121,14 +121,6 @@ int NeuralNetwork::setProperty(std::vector<std::string> values) {
     unsigned int type = parseNetProperty(key);
 
     switch (static_cast<PropertyType>(type)) {
-    case PropertyType::batch_size: {
-      status = setUint(batch_size, value);
-      NN_RETURN_STATUS();
-      /** TODO: increase buffer size if it is smaller than batch size.
-       * also if this is set with default batch size, then make it
-       * smaller/larger
-       */
-    } break;
     case PropertyType::loss: {
       status = setFloat(loss, value);
       NN_RETURN_STATUS();
@@ -172,6 +164,14 @@ int NeuralNetwork::setTrainConfig(std::vector<std::string> values) {
       continue_train = cont_train;
       opt.setProperty({values[i]});
     } break;
+    case PropertyType::batch_size: {
+      status = setUint(batch_size, value);
+      NN_RETURN_STATUS();
+      /** TODO: increase buffer size if it is smaller than batch size.
+       * also if this is set with default batch size, then make it
+       * smaller/larger
+       */
+    } break;
     default:
       ml_loge("Error: Unknown Network Property Key");
       status = ML_ERROR_INVALID_PARAMETER;
@@ -211,7 +211,6 @@ int NeuralNetwork::init() {
       }
     }
 
-    layers[i]->setBatch(batch_size);
     status = layers[i]->initialize();
 
     switch (l.getType()) {
@@ -366,7 +365,24 @@ void NeuralNetwork::readModel() {
   ml_logi("read modelfile: %s", save_path.c_str());
 }
 
+void NeuralNetwork::setBatchSize(unsigned int batch) {
+  batch_size = batch;
+  for (auto const &layer : layers)
+    layer->setBatch(batch_size);
+
+  if (data_buffer && data_buffer->setBatchSize(batch_size) != ML_ERROR_NONE)
+    throw std::invalid_argument("Error setting batchsize for the dataset");
+}
+
 sharedConstTensor NeuralNetwork::inference(const Tensor X) {
+  if (batch_size != X.batch()) {
+    /**
+     * Note that inference resets batch_size of the previous train configuration
+     * Next train must set its batch_size if inference is run with this model.
+     */
+    setBatchSize(X.batch());
+  }
+
   sharedConstTensor out;
   try {
     out = forwarding(MAKE_SHARED_TENSOR(X));
@@ -388,10 +404,13 @@ int NeuralNetwork::train(std::vector<std::string> values) {
     return ML_ERROR_INVALID_PARAMETER;
   }
 
-  /** Setup data buffer properties */
-  status = data_buffer->setBatchSize(batch_size);
+  status = setTrainConfig(values);
   NN_RETURN_STATUS();
 
+  /** set batch size just before training */
+  setBatchSize(batch_size);
+
+  /** Setup data buffer properties */
   status =
     data_buffer->setClassNum(layers.back()->getOutputDimension().width());
   NN_RETURN_STATUS();
@@ -402,9 +421,6 @@ int NeuralNetwork::train(std::vector<std::string> values) {
   status = data_buffer->init();
   NN_RETURN_STATUS();
 
-  status = setTrainConfig(values);
-  NN_RETURN_STATUS();
-
   return train_run();
 }
 
@@ -475,15 +491,17 @@ int NeuralNetwork::train_run() {
       while (true) {
         vec_4d in, label;
         if (data_buffer->getDataFromBuffer(nntrainer::BUF_VAL, in, label)) {
-          for (unsigned int b = 0; b < batch_size; ++b) {
-            sharedTensor X = MAKE_SHARED_TENSOR(Tensor({in[b]}));
-            sharedTensor Y2 = MAKE_SHARED_TENSOR(Tensor({label[b]}));
-            sharedConstTensor Y = forwarding(X, Y2);
-            if (Y->argmax() == Y2->argmax())
+          sharedTensor X = MAKE_SHARED_TENSOR(Tensor({in}));
+          sharedTensor Y2 = MAKE_SHARED_TENSOR(Tensor({label}));
+          sharedConstTensor Y = forwarding(X, Y2);
+          auto model_out = Y->argmax();
+          auto label_out = Y2->argmax();
+          for (unsigned int b = 0; b < batch_size; b++) {
+            if (model_out[b] == label_out[b])
               right++;
-            validation.loss += getLoss();
-            tcases++;
           }
+          validation.loss += getLoss();
+          tcases++;
         } else {
           data_buffer->clear(nntrainer::BUF_VAL);
           break;
@@ -496,7 +514,7 @@ int NeuralNetwork::train_run() {
         return status;
       }
       validation.loss /= (float)(tcases);
-      validation.accuracy = right / (float)(tcases)*100.0f;
+      validation.accuracy = right / (float)(tcases * batch_size) * 100.0f;
       std::cout << " >> [ Accuracy: " << validation.accuracy
                 << "% - Validation Loss : " << validation.loss << " ] ";
     }
index 17169f1ef943620a2863820ff245033fcd9d8436..de78a361f49f107d1539b2c162390057712c2e56 100644 (file)
@@ -270,12 +270,11 @@ unsigned int parseType(std::string ll, InputType t) {
  * name = 17
  * num_inputs = 18
  * num_outputs = 19
- * batch_size = 20
- * momentum = 21
- * moving_mean_initializer = 22
- * moving_variance_initializer = 23
- * gamma_initializer = 24
- * beta_initializer = 25
+ * momentum = 20
+ * moving_mean_initializer = 21
+ * moving_variance_initializer = 22
+ * gamma_initializer = 23
+ * beta_initializer = 24
  *
  * InputLayer has 0, 1, 2, 3 properties.
  * FullyConnectedLayer has 1, 4, 6, 7, 8, 9 properties.
@@ -283,7 +282,7 @@ unsigned int parseType(std::string ll, InputType t) {
  * Pooling2DLayer has 12, 13, 14, 15 properties.
  * BatchNormalizationLayer has 0, 1, 5, 6, 7 properties.
  */
-static std::array<std::string, 27> property_string = {
+static std::array<std::string, 26> property_string = {
   "input_shape",
   "normalization",
   "standardization",
@@ -304,7 +303,6 @@ static std::array<std::string, 27> property_string = {
   "name",
   "num_inputs",
   "num_outputs",
-  "batch_size",
   "momentum",
   "moving_mean_initializer",
   "moving_variance_initializer",
index 2ef813170e32d42621aee5bf3527f98eb3267115..d17801062da5b8458f48c83ac60c91c457a32eda 100644 (file)
@@ -59,11 +59,10 @@ sharedConstTensor Pooling2DLayer::forwarding(sharedConstTensor in) {
   input = *in;
 
   TensorDim hidden_dim = output_dim;
-  hidden_dim.batch(in->batch());
   hidden = Tensor(hidden_dim);
   hidden.setZero();
 
-  for (unsigned int b = 0; b < input.batch(); ++b) {
+  for (unsigned int b = 0; b < input_dim.batch(); ++b) {
     Tensor in_padded = zero_pad(b, input, padding);
     Tensor result = pooling2d(b, in_padded);
     memcpy(hidden.getAddress(b * hidden.getDim().getFeatureLen()),
index 4674910dc1fe0f0a24564ccebc363b3faa1ee338..16c1378403f3d217d8317efee5fa11f45df01563 100644 (file)
@@ -706,10 +706,21 @@ void Tensor::setValue(float val) {
 
 void Tensor::setZero() { setValue(0); }
 
-unsigned int Tensor::argmax() const {
+std::vector<unsigned int> Tensor::argmax() const {
   const float *data = getData();
-  auto max_iter = std::max_element(data, data + length());
-  return std::distance(data, max_iter);
+  std::vector<unsigned int> result;
+  unsigned int batch_size = batch();
+  unsigned int feature_len = dim.getFeatureLen();
+
+  result.reserve(batch_size);
+
+  for (unsigned int b = 0; b < batch_size; b++) {
+    auto max_iter =
+      std::max_element(data + b * feature_len, data + (b + 1) * feature_len);
+    result[b] = std::distance(data, max_iter);
+  }
+
+  return result;
 }
 
 float Tensor::l2norm() const {
index 44d1fd916a21166ce00adf96b03ce52211ec978d..97678b6315d898fd361b2a853094f15f77a24789 100644 (file)
@@ -48,10 +48,11 @@ protected:
     return status;
   }
 
-  virtual int reinitialize(const std::string str) {
+  virtual int reinitialize(const std::string str, int batch_size = 1) {
     resetLayer();
     int status = setProperty(str);
     EXPECT_EQ(status, ML_ERROR_NONE);
+    setBatch(batch_size);
     status = reinitialize();
     EXPECT_EQ(status, ML_ERROR_NONE);
     return status;
@@ -69,6 +70,8 @@ protected:
     layer.setInputDimension(dim);
   }
 
+  void setBatch(unsigned int batch) { layer.setBatch(batch); }
+
   void matchOutput(const nntrainer::Tensor &result,
                    const nntrainer::Tensor &golden) {
     const float *out_ptr, *golden_ptr;
@@ -175,7 +178,7 @@ class nntrainer_InputLayer
 protected:
   virtual void prepareLayer() {
     setInputDim("3:28:28");
-    setProperty("batch_size=1");
+    setBatch(1);
   }
 };
 
@@ -234,7 +237,7 @@ TEST_F(nntrainer_InputLayer, set_property_05_p) {
   nntrainer::TensorDim dim;
   int status = ML_ERROR_NONE;
 
-  status = setProperty("batch_size=5");
+  setBatch(5);
   EXPECT_EQ(status, ML_ERROR_NONE);
 
   dim = layer.getInputDimension();
@@ -308,7 +311,7 @@ class nntrainer_FullyConnectedLayer
 protected:
   virtual void prepareLayer() {
     setInputDim("1:28:28");
-    setProperty("batch_size=32");
+    setBatch(5);
     setProperty("unit=1");
   }
 };
@@ -538,7 +541,7 @@ protected:
 
   virtual void prepareLayer() {
     setInputDim("1:1:12");
-    setProperty("batch_size=3");
+    setBatch(3);
     setProperty("unit=15");
     setProperty("bias_initializer=zeros");
   }
@@ -801,8 +804,8 @@ protected:
   }
 
   virtual void prepareLayer() {
-    setProperty(
-      "input_shape=1:1:12 | epsilon=0.001 | batch_size=3 | momentum=0.90");
+    setProperty("input_shape=1:1:12 | epsilon=0.001 | momentum=0.90");
+    setBatch(3);
     setOptimizer(nntrainer::OptType::sgd, "learning_rate=1");
   }
 };
@@ -881,8 +884,8 @@ protected:
   }
 
   virtual void prepareLayer() {
-    setProperty(
-      "input_shape=2:4:5 | epsilon=0.001 | batch_size=3 | momentum=0.90");
+    setProperty("input_shape=2:4:5 | epsilon=0.001 | momentum=0.90");
+    setBatch(3);
     setOptimizer(nntrainer::OptType::sgd, "learning_rate=1");
   }
 };
@@ -916,8 +919,8 @@ protected:
   }
 
   virtual void prepareLayer() {
-    setProperty(
-      "input_shape=2:4:5 | epsilon=0.001 | batch_size=1 | momentum=0.90");
+    setProperty("input_shape=2:4:5 | epsilon=0.001 | momentum=0.90");
+    setBatch(1);
     setOptimizer(nntrainer::OptType::sgd, "learning_rate=1");
   }
 };
@@ -947,7 +950,7 @@ protected:
 
   virtual void prepareLayer() {
     int status =
-      setProperty("input_shape=3:28:28 | batch_size=32 |"
+      setProperty("input_shape=3:28:28 |"
                   "bias_initializer=zeros |"
                   "activation=sigmoid |"
                   "weight_regularizer=l2norm |"
@@ -957,6 +960,7 @@ protected:
                   "filters=12 | kernel_size= 5,5 | stride=3,3 | padding=1,1");
 
     EXPECT_EQ(status, ML_ERROR_NONE);
+    setBatch(32);
   }
 
   nntrainer::Tensor result;
@@ -1014,7 +1018,7 @@ TEST_F(nntrainer_Conv2DLayer, save_read_01_p) {
  * @brief Convolution 2D Layer
  */
 TEST_F(nntrainer_Conv2DLayer, forwarding_01_p) {
-  reinitialize("input_shape=3:7:7 | batch_size=1 |"
+  reinitialize("input_shape=3:7:7 |"
                "bias_initializer = zeros |"
                "weight_initializer=xavier_uniform |"
                "filters=2 | kernel_size=3,3 | stride=1, 1 | padding=0,0");
@@ -1035,10 +1039,11 @@ TEST_F(nntrainer_Conv2DLayer, forwarding_01_p) {
 
 TEST_F(nntrainer_Conv2DLayer, forwarding_02_p) {
   status =
-    reinitialize("input_shape=3:7:7 | batch_size=2 |"
+    reinitialize("input_shape=3:7:7 |"
                  "bias_initializer = zeros |"
                  "weight_initializer=xavier_uniform |"
-                 "filters=3 | kernel_size=3,3 | stride=1, 1 | padding=0,0");
+                 "filters=3 | kernel_size=3,3 | stride=1, 1 | padding=0,0",
+                 2);
 
   ASSERT_EQ(in.getDim(), nntrainer::TensorDim(2, 3, 7, 7));
   ASSERT_EQ(out.getDim(), nntrainer::TensorDim(2, 3, 5, 5));
@@ -1051,7 +1056,7 @@ TEST_F(nntrainer_Conv2DLayer, forwarding_02_p) {
 }
 
 TEST_F(nntrainer_Conv2DLayer, backwarding_01_p) {
-  status = reinitialize("input_shape=3:7:7 | batch_size=1 |"
+  status = reinitialize("input_shape=3:7:7 |"
                         "bias_initializer=zeros |"
                         "weight_initializer=xavier_uniform |"
                         "filters=2 |"
@@ -1104,7 +1109,7 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_01_p) {
 }
 
 TEST_F(nntrainer_Conv2DLayer, backwarding_04_p) {
-  status = reinitialize("input_shape=6:24:24 | batch_size=1 |"
+  status = reinitialize("input_shape=6:24:24 |"
                         "bias_initializer=zeros |"
                         "weight_initializer=xavier_uniform |"
                         "filters=12 |"
@@ -1157,13 +1162,14 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_04_p) {
 }
 
 TEST_F(nntrainer_Conv2DLayer, backwarding_02_p) {
-  status = reinitialize("input_shape=3:7:7 | batch_size=2 |"
+  status = reinitialize("input_shape=3:7:7 |"
                         "bias_initializer=zeros |"
                         "weight_initializer=xavier_uniform |"
                         "filters=3 |"
                         "kernel_size= 3,3 |"
                         "stride=1, 1 |"
-                        "padding=0,0");
+                        "padding=0,0",
+                        2);
 
   setOptimizer(nntrainer::OptType::sgd, "learning_rate=1.0");
 
@@ -1241,7 +1247,7 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_02_p) {
 
 #ifdef USE_BLAS
 TEST_F(nntrainer_Conv2DLayer, backwarding_03_p) {
-  status = reinitialize("input_shape=3:28:28 | batch_size=1 |"
+  status = reinitialize("input_shape=3:28:28 |"
                         "bias_initializer=zeros |"
                         "weight_initializer=zeros |"
                         "filters=6 |"
@@ -1250,19 +1256,21 @@ TEST_F(nntrainer_Conv2DLayer, backwarding_03_p) {
                         "padding=0, 0");
 
   nntrainer::Conv2DLayer layer1;
-  status = layer1.setProperty(
-    {"input_shape=3:28:28", "batch_size=1", "bias_initializer=zeros",
-     "weight_initializer=zeros", "filters=6", "kernel_size= 5,5", "stride=1, 1",
-     "padding=0, 0"});
+  status =
+    layer1.setProperty({"input_shape=3:28:28", "bias_initializer=zeros",
+                        "weight_initializer=zeros", "filters=6",
+                        "kernel_size= 5,5", "stride=1, 1", "padding=0, 0"});
   EXPECT_EQ(status, ML_ERROR_NONE);
+  layer1.setBatch(1);
   status = layer1.initialize();
   EXPECT_EQ(status, ML_ERROR_NONE);
 
   nntrainer::Conv2DLayer layer2;
   status = layer2.setProperty(
-    {"batch_size=1", "bias_initializer=zeros", "weight_initializer=zeros",
-     "filters=12", "kernel_size= 1,1", "stride=1, 1", "padding=0, 0"});
+    {"bias_initializer=zeros", "weight_initializer=zeros", "filters=12",
+     "kernel_size= 1,1", "stride=1, 1", "padding=0, 0"});
   EXPECT_EQ(status, ML_ERROR_NONE);
+  layer2.setBatch(1);
   layer2.setInputDimension(layer1.getOutputDimension());
   status = layer2.initialize();
   EXPECT_EQ(status, ML_ERROR_NONE);
@@ -1389,13 +1397,13 @@ protected:
 
 TEST_F(nntrainer_Pooling2DLayer, setProperty_01_p) {
   setInputDim("3:5:5");
-  setProperty("batch_size=2");
+  setBatch(2);
   setProperty("pool_size=2,2 | stride=1,1 | padding=0,0 | pooling=average");
 }
 
 TEST_F(nntrainer_Pooling2DLayer, setProperty_02_n) {
   setInputDim("3:5:5");
-  setProperty("batch_size=2");
+  setBatch(2);
   int status = layer.setProperty({"pool_size="});
   EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
 }
@@ -1457,7 +1465,7 @@ TEST_F(nntrainer_Pooling2DLayer, forwarding_04_p) {
 TEST_F(nntrainer_Pooling2DLayer, forwarding_05_p) {
   resetLayer();
   setInputDim("2:5:5");
-  setProperty("batch_size=2");
+  setBatch(2);
   setProperty("pooling=global_max");
   reinitialize();
 
@@ -1469,7 +1477,7 @@ TEST_F(nntrainer_Pooling2DLayer, forwarding_05_p) {
 TEST_F(nntrainer_Pooling2DLayer, forwarding_06_p) {
   resetLayer();
   setInputDim("2:5:5");
-  setProperty("batch_size=2");
+  setBatch(2);
   setProperty("pooling=global_average");
   reinitialize();
 
@@ -1566,7 +1574,7 @@ class nntrainer_FlattenLayer
 protected:
   virtual void prepareLayer() {
     setInputDim("2:4:4");
-    layer.setBatch(1);
+    setBatch(1);
   }
 };
 
@@ -1590,7 +1598,7 @@ TEST_F(nntrainer_FlattenLayer, forwarding_01_p) {
  */
 TEST_F(nntrainer_FlattenLayer, forwarding_02_p) {
   setInputDim("2:4:4");
-  layer.setBatch(2);
+  setBatch(2);
   reinitialize();
 
   EXPECT_EQ(out.getDim(), nntrainer::TensorDim(2, 1, 1, 32));
@@ -1623,7 +1631,7 @@ TEST_F(nntrainer_FlattenLayer, backwarding_01_p) {
  */
 TEST_F(nntrainer_FlattenLayer, backwarding_02_p) {
   setInputDim("2:4:4");
-  layer.setBatch(2);
+  setBatch(2);
   reinitialize();
 
   EXPECT_EQ(out.getDim(), nntrainer::TensorDim(2, 1, 1, 32));
@@ -1795,7 +1803,8 @@ class nntrainer_AdditionLayer
   : public nntrainer_abstractLayer<nntrainer::AdditionLayer> {
 protected:
   virtual void prepareLayer() {
-    setInputDim("32:3:28:28");
+    setInputDim("3:28:28");
+    setBatch(32);
     setProperty("num_inputs=1");
   }
 };
@@ -1837,7 +1846,7 @@ TEST_F(nntrainer_AdditionLayer, forwarding_01_n) {
 
   in = nntrainer::Tensor();
 
-  EXPECT_THROW(layer.forwarding(input), std::logic_error);
+  EXPECT_THROW(layer.forwarding(input), std::runtime_error);
 }
 
 TEST_F(nntrainer_AdditionLayer, forwarding_02_n) {