#include "cker/Utils.h"
#include "cker/operation/reference/Conv.h"
#include "cker/operation/optimized/Conv.h"
+#include <iostream>
#include <vector>
namespace nnfw
class Conv
{
public:
- Conv()
- : _modified_filter_data(), _im2col_data(), _im2col_shape(4), _need_im2col(false),
- _prepared(false)
- {
- }
+ Conv() : _modified_filter_data(), _im2col_shape(4), _need_im2col(false), _prepared(false) {}
void prepare(const Shape &filter_shape, const float *filter_data, PaddingType padding_type,
bool &is_replaced_weights, uint32_t dilationWidthFactor,
params.stride_height);
}
- uint8_t *im2col_raw_data = _im2col_data.data();
- optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
- bias_data, output_shape, output_data, _im2col_shape, im2col_raw_data);
+ int im2col_size = _need_im2col ? _im2col_shape.FlatSize() : 1;
+
+ // Use heap if size is larger than 8MB
+ if (im2col_size > 8 * 1024 * 1024)
+ {
+ std::unique_ptr<uint8_t[]> im2col_data = std::make_unique<uint8_t[]>(im2col_size);
+ optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data, output_shape, output_data, _im2col_shape, im2col_data.get());
+ }
+ else
+ {
+ uint8_t im2col_data[im2col_size];
+ optimized::Conv(params, input_shape, input_data, filter_shape, filter_data, bias_shape,
+ bias_data, output_shape, output_data, _im2col_shape, im2col_data);
+ }
}
private:
_im2col_shape.SetDim(1, output_shape.Dims(1));
_im2col_shape.SetDim(2, output_shape.Dims(2));
_im2col_shape.SetDim(3, input_shape.Dims(3) * kernel_shape.Dims(1) * kernel_shape.Dims(2));
- _im2col_data.resize(_im2col_shape.FlatSize());
}
}
private:
std::vector<float> _modified_filter_data;
- std::vector<uint8_t> _im2col_data;
Shape _im2col_shape;
bool _need_im2col;
bool _prepared;