[ Coverity ] Fix Coverity Issues accepted/tizen/unified/20200728.135447 submit/tizen/20200728.023042
authorjijoong.moon <jijoong.moon@samsung.com>
Fri, 24 Jul 2020 12:15:21 +0000 (21:15 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 28 Jul 2020 02:15:34 +0000 (11:15 +0900)
This PR incluide Fixs of Coverity Issues.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
27 files changed:
Applications/Classification/jni/main.cpp
Applications/Classification/jni/main_func.cpp
Applications/ReinforcementLearning/DeepQ/jni/main.cpp
Applications/Tizen_CAPI/capi_file.c
Applications/Tizen_CAPI/capi_func.c
Applications/Training/jni/bitmap_helpers.cpp
Applications/mnist/jni/main.cpp
api/capi/src/nntrainer.cpp
nntrainer/include/bn_layer.h
nntrainer/include/conv2d_layer.h
nntrainer/include/fc_layer.h
nntrainer/include/flatten_layer.h
nntrainer/include/input_layer.h
nntrainer/include/layer.h
nntrainer/include/optimizer.h
nntrainer/include/pooling2d_layer.h
nntrainer/include/tensor.h
nntrainer/include/tensor_dim.h
nntrainer/src/databuffer_file.cpp
nntrainer/src/databuffer_func.cpp
nntrainer/src/neuralnet.cpp
nntrainer/src/optimizer.cpp
nntrainer/src/tensor_dim.cpp
test/include/nntrainer_test_util.h
test/nntrainer_test_util.cpp
test/tizen_capi/unittest_tizen_capi.cpp
test/tizen_capi/unittest_tizen_capi_dataset.cpp

index ef8f324..889f20d 100644 (file)
@@ -119,10 +119,6 @@ static int rangeRandom(int min, int max) {
  * @param[out] feature_input save output of tflite
  */
 void getFeature(const string filename, vector<float> &feature_input) {
-  int input_size;
-  int output_size;
-  std::vector<int> output_idx_list;
-  std::vector<int> input_idx_list;
   int input_dim[4];
   int output_dim[4];
   std::string model_path = "../../res/mobilenetv2.tflite";
@@ -134,23 +130,8 @@ void getFeature(const string filename, vector<float> &feature_input) {
   std::unique_ptr<tflite::Interpreter> interpreter;
   tflite::InterpreterBuilder(*model.get(), resolver)(&interpreter);
 
-  input_size = interpreter->inputs().size();
-  output_size = interpreter->outputs().size();
-
-  int t_size = interpreter->tensors_size();
-
-  for (int i = 0; i < t_size; i++) {
-    for (int j = 0; j < input_size; j++) {
-      if (strncmp(interpreter->tensor(i)->name, interpreter->GetInputName(j),
-                  sizeof(interpreter->tensor(i)->name)) == 0)
-        input_idx_list.push_back(i);
-    }
-    for (int j = 0; j < output_size; j++) {
-      if (strncmp(interpreter->tensor(i)->name, interpreter->GetOutputName(j),
-                  sizeof(interpreter->tensor(i)->name)) == 0)
-        output_idx_list.push_back(i);
-    }
-  }
+  const std::vector<int> &input_idx_list = interpreter->inputs();
+  const std::vector<int> &output_idx_list = interpreter->outputs();
 
   for (int i = 0; i < 4; i++) {
     input_dim[i] = 1;
@@ -450,7 +431,15 @@ int main(int argc, char *argv[]) {
     std::vector<float> featureVector, resultVector;
     featureVector.resize(feature_size);
     getFeature(img, featureVector);
-    nntrainer::Tensor X = nntrainer::Tensor({featureVector});
+
+    nntrainer::Tensor X;
+    try {
+      X = nntrainer::Tensor({featureVector});
+    } catch (...) {
+      std::cerr << "Error while construct tensor" << std::endl;
+      NN.finalize();
+      return 0;
+    }
     cout << NN.forwarding(X, status).apply(stepFunction) << endl;
   }
   /**
index b4b012f..65111ef 100644 (file)
@@ -118,13 +118,14 @@ static int rangeRandom(int min, int max) {
  * @retval true/false false : end of data
  */
 bool getData(std::ifstream &F, std::vector<float> &outVec,
-             std::vector<float> &outLabel, int id) {
+             std::vector<float> &outLabel, uint64_t id) {
   F.clear();
   F.seekg(0, std::ios_base::end);
   uint64_t file_length = F.tellg();
-  uint64_t position = (feature_size + total_label_size) * id * sizeof(float);
+  uint64_t position =
+    (uint64_t)((feature_size + total_label_size) * id * sizeof(float));
 
-  if (position > file_length || position > ULLONG_MAX) {
+  if (position > file_length) {
     return false;
   }
   F.seekg(position, std::ios::beg);
@@ -296,11 +297,18 @@ int main(int argc, char *argv[]) {
    */
   nntrainer::NeuralNetwork NN;
   NN.setConfig(config);
-  NN.loadFromConfig();
+  try {
+    NN.loadFromConfig();
+  } catch (...) {
+    std::cerr << "Error during loadFromConfig" << std::endl;
+    NN.finalize();
+    return 0;
+  }
   try {
     NN.init();
   } catch (...) {
     std::cerr << "Error during init" << std::endl;
+    NN.finalize();
     return 0;
   }
   NN.readModel();
index 63a5be1..808da4e 100644 (file)
@@ -281,10 +281,34 @@ int main(int argc, char **argv) {
   /**
    * @brief     initialize mainNet & Target Net
    */
-  mainNet.loadFromConfig();
-  mainNet.init();
-  targetNet.loadFromConfig();
-  targetNet.init();
+  try {
+    mainNet.loadFromConfig();
+  } catch (...) {
+    std::cerr << "Error during loadFromConfig" << std::endl;
+    mainNet.finalize();
+    return 0;
+  }
+  try {
+    mainNet.init();
+  } catch (...) {
+    std::cerr << "Error during init" << std::endl;
+    mainNet.finalize();
+    return 0;
+  }
+  try {
+    targetNet.loadFromConfig();
+  } catch (...) {
+    std::cerr << "Error during loadFromConfig" << std::endl;
+    targetNet.finalize();
+    return 0;
+  }
+  try {
+    targetNet.init();
+  } catch (...) {
+    std::cerr << "Error during init" << std::endl;
+    targetNet.finalize();
+    return 0;
+  }
 
   /**
    * @brief     Read Model Data if any
index 9f6c8a8..cb3185d 100644 (file)
@@ -91,7 +91,7 @@ int main(int argc, char *argv[]) {
   NN_RETURN_STATUS();
 
   /* compile model with cross entropy loss function */
-  status = ml_train_model_compile(model, "loss=cross", NULL);
+  status = ml_train_model_compile(model, "loss=cross", "batch_size=32", NULL);
   NN_RETURN_STATUS();
 
   /* create dataset */
@@ -110,8 +110,7 @@ int main(int argc, char *argv[]) {
 
   /* train model with data files : epochs = 10 and store model file named
    * "model.bin" */
-  status = ml_train_model_run(model, "epochs=10", "batch_size=32",
-                              "model_file=model.bin", NULL);
+  status = ml_train_model_run(model, "epochs=10", "model_file=model.bin", NULL);
   NN_RETURN_STATUS();
 
   /* delete model */
index a643da0..2d8493d 100644 (file)
@@ -72,7 +72,7 @@ static int range_random(int min, int max) {
  * @retval true/false false : end of data
  */
 static bool get_data(const char *file_name, float *outVec, float *outLabel,
-                     int id, int file_length) {
+                     uint64_t id, int file_length) {
   uint64_t position;
   FILE *F;
   unsigned int i;
@@ -136,6 +136,7 @@ int gen_data_train(float **outVec, float **outLabel, bool *last) {
   unsigned int data_size = 0;
   unsigned int i, j;
   FILE *file;
+  float *o, *l;
 
   const char *file_name = "trainingSet.dat";
 
@@ -185,10 +186,10 @@ int gen_data_train(float **outVec, float **outLabel, bool *last) {
     }
   }
 
-  for (i = 0; i < count; i++) {
-    float o[feature_size];
-    float l[num_class];
+  o = malloc(sizeof(float) * feature_size);
+  l = malloc(sizeof(float) * num_class);
 
+  for (i = 0; i < count; i++) {
     get_data(file_name, o, l, memI[i], file_size);
 
     for (j = 0; j < feature_size; ++j)
@@ -197,6 +198,8 @@ int gen_data_train(float **outVec, float **outLabel, bool *last) {
       outLabel[0][i * num_class + j] = l[j];
   }
 
+  free(o);
+  free(l);
   *last = false;
   return ML_ERROR_NONE;
 }
@@ -215,6 +218,7 @@ int gen_data_val(float **outVec, float **outLabel, bool *last) {
   unsigned int count = 0;
   unsigned int data_size = 0;
   long file_size;
+  float *o, *l;
 
   const char *file_name = "trainingSet.dat";
 
@@ -255,10 +259,10 @@ int gen_data_val(float **outVec, float **outLabel, bool *last) {
     }
   }
 
-  for (i = 0; i < count; i++) {
-    float o[feature_size];
-    float l[num_class];
+  o = malloc(feature_size * sizeof(float));
+  l = malloc(num_class * sizeof(float));
 
+  for (i = 0; i < count; i++) {
     get_data(file_name, o, l, memI[i], file_size);
 
     for (j = 0; j < feature_size; ++j)
@@ -268,6 +272,10 @@ int gen_data_val(float **outVec, float **outLabel, bool *last) {
   }
 
   *last = false;
+
+  free(o);
+  free(l);
+
   return ML_ERROR_NONE;
 }
 
@@ -328,7 +336,7 @@ int main(int argc, char *argv[]) {
   NN_RETURN_STATUS();
 
   /* compile model with cross entropy loss function */
-  status = ml_train_model_compile(model, "loss=cross", NULL);
+  status = ml_train_model_compile(model, "loss=cross", "batch_size=32", NULL);
   NN_RETURN_STATUS();
 
   /* create dataset */
@@ -346,12 +354,12 @@ int main(int argc, char *argv[]) {
 
   /* train model with data files : epochs = 10 and store model file named
    * "model.bin" */
-  status = ml_train_model_run(model, "epochs=10", "batch_size=32",
-                              "model_file=model.bin", NULL);
+  status = ml_train_model_run(model, "epochs=10", "model_file=model.bin", NULL);
   NN_RETURN_STATUS();
 
   /* delete model */
   status = ml_train_model_destroy(model);
   NN_RETURN_STATUS();
+
   return 0;
 }
index 86659f5..ca7dc09 100644 (file)
@@ -106,9 +106,15 @@ uint8_t *read_bmp(const std::string &input_bmp_name, int *width, int *height,
 
   // Decode image, allocating tensor once the image size is known
   uint8_t *output = new uint8_t[abs(*height) * *width * *channels];
+
   const uint8_t *bmp_pixels = &img_bytes[header_size];
-  return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height),
-                    *channels, top_down);
+
+  decode_bmp(bmp_pixels, row_size, output, *width, abs(*height), *channels,
+             top_down);
+
+  delete (img_bytes);
+
+  return output;
 }
 
 } // namespace label_image
index 0d63a0d..68f6129 100644 (file)
@@ -291,11 +291,19 @@ int main(int argc, char *argv[]) {
    */
   nntrainer::NeuralNetwork NN;
   NN.setConfig(config);
-  NN.loadFromConfig();
+  try {
+    NN.loadFromConfig();
+  } catch (...) {
+    std::cerr << "Error during loadFromConfig" << std::endl;
+    NN.finalize();
+    return 0;
+  }
+
   try {
     NN.init();
   } catch (...) {
     std::cerr << "Error during init" << std::endl;
+    NN.finalize();
     return 0;
   }
 
index 318065f..bae5156 100644 (file)
@@ -496,12 +496,13 @@ int ml_train_layer_create(ml_train_layer_h *layer, ml_train_layer_type_e type) {
       delete nnlayer;
       ml_loge("Error: Unknown layer type");
       status = ML_ERROR_INVALID_PARAMETER;
-      break;
+      return status;
     }
   } catch (std::bad_alloc &e) {
     ml_loge("Error: heap exception: %s", e.what());
     status = ML_ERROR_OUT_OF_MEMORY;
     delete nnlayer;
+    return status;
   }
 
   nnlayer->in_use = false;
index 5ddf158..6388918 100644 (file)
@@ -50,6 +50,18 @@ public:
   ~BatchNormalizationLayer(){};
 
   /**
+   *  @brief  Move constructor of Pooling 2D Layer.
+   *  @param[in] BatchNormalization &&
+   */
+  BatchNormalizationLayer(BatchNormalizationLayer &&rhs) = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs BatchNormalizationLayer to be moved.
+   */
+  BatchNormalizationLayer &operator=(BatchNormalizationLayer &&rhs) = default;
+
+  /**
    * @brief     forward propagation with input
    * @param[in] in Input Tensor from upper layer
    * @retval    normalized input tensor using scaling factor
index 0a4f13d..6b2e841 100644 (file)
@@ -52,6 +52,18 @@ public:
   ~Conv2DLayer(){};
 
   /**
+   *  @brief  Move constructor of Conv 2D Layer.
+   *  @param[in] Conv2dLayer &&
+   */
+  Conv2DLayer(Conv2DLayer &&rhs) = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs Conv2DLayer to be moved.
+   */
+  Conv2DLayer &operator=(Conv2DLayer &&rhs) = default;
+
+  /**
    * @brief     initialize layer
    * @param[in] last last layer
    * @retval #ML_ERROR_NONE Successful.
index 7c10110..b438b11 100644 (file)
@@ -42,6 +42,18 @@ public:
   ~FullyConnectedLayer(){};
 
   /**
+   *  @brief  Move constructor of Pooling 2D Layer.
+   *  @param[in] FullyConnected &&
+   */
+  FullyConnectedLayer(FullyConnectedLayer &&rhs) = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs FullyConnectedLayer to be moved.
+   */
+  FullyConnectedLayer &operator=(FullyConnectedLayer &&rhs) = default;
+
+  /**
    * @brief     Read Weight & Bias Data from file
    * @param[in] file input stream file
    */
index 9b16eda..75f0ec1 100644 (file)
@@ -41,6 +41,18 @@ public:
   ~FlattenLayer(){};
 
   /**
+   *  @brief  Move constructor of FlattenLayer.
+   *  @param[in] FlattenLayer &&
+   */
+  FlattenLayer(FlattenLayer &&rhs) = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs FlattenLayer to be moved.
+   */
+  FlattenLayer &operator=(FlattenLayer &&rhs) = default;
+
+  /**
    * @brief     initialize layer
    * @param[in] last last layer
    * @retval #ML_ERROR_NONE Successful.
index 59191e4..d34f516 100644 (file)
@@ -52,6 +52,18 @@ public:
   ~InputLayer(){};
 
   /**
+   *  @brief  Move constructor of Pooling 2D Layer.
+   *  @param[in] Input &&
+   */
+  InputLayer(InputLayer &&rhs) = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs InputLayer to be moved.
+   */
+  InputLayer &operator=(InputLayer &&rhs) = default;
+
+  /**
    * @brief     No Weight data for this Input Layer
    */
   void read(std::ifstream &file){};
index 50186b7..ac8264e 100644 (file)
@@ -155,6 +155,18 @@ public:
   virtual ~Layer(){};
 
   /**
+   *  @brief  Move constructor of Layer.
+   *  @param[in] Layer &&
+   */
+  Layer(Layer &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs Layer to be moved.
+   */
+  virtual Layer &operator=(Layer &&rhs) = default;
+
+  /**
    * @brief     Forward Propation of neural Network
    * @param[in] in Input Tensor taken by upper layer
    * @retval    Output Tensor
index 1858fd4..498f467 100644 (file)
@@ -97,12 +97,32 @@ public:
    */
   Optimizer() : type(OptType::unknown), popt() {}
 
+  Optimizer(const OptType type, OptParam popt);
+
   /**
    * @brief     Destructor of Optimizer Class
    */
   ~Optimizer() {}
 
   /**
+   * @brief  copy assignment operator
+   * @parma[in] rhs Optimizer to be copied
+   */
+  Optimizer &operator=(const Optimizer &rhs) = default;
+
+  /**
+   *  @brief  Move constructor of Conv 2D Layer.
+   *  @param[in] Conv2dLayer &&
+   */
+  Optimizer(Optimizer &&rhs) = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs Optimizer to be moved.
+   */
+  Optimizer &operator=(Optimizer &&rhs) = default;
+
+  /**
    * @brief     set Optimizer Type
    * @param[in] t Optimizer type
    * @retval #ML_ERROR_NONE Successful.
index fd9369e..3ae0113 100644 (file)
@@ -57,6 +57,18 @@ public:
   ~Pooling2DLayer(){};
 
   /**
+   *  @brief  Move constructor of Pooling 2D Layer.
+   *  @param[in] Pooling2D &&
+   */
+  Pooling2DLayer(Pooling2DLayer &&rhs) = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs Pooling2DLayer to be moved.
+   */
+  Pooling2DLayer &operator=(Pooling2DLayer &&rhs) = default;
+
+  /**
    * @brief     initialize layer
    * @param[in] last last layer
    * @retval #ML_ERROR_NONE Successful.
index 20fee1c..9501824 100644 (file)
@@ -523,13 +523,6 @@ public:
   int setDim(TensorDim d);
 
   /**
-   * @brief     return if current tensor is contiguous, if not, you can't write
-   *            on this tensor
-   * @retval    bool is contigous
-   */
-  const bool isContiguous() const noexcept { return is_contiguous; }
-
-  /**
    * @brief     return current stride of tensor.
    * @retval    int[MAXDIM] strides
    */
index ebea783..439fb0b 100644 (file)
@@ -44,7 +44,30 @@ public:
     len = b * feature_len;
   }
 
+  TensorDim(const TensorDim &rhs) :
+    TensorDim(rhs.batch(), rhs.channel(), rhs.height(), rhs.width()){};
+
   ~TensorDim(){};
+
+  /**
+   *  @brief  Move constructor of Conv 2D Layer.
+   *  @param[in] Conv2dLayer &&
+   */
+  TensorDim(TensorDim &&rhs) noexcept = default;
+
+  /**
+   * @brief  Move assignment operator.
+   * @parma[in] rhs Optimizer to be moved.
+   */
+  TensorDim &operator=(TensorDim &&rhs) noexcept;
+
+  /**
+   * @brief  swap variable of Conv2D Layer
+   * @parma[out] lhs Optimizer
+   * @parma[in] rhs Optimizer
+   */
+  void swap(TensorDim &lhs, TensorDim &rhs) noexcept;
+
   unsigned int batch() const { return dim[0]; };
   unsigned int channel() const { return dim[1]; };
   unsigned int height() const { return dim[2]; };
@@ -59,12 +82,12 @@ public:
   void width(unsigned int w) { setTensorDim(3, w); }
 
   const unsigned int *getDim() const { return dim; }
-  const unsigned int getNumDim() const { return MAXDIM; }
+  unsigned int getNumDim() const { return MAXDIM; }
 
   void setTensorDim(unsigned int idx, unsigned int value);
   int setTensorDim(std::string input_shape);
 
-  void operator=(const TensorDim &from);
+  TensorDim &operator=(const TensorDim &rhs);
   bool operator==(const TensorDim &rhs) const;
   bool operator!=(const TensorDim &rhs) const { return !(*this == rhs); }
 
index ab483c1..84a726e 100644 (file)
@@ -191,7 +191,7 @@ void DataBufferFromDataFile::updateData(BufferType type) {
     break;
   }
 
-  unsigned int I;
+  uint64_t I;
   std::vector<unsigned int> mark;
   mark.resize(max_size);
   file.clear();
index d3ab490..01aeb6c 100644 (file)
@@ -316,6 +316,8 @@ void DataBufferFromCallback::updateData(BufferType type) {
   }
   free(vec);
   free(veclabel);
+  free(vec_arr);
+  free(veclabel_arr);
 }
 
 int DataBufferFromCallback::setProperty(const PropertyType type,
index e4cdf3a..0cfc198 100644 (file)
@@ -233,7 +233,8 @@ int NeuralNetwork::loadFromConfig() {
 
     if (!sec_name) {
       ml_loge("Error: Unable to retrieve section names from ini.");
-      return ML_ERROR_INVALID_PARAMETER;
+      status = ML_ERROR_INVALID_PARAMETER;
+      NN_RETURN_STATUS();
     }
 
     if (strncasecmp(network_str, sec_name, network_len) == 0) {
@@ -979,7 +980,7 @@ static unsigned int getLayerFlag(ml_train_summary_type_e verbosity,
     /// no break intended
 
   case ML_TRAIN_SUMMARY_MODEL:
-    flag =
+    flag |=
       LayerPrintOption::PRINT_INST_INFO | LayerPrintOption::PRINT_SHAPE_INFO;
     break;
 
index b953c8e..83dc703 100644 (file)
 
 namespace nntrainer {
 
+Optimizer::Optimizer(const OptType t, const OptParam p) {
+  type = t;
+  popt = p;
+}
+
 int Optimizer::setType(OptType t) {
   int status = ML_ERROR_NONE;
   if (t == OptType::unknown) {
index 376965a..e38ae90 100644 (file)
 
 namespace nntrainer {
 
+TensorDim &TensorDim::operator=(const TensorDim &rhs) {
+  TensorDim tmp(rhs.batch(), rhs.channel(), rhs.height(), rhs.width());
+  this->swap(*this, tmp);
+  return *this;
+}
+
+TensorDim &TensorDim::operator=(TensorDim &&rhs) noexcept {
+  this->swap(*this, rhs);
+  return *this;
+}
+
+void TensorDim::swap(TensorDim &lhs, TensorDim &rhs) noexcept {
+  std::swap(lhs.dim, rhs.dim);
+  std::swap(lhs.len, rhs.len);
+  std::swap(lhs.feature_len, rhs.feature_len);
+}
+
 void TensorDim::resetLen() {
   feature_len = dim[1] * dim[2] * dim[3];
   len = dim[0] * feature_len;
@@ -66,14 +83,6 @@ int TensorDim::setTensorDim(std::string input_shape) {
   return status;
 }
 
-void TensorDim::operator=(const TensorDim &from) {
-  for (int i = 0; i < MAXDIM; ++i) {
-    this->dim[i] = from.dim[i];
-  }
-  len = from.len;
-  feature_len = from.feature_len;
-}
-
 bool TensorDim::operator==(const TensorDim &rhs) const {
   for (int i = 0; i < MAXDIM; ++i) {
     if (this->dim[i] != rhs.dim[i]) {
index 8c63eb3..713d736 100644 (file)
@@ -157,7 +157,7 @@ protected:
   nntrainer::NeuralNetwork NN;
 
 private:
-  void erase_ini() { std::remove(getIniName().c_str()); }
+  void erase_ini() { std::remove((char *)(getIniName().c_str())); }
   int failAt;
   std::string name;
   std::vector<IniSection> sections;
index 167ad4b..d3a4494 100644 (file)
@@ -87,7 +87,7 @@ static int rangeRandom(int min, int max) {
  * @retval true/false false : end of data
  */
 static bool getData(std::ifstream &F, std::vector<float> &outVec,
-                    std::vector<float> &outLabel, int id) {
+                    std::vector<float> &outLabel, uint64_t id) {
   F.clear();
   F.seekg(0, std::ios_base::end);
   uint64_t file_length = F.tellg();
index f83212b..b31b49f 100644 (file)
@@ -759,6 +759,8 @@ TEST(nntrainer_capi_summary, summary_01_p) {
 
   status = ml_train_model_destroy(handle);
   EXPECT_EQ(status, ML_ERROR_NONE);
+
+  free(sum);
 }
 
 /**
index fd66cae..0570610 100644 (file)
@@ -35,6 +35,9 @@ TEST(nntrainer_capi_dataset, create_destroy_02_n) {
   status =
     ml_train_dataset_create_with_file(&dataset, "nofile.txt", NULL, NULL);
   EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
+
+  status = ml_train_dataset_destroy(dataset);
+  EXPECT_EQ(status, ML_ERROR_INVALID_PARAMETER);
 }
 
 /**
@@ -78,6 +81,7 @@ TEST(nntrainer_capi_dataset, create_destroy_05_p) {
   status =
     ml_train_dataset_create_with_file(&dataset, "trainingSet.dat", NULL, NULL);
   EXPECT_EQ(status, ML_ERROR_NONE);
+
   status = ml_train_dataset_destroy(dataset);
   EXPECT_EQ(status, ML_ERROR_NONE);