[tensor] Add beta for tensor sum
authorParichay Kapoor <pk.kapoor@samsung.com>
Wed, 27 Oct 2021 11:36:01 +0000 (20:36 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 28 Oct 2021 08:09:24 +0000 (17:09 +0900)
Add beta for tensor to enable addition to existing tensor while summing.

Signed-off-by: Parichay Kapoor <pk.kapoor@samsung.com>
nntrainer/tensor/tensor.cpp
nntrainer/tensor/tensor.h

index a90b866..8e07cd9 100644 (file)
@@ -347,6 +347,7 @@ Tensor &Tensor::multiply_strided(Tensor const &m, Tensor &output) const {
   if (size() != m.size() || size() != output.size())
     throw std::invalid_argument(
       "Strided multiplication does not support broadcasting");
+
   /** @todo optimize this with a tensor iterator */
   for (unsigned int b = 0; b < batch(); ++b) {
     for (unsigned int c = 0; c < channel(); ++c) {
@@ -721,11 +722,12 @@ Tensor Tensor::sum_by_batch() const {
 /**
  * @brief Calculate sum according to the axis.
  */
-Tensor Tensor::sum(unsigned int axis, float alpha) const {
+Tensor Tensor::sum(unsigned int axis, float alpha, float beta) const {
   Tensor ret;
-  return sum(axis, ret, alpha);
+  return sum(axis, ret, alpha, beta);
 }
-Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha) const {
+Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha,
+                    float beta) const {
   const float *data = getData();
 
   if (axis >= 4)
@@ -745,7 +747,7 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha) const {
     Tensor ones(1, 1, 1, batch);
     ones.setValue(alpha);
     sgemv(CblasRowMajor, CblasTrans, batch, feat_len, 1, data, feat_len,
-          ones.getData(), 1, 0.0, ret.getData(), 1);
+          ones.getData(), 1, beta, ret.getData(), 1);
   } break;
   case 1: {
     CREATE_IF_EMPTY_DIMS(ret, dim.batch(), 1, dim.height(), dim.width());
@@ -756,7 +758,7 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha) const {
     float *rdata = ret.getData();
     for (unsigned int k = 0; k < dim.batch(); ++k) {
       sgemv(CblasRowMajor, CblasTrans, channel, feat_len, 1,
-            &data[k * dim.getFeatureLen()], feat_len, ones.getData(), 1, 0.0,
+            &data[k * dim.getFeatureLen()], feat_len, ones.getData(), 1, beta,
             &rdata[k * feat_len], 1);
     }
   } break;
@@ -773,7 +775,7 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha) const {
           k * dim.getFeatureLen() + c * dim.width() * dim.height();
         unsigned int ridx = k * ret.dim.getFeatureLen() + c * dim.width();
         sgemv(CblasRowMajor, CblasTrans, height, width, 1, &data[idx], width,
-              ones.getData(), 1, 0.0, &rdata[ridx], 1);
+              ones.getData(), 1, beta, &rdata[ridx], 1);
       }
     }
   } break;
@@ -783,8 +785,8 @@ Tensor &Tensor::sum(unsigned int axis, Tensor &ret, float alpha) const {
     unsigned int n = dim.width();
     Tensor ones(1, 1, 1, n);
     ones.setValue(alpha);
-    sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n, ones.getData(), 1, 0.0,
-          ret.getData(), 1);
+    sgemv(CblasRowMajor, CblasNoTrans, m, n, 1, data, n, ones.getData(), 1,
+          beta, ret.getData(), 1);
   } break;
   default:
     throw std::out_of_range("Error: Dimension cannot exceed 3");
@@ -1110,7 +1112,7 @@ std::ostream &operator<<(std::ostream &out, Tensor const &m) {
 }
 
 float *Tensor::getAddress(unsigned int i) {
-  if (i > this->dim.getDataLen()) {
+  if (i > getIndex(batch(), channel(), height(), width())) {
     ml_loge("Error: Index out of bounds");
     return nullptr;
   }
@@ -1119,7 +1121,7 @@ float *Tensor::getAddress(unsigned int i) {
 }
 
 const float *Tensor::getAddress(unsigned int i) const {
-  if (i > this->dim.getDataLen()) {
+  if (i > getIndex(batch(), channel(), height(), width())) {
     ml_loge("Error: Index out of bounds");
     return nullptr;
   }
index 9391e5c..6dd2ca0 100644 (file)
@@ -612,7 +612,7 @@ public:
    * @param[in] alpha Scale the sum by this value
    * @retval    Calculated Tensor
    */
-  Tensor sum(unsigned int axis, float alpha = 1.0) const;
+  Tensor sum(unsigned int axis, float alpha = 1.0, float beta = 0.0) const;
 
   /**
    * @brief     sum all the Tensor elements according to the axis
@@ -625,7 +625,8 @@ public:
    * @param[in] alpha Scale the sum by this value
    * @retval    Calculated Tensor
    */
-  Tensor &sum(unsigned int axis, Tensor &output, float alpha = 1.0) const;
+  Tensor &sum(unsigned int axis, Tensor &output, float alpha = 1.0,
+              float beta = 0.0) const;
 
   /**
    * @brief sum all the Tensor by multiple axes