[Test] Refactor recorder.py
authorJihoon Lee <jhoon.it.lee@samsung.com>
Fri, 23 Oct 2020 06:36:39 +0000 (15:36 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 2 Nov 2020 06:40:51 +0000 (15:40 +0900)
This commit mainly patches KerasRecorder to be more flexible

**Changes proposed in this PR:**
- Pass file, label info at `KerasRecorder.run` phase instead of __init__
to leave room to reuse the model
- Allow initiation with SequentialModel for usuability
- Deal with cross_sigmoid, cross_softmax
- Move some functions out of class

**V2**
Since `KerasRecorder` class was highly coupled to a certain model and
made it hard to make some variation out of it, (e.g. using "mse" instead
of "cross" will need to make a whole new class and it is much
error-prone.
This patch move the class implementation to several functions.

This will be used with `functools.partial` so to easily generate loss,
optimizer variation.

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
[Test/Refactor] Restructure data format

Restructure golden data to reduce redundant data and only check updated
weight thus making code more readable

**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Jihoon Lee <jhoon.it.lee@samsung.com>
packaging/unittest_models.tar.gz
test/input_gen/genModelTests.py
test/input_gen/recorder.py
test/unittest/unittest_nntrainer_models.cpp

index 6382fa520eaf88f533d88f0a221c5ce150a37123..f4859badee38cc00b38021f634486640a2a6ebbe 100644 (file)
Binary files a/packaging/unittest_models.tar.gz and b/packaging/unittest_models.tar.gz differ
index 7c33c3713702bdecef3760c50efb17cf9e329d14..a000a770d8e2dcffa055b76baf48480ba3194254 100644 (file)
@@ -9,8 +9,9 @@
 # @author Jihoon lee <jhoon.it.lee@samsung.com>
 
 import warnings
+from functools import partial
 
-from recorder import KerasRecorder
+from recorder import record
 
 with warnings.catch_warnings():
     warnings.filterwarnings("ignore", category=FutureWarning)
@@ -18,36 +19,43 @@ with warnings.catch_warnings():
     import tensorflow as tf
     from tensorflow.python import keras as K
 
+
+opt = tf.keras.optimizers
+
 if __name__ == "__main__":
-    inp = K.Input(shape=(3, 3))
-    a = K.layers.Dense(5)(inp)
-    b = K.layers.Dense(5)(a)
-    c = K.layers.Dense(10)(b)
-    d = K.layers.Activation("softmax")(c)
-
-    KerasRecorder(
-        file_name="fc_softmax_mse.info",
-        inputs=inp,
-        outputs=[inp, a, b, c, d],
-        input_shape=(3, 3),
-        label_shape=(3, 10),
-        loss_fn=tf.keras.losses.MeanSquaredError(),
-    ).run(10)
-
-    inp = K.Input(shape=(3, 3))
-    a = K.layers.Dense(10)(inp)
-    b = K.layers.Activation("relu")(a)
-    c = K.layers.Dense(10)(b)
-    d = K.layers.Activation("relu")(c)
-    e = K.layers.Dense(2)(d)
-    f = K.layers.Activation("relu")(e)
-
-    KerasRecorder(
-        file_name="fc_relu_mse.info",
-        inputs=inp,
-        outputs=[inp, a, b, c, d, e, f],
-        input_shape=(3, 3),
-        label_shape=(3, 2),
-        loss_fn=tf.keras.losses.MeanSquaredError(),
-        optimizer=tf.keras.optimizers.SGD(lr=0.001)
-    ).run(10)
+    fc_sigmoid = [
+        K.Input(shape=(3, 3)),
+        K.layers.Dense(5),
+        K.layers.Activation("sigmoid"),
+        K.layers.Dense(10),
+        K.layers.Activation("softmax"),
+    ]
+
+    fc_sigmoid_tc = partial(
+        record, model=fc_sigmoid, input_shape=(3, 3), label_shape=(3, 10), iteration=10
+    )
+
+    fc_sigmoid_tc(
+        file_name="fc_sigmoid_mse_sgd.info",
+        loss_fn_str="mse",
+        optimizer=opt.SGD(learning_rate=1.0),
+    )
+
+    fc_relu = [
+        K.Input(shape=(3)),
+        K.layers.Dense(10),
+        K.layers.Activation("relu"),
+        K.layers.Dense(2),
+        K.layers.Activation("sigmoid"),
+    ]
+
+    fc_relu_tc = partial(
+        record, model=fc_relu, input_shape=(3, 3), label_shape=(3, 2), iteration=10
+    )
+
+    fc_relu_tc(
+        file_name="fc_relu_mse_sgd.info",
+        loss_fn_str="mse",
+        optimizer=opt.SGD(learning_rate=0.1),
+        debug="initial_input"
+    )
index 20feb80c3435a91b04618a743474f5a1c9762a2f..fbc95430f5c367011ebe923431fe16d410b6bc45 100644 (file)
@@ -20,6 +20,8 @@ with warnings.catch_warnings():
     import tensorflow as tf
     from tensorflow.python import keras as K
 
+__all__ = ["record"]
+
 tf.compat.v1.enable_eager_execution()
 # Fix the seeds across frameworks
 SEED = 1234
@@ -27,102 +29,224 @@ random.seed(SEED)
 tf.compat.v1.set_random_seed(SEED)
 np.random.seed(SEED)
 
+LOSS_FN = {
+    "mse": lambda: tf.keras.losses.MeanSquaredError(),
+    "cross_sigmoid": lambda: tf.keras.losses.BinaryCrossentropy(from_logits=True),
+    "cross_softmax": lambda: tf.keras.losses.CategoricalCrossentropy(from_logits=True),
+}
 
-##
-# Keras Recorder
 
-##
-# @brief Record Keras model with some watchers attached
-# @note  The class might need to go through some rework for non-sequential model
-# in case of the order of graph traversal is diffrent from NNTrainer
-class KerasRecorder:
-    def __init__(
-        self,
-        file_name,
-        inputs,
-        outputs,
-        input_shape,
-        label_shape,
-        loss_fn=None,
-        optimizer=tf.keras.optimizers.SGD(lr=1.0),
-    ):
-        self.inputs = inputs
-        self.outputs = outputs
-        self.model = K.Model(inputs=inputs, outputs=outputs)
-        self.loss_fn = loss_fn
-        self.optimizer = optimizer
-        if os.path.isfile(file_name):
-          print("Warning: the file %s is being truncated and overwritten" % file_name)
-        self.file = open(file_name, "wb")
-        self.generate_data(input_shape, label_shape)
-
-    def __del__(self):
-        self.file.close()
-
-    def _rand_like(self, tensorOrShape, scale=10):
-        try:
-            t =  np.random.randint(1, 10, size=tensorOrShape.shape).astype(dtype=np.float32)
-        except AttributeError:
-            t = np.random.randint(1, 10, size=tensorOrShape).astype(dtype=np.float32)
-        return tf.convert_to_tensor(t)
-
-    ##
-    # @brief generate data using uniform data from a function and save to the file.
-    # @note one-hot label is supported for now, this could be extended if needed.
-    def generate_data(self, input_shape, label_shape):
-        """This part loads data, should be changed if you are gonna load real data"""
-        self.initial_input = self._rand_like(input_shape)
-        self.label = tf.one_hot(
-          indices=np.random.randint(0, label_shape[1] - 1, label_shape[0]),
-          depth=label_shape[1]
-        )
+def _get_loss_fn(loss_fn_representation):
+    try:
+        return LOSS_FN[loss_fn_representation]()
+    except KeyError:
+        raise ValueError("given loss fn representation is not available")
 
-        self.initial_input.numpy().tofile(self.file)
-        self.label.numpy().tofile(self.file)
 
-    def _write_items(self, *items):
+def _get_writer(file):
+    def write_fn(*items):
         for item in items:
             try:
-                item.numpy().tofile(self.file)
+                item.numpy().tofile(file)
             except AttributeError:
                 pass
+        return items
 
-    ##
-    # @brief model iteration wrapper that listen to the gradient and outputs of the model
-    # each results are recorded.
-    def step(self):
-        with tf.GradientTape(persistent=True) as tape:
-            tape.watch(self.initial_input)
-            outputs = self.model(self.initial_input)
+    return write_fn
 
-            if self.loss_fn:
-                loss = self.loss_fn(self.label, outputs[-1])
-                outputs.append(loss)
 
-        results = [self.initial_input] + outputs
+def _rand_like(tensorOrShape, scale=1):
+    try:
+        shape = tensorOrShape.shape
+    except AttributeError:
+        shape = tensorOrShape
 
-        for idx, layer in enumerate(self.model.layers):
-            # print("generating for %s" % layer.name)
+    t = np.random.randint(-10, 10, shape).astype(dtype=np.float32)
+    return tf.convert_to_tensor(t) * scale
 
-            weights = layer.trainable_weights.copy()
-            gradients = tape.gradient(results[-1], layer.trainable_weights)
-            dweights = tape.gradient(results[-1], results[idx])
 
-            # input, weights, gradients, output, dx
-            # you should take weight order to account (eg. I think conv2d has different order)
-            self._write_items(
-                *[results[idx], *weights, *gradients, results[idx + 1], dweights]
-            )
+##
+# For some layers, nntrainer has different layout
+# this function reorder lists from keras layout to nntrainer layout
+def _get_relayout_weight_fn(layer):
+    if isinstance(layer, K.layers.Conv2D):
+        raise NotImplementedError("Not implemented yet")
 
-            self.optimizer.apply_gradients(zip(gradients, layer.trainable_weights))
+    if isinstance(layer, K.layers.BatchNormalization):
+        raise NotImplementedError("Not implemented yet")
 
-        self._write_items(results[-1])
-        print("loss is %s" % results[-1])
+    return lambda x: x
+
+
+_debug_default_formatter = lambda key, value: "key: {}\n {}".format(key, value)
+##
+# @brief Print debug information from the record
+# @param debug list or string that filters debug information from @a data
+# @param print_option print option for the print function
+# @param print_format print formatter. a callable that takes key and value should be passed
+# @param data data to passed to _debug_print
+def _debug_print(
+    debug=None,
+    print_option={"end": "\n"},
+    print_format=_debug_default_formatter,
+    **data
+):
+    if not debug:
+        return
+    elif isinstance(debug, str):
+        debug = [debug]
+
+    for target in debug:
+        try:
+            print(print_format(target, data[target]), **print_option)
+        except KeyError:
+            pass
+
+
+##
+# @brief generate data using uniform data from a function and save to the file.
+# @note one-hot label is supported for now, this could be extended if needed.
+def prepare_data(model, input_shape, label_shape, writer_fn, **kwargs):
+    initial_input = _rand_like(input_shape)
+    label = tf.one_hot(
+        indices=np.random.randint(0, label_shape[1] - 1, label_shape[0]),
+        depth=label_shape[1],
+    )
+
+    initial_weights = [
+        weight for l in model.layers for weight in l.trainable_weights.copy()
+    ]
+    writer_fn(initial_input, label, *initial_weights)
+    _debug_print(
+        initial_input=initial_input,
+        label=label,
+        initial_weights=initial_weights,
+        **kwargs
+    )
+
+    return initial_input, label
+
+
+##
+# @brief model iteration wrapper that listen to the gradient and outputs of the model
+# each results are recorded.
+def train_step(model, optimizer, loss_fn, initial_input, label, writer_fn, **kwargs):
+    with tf.GradientTape(persistent=True) as tape:
+        tape.watch(initial_input)
+        outputs = model(initial_input)
+
+        loss = loss_fn(label, outputs[-1])
+        outputs.append(loss)
+
+    layer_input = initial_input
+    for layer_output, layer in zip(outputs, model.layers):
+        # print("generating for %s" % layer.name)
+        to_nntr_layout = _get_relayout_weight_fn(layer)
+
+        gradients = tape.gradient(loss, layer.trainable_weights)
+        optimizer.apply_gradients(zip(gradients, layer.trainable_weights))
+
+        weights = layer.trainable_weights.copy()
+        dx = tape.gradient(loss, layer_input)
+
+        writer_fn(
+            layer_output,  # output of forward
+            dx,  # output of backward
+            *to_nntr_layout(gradients),  # weight gradient output from backward
+            *to_nntr_layout(weights)  # updated weight after optimization
+        )
+
+        _debug_print(
+            output=layer_output, dx=dx, weights=weights, gradients=gradients, **kwargs
+        )
+
+        layer_input = layer_output
+
+    writer_fn(loss)
+    _debug_print(loss=loss, **kwargs)
+
+
+##
+# @brief inference_step of the result
+def inference_step(loss_fn_str, initial_input, label, writer_fn):
+    # Not yet implemented
+    # because loss function with fromLogit is used, last layer fc layer should be added for the inference step
+    if loss_fn_str == "cross_sigmoid" or loss_fn_str == "cross_entropy":
+        # add last activation layer
+        pass
+    raise NotImplementedError("Not Implemented yet")
+
+value_only_formatter = lambda key, value: value
+
+##
+# @brief generate recordable model
+# @param loss_fn_str one of LOSS_FN string otherwise raise KeyError
+# @param model base model to record, if model is present @a inputs and @a outputs is ignored
+# @param inputs keras inputs to build a model
+# @param outputs keras outputs to build a model
+def generate_recordable_model(loss_fn_str, model=None, inputs=None, outputs=None, **kwargs):
+    if model is not None:
+        if isinstance(model, list):
+            model = K.Sequential(model)
+        inputs = model.input
+        outputs = [model.input] + [layer.output for layer in model.layers]
+
+    # omit last activation layer if cross softmax or corss_sigmoid
+    if loss_fn_str == "cross_softmax" or loss_fn_str == "cross_sigmoid":
+      if isinstance(model.layers[-1], K.layers.activation):
+        outputs = outputs[:-1]
+
+    model = K.Model(inputs=inputs, outputs=outputs)
+
+    model.summary(
+        print_fn=lambda x: _debug_print(summary=x, print_format=value_only_formatter, **kwargs)
+    )
+
+    return model
+
+##
+# @brief record function that records weights, gradients, inputs and outputs for @a iteration
+# @param loss_fn_str loss function representation
+# @param optimizer keras optimizer
+# @param file_name file name to save
+# @param input_shape input shape to put
+# @param label_shape label shape to put
+# @param iteration number of iteration to run
+# @param model base model to record, if model is present @a inputs and @a outputs is ignored
+# @param inputs keras inputs to build a model
+# @param outputs keras outputs to build a model
+def record(
+    loss_fn_str,
+    optimizer,
+    file_name,
+    input_shape,
+    label_shape,
+    iteration=1,
+    model=None,
+    inputs=None,
+    outputs=None,
+    **kwargs
+):
+    if os.path.isfile(file_name):
+        print("Warning: the file %s is being truncated and overwritten" % file_name)
+
+    loss_fn = _get_loss_fn(loss_fn_str)
+    model = generate_recordable_model(loss_fn_str, model, inputs, outputs, **kwargs)
+
+    with open(file_name, "wb") as f:
+        write = _get_writer(f)
+
+        initial_input, label = prepare_data(
+            model, input_shape, label_shape, write, **kwargs
+        )
 
-    ##
-    # @brief run function
-    # @param iteration number of iteration to run
-    def run(self, iteration = 1):
-        print(self.model.summary())
         for _ in range(iteration):
-            self.step()
+            _debug_print(
+                iteration="[%d/%d]" % (_, iteration),
+                print_option={"end": " "},
+                print_format=value_only_formatter,
+                **kwargs
+            )
+            train_step(model, optimizer, loss_fn, initial_input, label, write, **kwargs)
+
+        # self.inference_step(initial_input, label, write)
index 81566580c22b88617134c9b867e5b721c9fb7acd..dcc63e2c8e355c0cfaeb1d9383dba62ce0d1febe 100644 (file)
@@ -69,8 +69,6 @@ public:
     unsigned int num_weights = node->getNumWeights();
     node->setTrainable(false);
 
-    expected_input = nntrainer::Tensor(node->getInputDimension()[0]);
-
     for (unsigned int i = 0; i < num_weights; ++i) {
       const nntrainer::Weight &w = node->weightAt(i);
       expected_weights.push_back(w);
@@ -84,9 +82,10 @@ public:
    * @brief clones from expected weights to node->weights
    *
    */
-  void cloneWeightsFromExpected() {
-    for (unsigned int i = 0; i < expected_weights.size(); ++i) {
-      node->weightAt(i) = expected_weights[i];
+  void readLayerWeight(std::ifstream &f) {
+    for (unsigned int i = 0; i < node->getNumWeights(); ++i) {
+      /// @note below is harrasing the fact the tensor shares same base memory
+      node->weightAt(i).getVariable().read(f);
     }
   }
 
@@ -97,8 +96,8 @@ public:
    * @param iteration iteration
    * @return nntrainer::sharedConstTensor
    */
-  nntrainer::sharedConstTensor forward(nntrainer::sharedConstTensor in,
-                                       int iteration);
+  nntrainer::sharedConstTensors forward(nntrainer::sharedConstTensors in,
+                                        int iteration);
 
   /**
    * @brief forward loss node with verifying inputs/weights/outputs
@@ -108,9 +107,9 @@ public:
    * @param iteration iteration
    * @return nntrainer::sharedConstTensor
    */
-  nntrainer::sharedConstTensor lossForward(nntrainer::sharedConstTensor pred,
-                                           nntrainer::sharedConstTensor answer,
-                                           int iteration);
+  nntrainer::sharedConstTensors
+  lossForward(nntrainer::sharedConstTensors pred,
+              nntrainer::sharedConstTensors answer, int iteration);
 
   /**
    * @brief backward pass of the node with verifying inputs/gradients/outputs
@@ -120,9 +119,9 @@ public:
    * @param should_verify should verify the inputs/gradients/outputs
    * @return nntrainer::sharedConstTensor
    */
-  nntrainer::sharedConstTensor backward(nntrainer::sharedConstTensor deriv,
-                                        int iteration,
-                                        bool should_verify = true);
+  nntrainer::sharedConstTensors backward(nntrainer::sharedConstTensors deriv,
+                                         int iteration,
+                                         bool should_verify = true);
 
   /**
    * @brief verify weights of the current node
@@ -154,7 +153,6 @@ public:
 
 private:
   NodeType node;
-  nntrainer::Tensor expected_input;
   nntrainer::Tensor expected_output;
   nntrainer::Tensor expected_dx;
   std::vector<nntrainer::Weight> expected_weights;
@@ -170,6 +168,9 @@ public:
                   unsigned int iterations);
 
 private:
+  std::array<nntrainer::Tensor, 2>
+  prepareData(std::ifstream &f, const nntrainer::TensorDim &label_dim);
+
   void readIteration(std::ifstream &f);
 
   nntrainer::NeuralNetwork nn;
@@ -179,83 +180,81 @@ private:
 };
 
 void NodeWatcher::read(std::ifstream &in) {
-  expected_input.read(in);
+  expected_output.read(in);
+  expected_dx.read(in);
 
   /// @note below is harrasing the fact the tensor shares same base memory
-  /// it should better be getVariableRef() or somewhat equivalent
+  /// it should better be getGraidentRef() or somewhat equivalent
   for (auto &i : expected_weights) {
-    i.getVariable().read(in);
+    i.getGradient().read(in);
   }
 
   for (auto &i : expected_weights) {
-    i.getGradient().read(in);
+    i.getVariable().read(in);
   }
-
-  expected_output.read(in);
-  expected_dx.read(in);
 }
 
 void NodeWatcher::verifyWeight(const std::string &error_msg) {
   for (unsigned int i = 0; i < expected_weights.size(); ++i) {
     verify(node->weightAt(i).getVariable(), expected_weights[i].getVariable(),
-           error_msg + " " + node->weightAt(i).getName() + "weight");
+           error_msg + " " + node->weightAt(i).getName() + " weight");
   }
 }
 
 void NodeWatcher::verifyGrad(const std::string &error_msg) {
   for (unsigned int i = 0; i < expected_weights.size(); ++i) {
     verify(node->weightAt(i).getGradient(), expected_weights[i].getGradient(),
-           error_msg + " " + node->weightAt(i).getName() + "grad");
+           error_msg + " " + node->weightAt(i).getName() + " grad");
   }
 }
 
-nntrainer::sharedConstTensor
-NodeWatcher::forward(nntrainer::sharedConstTensor in, int iteration) {
+nntrainer::sharedConstTensors
+NodeWatcher::forward(nntrainer::sharedConstTensors in, int iteration) {
   std::stringstream ss;
   ss << "forward failed at " << node->getName() << " at iteration "
      << iteration;
   std::string err_msg = ss.str();
 
-  verify(*in, expected_input, err_msg + " at input ");
-  nntrainer::sharedConstTensor out = node->forwarding({in})[0];
-  verify(*out, expected_output, err_msg + " at output ");
+  nntrainer::sharedConstTensors out = node->forwarding(in);
+  verify(*out[0], expected_output, err_msg + " at output");
   return out;
 }
 
-nntrainer::sharedConstTensor
-NodeWatcher::lossForward(nntrainer::sharedConstTensor pred,
-                         nntrainer::sharedConstTensor answer, int iteration) {
+nntrainer::sharedConstTensors
+NodeWatcher::lossForward(nntrainer::sharedConstTensors pred,
+                         nntrainer::sharedConstTensors answer, int iteration) {
   std::stringstream ss;
   ss << "loss failed at " << node->getName() << " at iteration " << iteration;
   std::string err_msg = ss.str();
 
-  nntrainer::sharedConstTensor out =
-    std::static_pointer_cast<nntrainer::LossLayer>(node)->forwarding(
-      {pred}, {answer})[0];
+  nntrainer::sharedConstTensors out =
+    std::static_pointer_cast<nntrainer::LossLayer>(node)->forwarding(pred,
+                                                                     answer);
 
   return out;
 }
 
-nntrainer::sharedConstTensor
-NodeWatcher::backward(nntrainer::sharedConstTensor deriv, int iteration,
+nntrainer::sharedConstTensors
+NodeWatcher::backward(nntrainer::sharedConstTensors deriv, int iteration,
                       bool should_verify) {
   std::stringstream ss;
   ss << "backward failed at " << node->getName() << " at iteration "
      << iteration;
   std::string err_msg = ss.str();
 
-  nntrainer::sharedConstTensor out = node->backwarding({deriv}, iteration)[0];
-
-  if (should_verify) {
-    verify(*out, expected_dx, err_msg);
-    verifyGrad(err_msg);
-  }
+  nntrainer::sharedConstTensors out = node->backwarding(deriv, iteration);
 
   auto opt = node->getOptimizer();
   if (opt) {
     opt->apply_gradients(node->getWeights(), node->getNumWeights(), iteration);
   }
 
+  if (should_verify) {
+    verify(*out[0], expected_dx, err_msg);
+    verifyGrad(err_msg);
+    verifyWeight(err_msg);
+  }
+
   return out;
 }
 
@@ -278,33 +277,21 @@ void GraphWatcher::compareFor(const std::string &reference,
                               const unsigned int iterations) {
   std::ifstream ref(reference, std::ios_base::in | std::ios_base::binary);
 
-  if (ref.bad()) {
-    throw std::runtime_error("ref is bad!");
+  if (!ref.good()) {
+    std::stringstream ss;
+    ss << "ref is bad! ref path: " << reference;
+    throw std::runtime_error(ss.str().c_str());
   }
 
-  nntrainer::Tensor in(nn.getInputDimension()[0]);
-  nntrainer::Tensor lb(label_shape);
-
-  in.read(ref);
-  lb.read(ref);
-
-  auto prepareInitialWeight = [this]() {
-    std::for_each(nodes.begin(), nodes.end(),
-                  [](NodeWatcher &n) { n.cloneWeightsFromExpected(); });
-  };
-
-  auto matchWeightAfterUpdation = [this]() {
-    std::for_each(nodes.begin(), nodes.end(), [](NodeWatcher &n) {
-      n.verifyWeight("weight is diffrent after updation, check optimizer");
-    });
-  };
+  auto data = prepareData(ref, label_shape);
 
   for (unsigned int iteration = 1; iteration <= iterations; ++iteration) {
-    nntrainer::sharedConstTensor input = MAKE_SHARED_TENSOR(in.clone());
-    nntrainer::sharedConstTensor label = MAKE_SHARED_TENSOR(lb.clone());
+    nntrainer::sharedConstTensors input = {
+      MAKE_SHARED_TENSOR(std::get<0>(data).clone())};
+    nntrainer::sharedConstTensors label = {
+      MAKE_SHARED_TENSOR(std::get<1>(data).clone())};
 
     readIteration(ref);
-    iteration == 1 ? prepareInitialWeight() : matchWeightAfterUpdation();
 
     /// forward pass
     for (auto &i : nodes)
@@ -314,13 +301,27 @@ void GraphWatcher::compareFor(const std::string &reference,
     EXPECT_NEAR(expected_loss, loss_node.getLoss(), nntrainer::Tensor::epsilon);
 
     /// backward pass and update weights
-    nntrainer::sharedConstTensor output =
+    nntrainer::sharedConstTensors output =
       loss_node.backward(label, iteration, false);
     for (auto it = nodes.rbegin(); it != nodes.rend(); it++)
       output = it->backward(output, iteration);
   }
+}
+
+std::array<nntrainer::Tensor, 2>
+GraphWatcher::prepareData(std::ifstream &f,
+                          const nntrainer::TensorDim &label_dim) {
+  nntrainer::Tensor in(nn.getInputDimension()[0]);
+  nntrainer::Tensor lb(label_dim);
+
+  in.read(f);
+  lb.read(f);
 
-  /// note that last weight update is not checked up. this need to be fixed
+  for (auto &i : nodes) {
+    i.readLayerWeight(f);
+  }
+
+  return {in, lb};
 }
 
 void GraphWatcher::readIteration(std::ifstream &f) {
@@ -411,6 +412,9 @@ static IniSection nn_base("model", "Type = NeuralNetwork");
 static IniSection input_base("input", "Type = input");
 static IniSection fc_base("fc", "Type = Fully_connected");
 
+static IniSection adam("_", "Optimizer=adam | beta1 = 0.9 | beta2 = 0.999 | "
+                            "epsilon = 1e-7");
+
 static IniSection act_base("activation", "Type = Activation");
 static IniSection softmax = act_base + "Activation = softmax";
 static IniSection sigmoid = act_base + "Activation = sigmoid";
@@ -436,13 +440,13 @@ using I = IniSection;
  *
  * [dense]
  * Type = fully_connected
- * Unit = 10
+ * Unit = 5
  *
- * [dense_1]
- * Type = fully_connected
- * Unit = 10
+ * [activation]
+ * Type = Activation
+ * Activation = softmax
  *
- * [dense_2]
+ * [dense]
  * Type = fully_connected
  * Unit = 10
  *
@@ -450,37 +454,28 @@ using I = IniSection;
  * Type = Activation
  * Activation = softmax
  */
-IniTestWrapper::Sections fc_softmax_mse{
+IniTestWrapper::Sections fc_sigmoid_mse_sgd{
   nn_base + "Learning_rate=1 | Optimizer=sgd | Loss=mse | batch_size = 3",
   I("input") + input_base + "input_shape = 1:1:3",
   I("dense") + fc_base + "unit = 5",
-  I("dense_1") + fc_base + "unit = 5",
-  I("dense_2") + fc_base + "unit = 10",
-  softmax};
-
-IniTestWrapper::Sections fc_sigmoid_mse{
-  nn_base + "Learning_rate=1 | Optimizer=sgd | Loss=mse | batch_size = 3",
-  I("input") + input_base + "input_shape = 1:1:3",
-  I("dense") + fc_base + "unit = 10",
+  I("act") + sigmoid,
   I("dense_1") + fc_base + "unit = 10",
-  I("dense_2") + fc_base + "unit = 2",
-  sigmoid};
+  I("act_1") + softmax};
 
-IniTestWrapper::Sections fc_relu_mse{
-  nn_base + "Learning_rate=0.001 | Optimizer=sgd | Loss=mse | batch_size = 3",
+IniTestWrapper::Sections fc_relu_mse_sgd{
+  nn_base + "Learning_rate=0.1 | Optimizer=sgd | Loss=mse | batch_size = 3",
   I("input") + input_base + "input_shape = 1:1:3",
   I("dense") + fc_base + "unit = 10",
   I("act") + relu,
-  I("dense_1") + fc_base + "unit = 10",
-  I("act_1") + relu,
-  I("dense_2") + fc_base + "unit = 2",
-  I("act_2") + relu};
+  I("dense_1") + fc_base + "unit = 2",
+  I("act_1") + sigmoid};
 
 // clang-format off
 INSTANTIATE_TEST_CASE_P(
   nntrainerModelAutoTests, nntrainerModelTest, ::testing::Values(
-mkModelTc("fc_softmax_mse", fc_softmax_mse, "3:1:1:10", 10),
-mkModelTc("fc_relu_mse", fc_relu_mse, "3:1:1:2", 10)
+mkModelTc("fc_sigmoid_mse_sgd", fc_sigmoid_mse_sgd, "3:1:1:10", 10),
+mkModelTc("fc_relu_mse_sgd", fc_relu_mse_sgd, "3:1:1:2", 10)
+// mkModelTc("cifar_classification", cifar_classification, "10:1:1:10", 10)
 /// #if gtest_version <= 1.7.0
 ));
 /// #else gtest_version > 1.8.0