[onert] Reduce memory usage in Conv qasymm uint8 operation (#4059)
authorHyeongseok Oh <hseok82.oh@samsung.com>
Wed, 2 Sep 2020 05:49:15 +0000 (14:49 +0900)
committerGitHub <noreply@github.com>
Wed, 2 Sep 2020 05:49:15 +0000 (14:49 +0900)
Move allocate for im2col space from configure to run  because im2col data cannot reuse

Signed-off-by: Hyeongseok Oh <hseok82.oh@samsung.com>
compute/cker/include/cker/operation/Conv.h

index f28004d..214f2e6 100644 (file)
@@ -23,6 +23,7 @@
 #include "cker/Utils.h"
 #include "cker/operation/reference/Conv.h"
 #include "cker/operation/optimized/Conv.h"
+#include <iostream>
 #include <vector>
 
 namespace nnfw
@@ -54,11 +55,7 @@ inline void TransposeFloatTensor(const float *input_data, const nnfw::cker::Shap
 class Conv
 {
 public:
-  Conv()
-      : _modified_filter_data(), _im2col_data(), _im2col_shape(4), _need_im2col(false),
-        _prepared(false)
-  {
-  }
+  Conv() : _modified_filter_data(), _im2col_shape(4), _need_im2col(false), _prepared(false) {}
 
   void prepare(const Shape &filter_shape, const float *filter_data, PaddingType padding_type,
                bool &is_replaced_weights, uint32_t dilationWidthFactor,
@@ -121,9 +118,21 @@ public:
                        params.stride_height);
     }
 
-    uint8_t *im2col_raw_data = _im2col_data.data();
-    optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
-                    bias_data, output_shape, output_data, _im2col_shape, im2col_raw_data);
+    int im2col_size = _need_im2col ? _im2col_shape.FlatSize() : 1;
+
+    // Use heap if size is larger than 8MB
+    if (im2col_size > 8 * 1024 * 1024)
+    {
+      std::unique_ptr<uint8_t[]> im2col_data = std::make_unique<uint8_t[]>(im2col_size);
+      optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+                      bias_data, output_shape, output_data, _im2col_shape, im2col_data.get());
+    }
+    else
+    {
+      uint8_t im2col_data[im2col_size];
+      optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+                      bias_data, output_shape, output_data, _im2col_shape, im2col_data);
+    }
   }
 
 private:
@@ -155,13 +164,11 @@ private:
       _im2col_shape.SetDim(1, output_shape.Dims(1));
       _im2col_shape.SetDim(2, output_shape.Dims(2));
       _im2col_shape.SetDim(3, input_shape.Dims(3) * kernel_shape.Dims(1) * kernel_shape.Dims(2));
-      _im2col_data.resize(_im2col_shape.FlatSize());
     }
   }
 
 private:
   std::vector<float> _modified_filter_data;
-  std::vector<uint8_t> _im2col_data;
   Shape _im2col_shape;
   bool _need_im2col;
   bool _prepared;