Using CBLAS for Tensor Calculation
authorjijoong.moon <jijoong.moon@samsung.com>
Wed, 12 Feb 2020 05:13:54 +0000 (14:13 +0900)
committer문지중/On-Device Lab(SR)/Principal Engineer/삼성전자 <jijoong.moon@samsung.com>
Thu, 13 Feb 2020 04:19:04 +0000 (13:19 +0900)
Implement Tensor Calculation for Tensor Calculation
Can use with "-DUSE_BLAS"

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
CMakeLists.txt
include/tensor.h
src/tensor.cpp

index 6b85ed2..a94e533 100644 (file)
@@ -9,9 +9,10 @@ set (PKGCONFIG_INSTALL_DIR "/usr/lib/pkgconfig/" )
 set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -Wall -Werror -g -pthread")
 set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -Werror -g -std=c++11 -pthread")
 
-
 find_package(PkgConfig REQUIRED)
 
+option(USE_BLAS "Use BLAS library" ON)
+
 set(INIPARSER ${PROJECT_SOURCE_DIR}/external/iniparser/src)
 
 include_directories( ${include_directories}
@@ -27,10 +28,19 @@ set(SRCS
        ${INIPARSER}/dictionary.c
        )
 
-set(NNTRAINER_HEADERS include/neuralnet.h include/tensor.h include/layers.h)      
+set(NNTRAINER_HEADERS include/neuralnet.h include/tensor.h include/layers.h)
 
-add_library( ${PROJECT_NAME} SHARED ${SRCS} )
+if(USE_BLAS)
+  ADD_DEFINITIONS(-DUSE_BLAS)
+  pkg_check_modules(BLAS cblas)
+  link_libraries(${BLAS_LIBRARIES})
+  include_directories({BLAS_INCLUDE_DIRS})
+  set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -lcblas")
+  set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -lcblas")
+  message("[-DUSE_BLAS] is enabled")
+endif(USE_BLAS)
 
+add_library( ${PROJECT_NAME} SHARED ${SRCS} )
 
 configure_file("package.pc.in" "${PROJECT_BINARY_DIR}/${CMAKE_PROJECT_NAME}.pc" @ONLY)
 install (FILES "${PROJECT_BINARY_DIR}/${CMAKE_PROJECT_NAME}.pc" DESTINATION "${PKGCONFIG_INSTALL_DIR}")
index 6d3a901..247bf0e 100644 (file)
 #ifndef __TENSOR_H__
 #define __TENSOR_H__
 
+#ifdef USE_BLAS
+extern "C" {
+#include <cblas.h>
+}
+#endif
+
 #include <cmath>
 #include <fstream>
 #include <iostream>
+#include <memory>
 #include <vector>
 
 /**
@@ -38,7 +45,7 @@ class Tensor {
   /**
    * @brief     Constructor of Tensor
    */
-  Tensor(){};
+  Tensor() : height(0), width(0), batch(0), dim(0), len(0){};
 
   /**
    * @brief     Constructor of Tensor with batch size one
@@ -57,15 +64,23 @@ class Tensor {
 
   /**
    * @brief   Constructor of Tensor
-   * @param[in] data data for the Tensor with batch size one
+   * @param[in] d data for the Tensor with batch size one
    */
-  Tensor(std::vector<std::vector<float>> const &data);
+  Tensor(std::vector<std::vector<float>> const &d);
 
   /**
    * @brief     Constructor of Tensor
-   * @param[in] data data for the Tensor
+   * @param[in] d data for the Tensor
    */
-  Tensor(std::vector<std::vector<std::vector<float>>> const &data);
+  Tensor(std::vector<std::vector<std::vector<float>>> const &d);
+
+  /**
+   * @brief     return value at specific location
+   * @param[in] batch batch location
+   * @param[in] h height location
+   * @param[in] w width location
+   */
+  float getValue(int batch, int h, int w);
 
   /**
    * @brief     Multiply value element by element
@@ -218,13 +233,20 @@ class Tensor {
    */
   void read(std::ifstream &file);
 
+  /**
+   * @brief     return argument index which value is max
+   * @retval    int argument index
+   */
+  int argmax();
+
  private:
   /**< handle the data as a std::vector type */
-  std::vector<std::vector<std::vector<float>>> data;
+  std::vector<float> data;
   int height;
   int width;
   int batch;
   int dim;
+  int len;
 };
 
 /**
index 3150fb2..9af6800 100644 (file)
@@ -24,6 +24,7 @@
 #include "include/tensor.h"
 #include <assert.h>
 #include <stdio.h>
+#include <cstring>
 #include <sstream>
 
 Tensor::Tensor(int height, int width) {
@@ -31,7 +32,9 @@ Tensor::Tensor(int height, int width) {
   this->width = width;
   this->batch = 1;
   this->dim = 2;
-  this->data.push_back(std::vector<std::vector<float>>(height, std::vector<float>(width)));
+  this->len = height * width * batch;
+  this->data = std::vector<float>(len);
+  setZero();
 }
 
 Tensor::Tensor(int batch, int height, int width) {
@@ -39,71 +42,87 @@ Tensor::Tensor(int batch, int height, int width) {
   this->width = width;
   this->batch = batch;
   this->dim = 3;
+  this->len = height * width * batch;
+  this->data = std::vector<float>(len);
+  setZero();
+}
 
-  for (int i = 0; i < batch; i++) {
-    this->data.push_back(std::vector<std::vector<float>>(height, std::vector<float>(width)));
-  }
+float Tensor::getValue(int batch, int h, int w) {
+  return this->data[batch * height * width + h * width + w];
 }
 
-Tensor::Tensor(std::vector<std::vector<float>> const &data) {
-  assert(data.size() != 0);
-  this->height = data.size();
-  this->width = data[0].size();
+void Tensor::setValue(int batch, int h, int w, float value) {
+  this->data[batch * height * width + h * width + w] = value;
+}
+
+Tensor::Tensor(std::vector<std::vector<float>> const &d) {
+  assert(d.size() != 0);
+  this->height = d.size();
+  this->width = d[0].size();
   this->batch = 1;
   this->dim = 2;
-  this->data.push_back(data);
+  this->len = height * width * batch;
+  this->data = std::vector<float>(len);
+
+  for (int j = 0; j < height; ++j)
+    for (int k = 0; k < width; ++k)
+      this->setValue(0, j, k, d[j][k]);
 }
 
-Tensor::Tensor(std::vector<std::vector<std::vector<float>>> const &data) {
-  assert(data.size() != 0 && data[0].size() != 0);
-  this->batch = data.size();
-  this->height = data[0].size();
-  this->width = data[0][0].size();
+Tensor::Tensor(std::vector<std::vector<std::vector<float>>> const &d) {
+  assert(d.size() != 0 && d[0].size() != 0);
+  this->batch = d.size();
+  this->height = d[0].size();
+  this->width = d[0][0].size();
   this->dim = 3;
-  this->data = data;
+  this->len = this->batch * this->height * this->width;
+  this->data = std::vector<float>(len);
+
+  for (int i = 0; i < this->batch; ++i)
+    for (int j = 0; j < this->height; ++j)
+      for (int k = 0; k < this->width; ++k)
+        this->setValue(i, j, k, d[i][j][k]);
 }
 
 Tensor Tensor::multiply(float const &value) {
   Tensor result(batch, height, width);
-  int i, j, k;
-
-  for (k = 0; k < batch; k++) {
-    for (i = 0; i < height; i++) {
-      for (j = 0; j < width; j++) {
-        result.data[k][i][j] = data[k][i][j] * value;
-      }
-    }
+#ifdef USE_BLAS
+  memset(result.data.data(), 0, sizeof(float) * result.len);
+  cblas_saxpy(this->len, value, this->data.data(), 1, result.data.data(), 1);
+#else
+  for (int k = 0; k < len; ++k) {
+    result.data[k] = data[k] * value;
   }
-
+#endif
   return result;
 }
 
 Tensor Tensor::divide(float const &value) {
   Tensor result(batch, height, width);
-  int i, j, k;
-
-  for (k = 0; k < batch; k++) {
-    for (i = 0; i < height; i++) {
-      for (j = 0; j < width; j++) {
-        result.data[k][i][j] = data[k][i][j] / value;
-      }
-    }
+#ifdef USE_BLAS
+  memset(result.data.data(), 0, sizeof(float) * result.len);
+  cblas_saxpy(this->len, 1.0 / value, this->data.data(), 1, result.data.data(), 1);
+#else
+  for (int k = 0; k < len; ++k) {
+    result.data[k] = data[k] / value;
   }
-
+#endif
   return result;
 }
 
 Tensor Tensor::add(float const &value) {
   Tensor result(batch, height, width);
-  int i, j, k;
-
-  for (k = 0; k < batch; k++) {
-    for (i = 0; i < height; i++) {
-      for (j = 0; j < width; j++) {
-        result.data[k][i][j] = data[k][i][j] + value;
-      }
-    }
+#ifdef USE_BLAS
+  cblas_scopy(this->len, this->data.data(), 1, result.data.data(), 1);
+  Tensor tmp(batch, height, width);
+  for (int i = 0; i < tmp.len; ++i)
+    tmp.data[i] = 1.0;
+  cblas_saxpy(this->len, value, tmp.data.data(), 1, result.data.data(), 1);
+#else
+  for (int k = 0; k < batch; ++k) {
+    result.data[k] = data[k] + value;
   }
+#endif
 
   return result;
 }
@@ -112,49 +131,67 @@ Tensor Tensor::add(Tensor const &m) const {
   assert(height == m.height && width == m.width);
 
   Tensor result(batch, height, width);
+#ifdef USE_BLAS
+  cblas_scopy(this->len, this->data.data(), 1, result.data.data(), 1);
+  int size = this->width * this->height;
+  if (m.batch == 1) {
+    for (int k = 0; k < batch; ++k) {
+      cblas_saxpy(size, 1.0, m.data.data(), 1, &(result.data.data()[k * size]), 1);
+    }
+  } else {
+    cblas_saxpy(this->len, 1.0, m.data.data(), 1, result.data.data(), 1);
+  }
+#else
   int i, j, k;
   if (m.batch == 1) {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < width; j++) {
-          result.data[k][i][j] = data[k][i][j] + m.data[0][i][j];
-        }
+    for (k = 0; k < batch; ++k) {
+      for (i = 0; i < m.len; ++i) {
+        j = k * m.len;
+        result.data[j + i] = data[j + i] + m.data[i];
       }
     }
   } else {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < width; j++) {
-          result.data[k][i][j] = data[k][i][j] + m.data[k][i][j];
-        }
-      }
+    for (k = 0; k < len; ++k) {
+      result.data[k] = data[k] + m.data[k];
     }
   }
+#endif
+
   return result;
 }
 
 Tensor Tensor::subtract(Tensor const &m) const {
   assert(height == m.height && width == m.width);
   Tensor result(batch, height, width);
-  int i, j, k;
+
+#ifdef USE_BLAS
+  cblas_scopy(this->len, this->data.data(), 1, result.data.data(), 1);
+  int size = this->width * this->height;
+  float alpha = -1.0;
 
   if (m.batch == 1) {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < width; j++) {
-          result.data[k][i][j] = data[k][i][j] - m.data[0][i][j];
-        }
-      }
+    for (int k = 0; k < batch; ++k) {
+      cblas_saxpy(size, alpha, m.data.data(), 1, &(result.data.data()[k * size]), 1);
     }
   } else {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < width; j++) {
-          result.data[k][i][j] = data[k][i][j] - m.data[k][i][j];
-        }
+    assert(batch == m.batch);
+    cblas_saxpy(this->len, alpha, m.data.data(), 1, result.data.data(), 1);
+  }
+#else
+  int i, j, k;
+  if (m.batch == 1) {
+    for (k = 0; k < batch; ++k) {
+      for (i = 0; i < m.len; ++i) {
+        j = k * m.len;
+        result.data[j + i] = data[j + i] - m.data[i];
       }
     }
+  } else {
+    for (k = 0; k < len; ++k) {
+      result.data[k] = data[k] - m.data[k];
+    }
   }
+#endif
   return result;
 }
 
@@ -162,24 +199,30 @@ Tensor Tensor::multiply(Tensor const &m) const {
   assert(height == m.height && width == m.width);
   Tensor result(batch, height, width);
 
-  int i, j, k;
-
+  int end = this->len / 4;
+  int e = width * height / 4;
+  int i;
   if (m.batch == 1) {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < width; j++) {
-          result.data[k][i][j] = data[k][i][j] * m.data[0][i][j];
-        }
+    for (int k = 0; k < batch; ++k) {
+      int b = k * width * height;
+      for (i = 0; i < e * 4; i += 4) {
+        result.data[b + i + 0] = this->data[b + i + 0] * m.data[i + 0];
+        result.data[b + i + 1] = this->data[b + i + 1] * m.data[i + 1];
+        result.data[b + i + 2] = this->data[b + i + 2] * m.data[i + 2];
+        result.data[b + i + 3] = this->data[b + i + 3] * m.data[i + 3];
       }
+      for (int j = i; j < width * height; j++)
+        result.data[b + j] = this->data[b + j] * m.data[j];
     }
   } else {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < width; j++) {
-          result.data[k][i][j] = data[k][i][j] * m.data[k][i][j];
-        }
-      }
+    for (i = 0; i < end * 4; i += 4) {
+      result.data[i + 0] = this->data[i + 0] * m.data[i + 0];
+      result.data[i + 1] = this->data[i + 1] * m.data[i + 1];
+      result.data[i + 2] = this->data[i + 2] * m.data[i + 2];
+      result.data[i + 3] = this->data[i + 3] * m.data[i + 3];
     }
+    for (int j = i; j < len; ++j)
+      result.data[j] = this->data[j] * m.data[j];
   }
 
   return result;
@@ -189,24 +232,31 @@ Tensor Tensor::divide(Tensor const &m) const {
   assert(height == m.height && width == m.width);
   Tensor result(batch, height, width);
 
-  int i, j, k;
+  int end = this->len / 4;
+  int e = width * height / 4;
+  int i;
 
   if (m.batch == 1) {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < width; j++) {
-          result.data[k][i][j] = data[k][i][j] / m.data[0][i][j];
-        }
+    for (int k = 0; k < batch; ++k) {
+      int b = k * width * height;
+      for (i = 0; i < e * 4; i += 4) {
+        result.data[b + i + 0] = this->data[b + i + 0] / m.data[i + 0];
+        result.data[b + i + 1] = this->data[b + i + 1] / m.data[i + 1];
+        result.data[b + i + 2] = this->data[b + i + 2] / m.data[i + 2];
+        result.data[b + i + 3] = this->data[b + i + 3] / m.data[i + 3];
       }
+      for (int j = i - 1; j < width * height; ++j)
+        result.data[b + j] = this->data[b + j] / m.data[j];
     }
   } else {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < width; j++) {
-          result.data[k][i][j] = data[k][i][j] / m.data[k][i][j];
-        }
-      }
+    for (i = 0; i < end * 4; i += 4) {
+      result.data[i + 0] = this->data[i + 0] / m.data[i + 0];
+      result.data[i + 1] = this->data[i + 1] / m.data[i + 1];
+      result.data[i + 2] = this->data[i + 2] / m.data[i + 2];
+      result.data[i + 3] = this->data[i + 3] / m.data[i + 3];
     }
+    for (int j = i - 1; j < len; ++j)
+      result.data[j] = this->data[j] / m.data[j];
   }
 
   return result;
@@ -217,17 +267,21 @@ Tensor Tensor::divide(Tensor const &m) const {
  * Therefore the result has M(batch, 1, 1) dimension.
  */
 Tensor Tensor::sum() const {
-  int i, j, k;
+  int k;
   Tensor ret(batch, 1, 1);
-
-  for (k = 0; k < batch; k++) {
-    ret.data[k][0][0] = 0.0;
-    for (i = 0; i < height; i++) {
-      for (j = 0; j < width; j++) {
-        ret.data[k][0][0] += data[k][i][j];
-      }
+#ifdef USE_BLAS
+  for (k = 0; k < batch; ++k)
+    ret.data[k] = cblas_sasum(width * height, &(data.data()[k * width * height]), 1);
+#else
+  int i;
+  for (k = 0; k < batch; ++k) {
+    int id = k * width * height;
+    ret.data[id] = 0.0;
+    for (i = 0; i < height * width; ++i) {
+      ret.data[id] += data[id + i];
     }
   }
+#endif
 
   return ret;
 }
@@ -238,19 +292,41 @@ Tensor Tensor::sum() const {
  */
 Tensor Tensor::dot(Tensor const &m) const {
   assert(width == m.height);
-  int i, j, h, k, mwidth = m.width;
-  float w = 0;
-
+  int mwidth = m.width;
   Tensor result(batch, height, mwidth);
+
+#ifdef USE_BLAS
+  float alpha_dgemm = 1.0;
+  float beta_dgemm = 1.0;
   if (m.batch == 1) {
-    for (k = 0; k < batch; k++) {
-      for (i = 0; i < height; i++) {
-        for (j = 0; j < mwidth; j++) {
-          for (h = 0; h < width; h++) {
-            w += data[k][i][h] * m.data[0][h][j];
+    for (int k = 0; k < batch; k++) {
+      int i = k * width * height;
+      int ii = k * height * mwidth;
+      cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, height, mwidth, width, alpha_dgemm, &(data.data()[i]),
+                  width, m.data.data(), mwidth, beta_dgemm, &(result.data.data()[ii]), mwidth);
+    }
+  } else {
+    for (int k = 0; k < batch; k++) {
+      int i = k * width * height;
+      int j = k * m.width * m.height;
+      int ii = k * height * mwidth;
+
+      cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, height, mwidth, width, alpha_dgemm, &(data.data()[i]),
+                  width, &(m.data.data()[j]), mwidth, beta_dgemm, &(result.data.data()[ii]), mwidth);
+    }
+  }
+#else
+  float w = 0.0;
+  int i, j, k, h;
+  if (m.batch == 1) {
+    for (k = 0; k < batch; ++k) {
+      for (i = 0; i < height; ++i) {
+        for (j = 0; j < mwidth; ++j) {
+          for (h = 0; h < width; ++h) {
+            w += data[k * height * width + i * width + h] * m.data[h * mwidth + j];
           }
-          result.data[k][i][j] = w;
-          w = 0;
+          result.data[k * height * mwidth + i * mwidth + j] = w;
+          w = 0.0;
         }
       }
     }
@@ -259,14 +335,15 @@ Tensor Tensor::dot(Tensor const &m) const {
       for (i = 0; i < height; i++) {
         for (j = 0; j < mwidth; j++) {
           for (h = 0; h < width; h++) {
-            w += data[k][i][h] * m.data[k][h][j];
+            w += data[k * height * width + i * width + h] * m.data[k * width * mwidth + h * mwidth + j];
           }
-          result.data[k][i][j] = w;
-          w = 0;
+          result.data[k * height * mwidth + i * mwidth + j] = w;
+          w = 0.0;
         }
       }
     }
   }
+#endif
 
   return result;
 }
@@ -274,10 +351,11 @@ Tensor Tensor::dot(Tensor const &m) const {
 Tensor Tensor::transpose() const {
   Tensor result(batch, width, height);
   int i, j, k;
-  for (k = 0; k < batch; k++) {
-    for (i = 0; i < width; i++) {
-      for (j = 0; j < height; j++) {
-        result.data[k][i][j] = data[k][j][i];
+  for (k = 0; k < batch; ++k) {
+    int b = k * width * height;
+    for (i = 0; i < width; ++i) {
+      for (j = 0; j < height; ++j) {
+        result.data[b + i * height + j] = data[b + j * width + i];
       }
     }
   }
@@ -286,15 +364,11 @@ Tensor Tensor::transpose() const {
 
 Tensor Tensor::applyFunction(float (*function)(float)) const {
   Tensor result(batch, height, width);
-  int i, j, k;
+  int i;
+
+  for (i = 0; i < this->len; ++i)
+    result.data[i] = (*function)(data[i]);
 
-  for (k = 0; k < batch; k++) {
-    for (i = 0; i < height; i++) {
-      for (j = 0; j < width; j++) {
-        result.data[k][i][j] = (*function)(data[k][i][j]);
-      }
-    }
-  }
   return result;
 }
 
@@ -304,7 +378,7 @@ void Tensor::print(std::ostream &out) const {
   for (k = 0; k < batch; k++) {
     for (i = 0; i < height; i++) {
       for (j = 0; j < width; j++) {
-        out << data[k][i][j] << " ";
+        out << data[k * width * height + i * width + j] << " ";
       }
       out << std::endl;
     }
@@ -318,18 +392,18 @@ std::ostream &operator<<(std::ostream &out, Tensor const &m) {
 }
 
 Tensor &Tensor::copy(const Tensor &from) {
-  if (this != &from && from.data.size() != 0) {
+  if (this != &from && from.len != 0) {
     height = from.height;
     width = from.width;
     batch = from.batch;
-    for (int k = 0; k < batch; k++) {
-      for (int i = 0; i < height; i++) {
-        for (int j = 0; j < width; j++) {
-          data[k][i][j] = from.data[k][i][j];
-        }
-      }
-    }
+#ifdef USE_BLAS
+    cblas_scopy(this->len, from.data.data(), 1, this->data.data(), 1);
+#else
+    for (int i = 0; i < len; ++i)
+      data[i] = from.data[i];
+#endif
   }
+
   return *this;
 }
 
@@ -338,32 +412,21 @@ Tensor &Tensor::copy(const Tensor &from) {
  */
 std::vector<float> Tensor::Mat2Vec() {
   std::vector<float> ret;
-  for (int k = 0; k < batch; k++)
-    for (int i = 0; i < height; i++)
-      for (int j = 0; j < width; j++)
-        ret.push_back(data[k][i][j]);
+
+  for (int i = 0; i < this->len; i++)
+    ret.push_back(data[i]);
 
   return ret;
 }
 
 void Tensor::save(std::ofstream &file) {
-  for (int k = 0; k < batch; k++) {
-    for (int i = 0; i < height; i++) {
-      for (int j = 0; j < width; j++) {
-        file.write((char *)&data[k][i][j], sizeof(float));
-      }
-    }
-  }
+  for (int i = 0; i < this->len; i++)
+    file.write((char *)&data[i], sizeof(float));
 }
 
 void Tensor::read(std::ifstream &file) {
-  for (int k = 0; k < batch; k++) {
-    for (int i = 0; i < height; i++) {
-      for (int j = 0; j < width; j++) {
-        file.read((char *)&data[k][i][j], sizeof(float));
-      }
-    }
-  }
+  for (int i = 0; i < this->len; i++)
+    file.read((char *)&data[i], sizeof(float));
 }
 
 /**
@@ -377,25 +440,17 @@ Tensor Tensor::average() const {
   Tensor result(1, height, width);
   for (int i = 0; i < height; i++) {
     for (int j = 0; j < width; j++) {
-      result.data[0][i][j] = 0.0;
+      result.data[i * width + j] = 0.0;
       for (int k = 0; k < batch; k++) {
-        result.data[0][i][j] += data[k][i][j];
+        result.data[i * width + j] += data[k * width * height + i * width + j];
       }
-      result.data[0][i][j] = result.data[0][i][j] / (float)batch;
+      result.data[i * width + j] = result.data[i * width + j] / (float)batch;
     }
   }
   return result;
 }
 
-void Tensor::setZero() {
-  for (int k = 0; k < batch; k++) {
-    for (int i = 0; i < height; i++) {
-      for (int j = 0; j < width; j++) {
-        this->data[k][i][j] = 0.0;
-      }
-    }
-  }
-}
+void Tensor::setZero() { memset(this->data.data(), 0, sizeof(float) * this->len); }
 
 Tensor Tensor::softmax() const {
   Tensor result(batch, height, width);
@@ -404,21 +459,35 @@ Tensor Tensor::softmax() const {
   divisor.setZero();
 
   for (int k = 0; k < batch; k++) {
+    int index = k * height;
     for (int i = 0; i < height; i++) {
       for (int j = 0; j < width; j++) {
-        divisor.data[k][i][0] += exp(this->data[k][i][j]);
+        divisor.data[index + i] += exp(this->data[k * height * width + i * width + j]);
       }
     }
   }
 
   for (int k = 0; k < batch; k++) {
+    int index = k * height;
     for (int i = 0; i < height; i++) {
       for (int j = 0; j < width; j++) {
-        result.data[k][i][j] = exp(this->data[k][i][j]) / divisor.data[k][i][0];
+        int id = k * height * width + i * width + j;
+        result.data[id] = exp(this->data[id]) / divisor.data[index + i];
       }
     }
   }
+
   return result;
 }
 
-void Tensor::setValue(int batch, int height, int width, float value) { this->data[batch][height][width] = value; }
+int Tensor::argmax() {
+  int index = 0;
+  float maximum = 0.0;
+  for (int i = 0; i < len; i++) {
+    if (this->data[i] > maximum) {
+      maximum = this->data[i];
+      index = i;
+    }
+  }
+  return index;
+}