From 554f00ca49b72c11e4665c3967ae43e1b5512504 Mon Sep 17 00:00:00 2001 From: "jijoong.moon" Date: Wed, 17 Jun 2020 21:30:33 +0900 Subject: [PATCH] [ Pooling2D ] backwarding This PR provides backwarding process of Pooling 2D. . backwarding for max pooling 2D . backwarding for average pooling 2D . backwarding global_max, global_averge is NYI. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- nntrainer/include/pooling2d_layer.h | 18 +++--- nntrainer/src/pooling2d_layer.cpp | 95 +++++++++++++++++++++++------ test/unittest/unittest_nntrainer_layers.cpp | 93 ++++++++++++++++++++++++++++ 3 files changed, 179 insertions(+), 27 deletions(-) diff --git a/nntrainer/include/pooling2d_layer.h b/nntrainer/include/pooling2d_layer.h index 798fce6..3aaa02c 100644 --- a/nntrainer/include/pooling2d_layer.h +++ b/nntrainer/include/pooling2d_layer.h @@ -119,13 +119,6 @@ public: */ int setProperty(std::vector values); - /** - * @brief calculation convolution - * @param[in] in input tensor - * @param[out] status output of status - * @retval Tensor outoput tensor - */ - Tensor pooling2d(Tensor in, int &status); /* TO DO : support keras type of padding */ enum class PaddingType { @@ -139,7 +132,18 @@ private: unsigned int pooling_size[POOLING2D_DIM]; unsigned int stride[POOLING2D_DIM]; unsigned int padding[POOLING2D_DIM]; + std::vector max_idx; PoolingType pooling_type; + + /** + * @brief calculation convolution + * @param[in] batch batch index + * @param[in] in input tensor + * @param[out] status output of status + * @retval Tensor outoput tensor + */ + Tensor pooling2d(unsigned int batch, Tensor in, int &status); + }; } // namespace nntrainer diff --git a/nntrainer/src/pooling2d_layer.cpp b/nntrainer/src/pooling2d_layer.cpp index a596d9c..cc2a1df 100644 --- a/nntrainer/src/pooling2d_layer.cpp +++ b/nntrainer/src/pooling2d_layer.cpp @@ -42,13 +42,17 @@ int Pooling2DLayer::initialize(bool last) { hidden = Tensor(output_dim); + if (pooling_type == PoolingType::max) { + max_idx.resize(output_dim.getDataLen()); + } + return status; } Tensor Pooling2DLayer::forwarding(Tensor in, int &status) { for (unsigned int b = 0; b < in.batch(); ++b) { Tensor in_padded = zero_pad(b, in, padding); - Tensor result = pooling2d(in_padded, status); + Tensor result = pooling2d(b, in_padded, status); memcpy(hidden.getAddress(b * hidden.getDim().getFeatureLen()), result.getData(), result.getDim().getDataLen() * sizeof(float)); } @@ -59,9 +63,52 @@ Tensor Pooling2DLayer::forwarding(Tensor in, Tensor output, int &status) { return forwarding(in, status); } -Tensor Pooling2DLayer::backwarding(Tensor in, int iteration) { - // NYI - return in; +Tensor Pooling2DLayer::backwarding(Tensor derivative, int iteration) { + unsigned int batch = input_dim.batch(); + unsigned int channel = input_dim.channel(); + unsigned int height = input_dim.height(); + unsigned int width = input_dim.width(); + unsigned int p_height = pooling_size[0]; + unsigned int p_width = pooling_size[1]; + unsigned int p_size = p_height * p_width; + + unsigned int J, K; + Tensor result = Tensor(input_dim); + float *out = result.getData(); + switch (pooling_type) { + case PoolingType::max: { + for (unsigned int i = 0; i < derivative.getDim().getDataLen(); ++i) { + out[max_idx[i]] += derivative.getData()[i]; + } + } break; + case PoolingType::average: { + for (unsigned int b = 0; b < batch; ++b) { + for (unsigned int i = 0; i < channel; ++i) { + J = 0; + for (unsigned int j = 0; j <= height - p_height; j += stride[0]) { + K = 0; + for (unsigned int k = 0; k <= width - p_width; k += stride[1]) { + float del = derivative.getValue(b, i, J, K) / (p_size); + for (unsigned int pi = 0; pi < p_height; ++pi) { + for (unsigned int pj = 0; pj < p_width; ++pj) { + result.setValue(b, i, j + pi, k + pj, + result.getValue(b, i, j + pi, k + pj) + del); + } + } + K++; + } + J++; + } + } + } + } break; + case PoolingType::global_max: + case PoolingType::global_average: + default: + ml_loge("Error: Unknown Pooling Type"); + break; + } + return result; } void Pooling2DLayer::copy(std::shared_ptr l) { @@ -135,55 +182,63 @@ int Pooling2DLayer::setProperty(std::vector values) { return status; } -Tensor Pooling2DLayer::pooling2d(Tensor in, int &status) { +Tensor Pooling2DLayer::pooling2d(unsigned int batch, Tensor in, int &status) { unsigned int channel = in.channel(); unsigned int height = in.height(); unsigned int width = in.width(); unsigned int p_height = pooling_size[0]; unsigned int p_width = pooling_size[1]; + unsigned int base_idx = batch * output_dim.getFeatureLen(); Tensor output(output_dim.channel(), output_dim.height(), output_dim.width()); - unsigned int I, J; + unsigned int J, K; switch (pooling_type) { case PoolingType::max: { for (unsigned int i = 0; i < channel; ++i) { - I = 0; + J = 0; for (unsigned int j = 0; j <= height - p_height; j += stride[0]) { - J = 0; + K = 0; for (unsigned int k = 0; k <= width - p_width; k += stride[1]) { float max = std::numeric_limits::min(); for (unsigned int pi = 0; pi < p_height; ++pi) { for (unsigned int pj = 0; pj < p_width; ++pj) { float val = in.getValue(0, i, j + pi, k + pj); - if (max < val) + if (max < val) { + max_idx[base_idx + + i * output_dim.height() * output_dim.width() + + J * output_dim.width() + K] = + batch * input_dim.getFeatureLen() + i * height * width + + (j + pi) * width + (k + pj); max = val; + } } } - output.setValue(0, i, I, J, max); - J++; + output.setValue(0, i, J, K, max); + K++; } - I++; + J++; } } } break; case PoolingType::average: { + unsigned int p_size = p_height*p_width; for (unsigned int i = 0; i < channel; ++i) { - I = 0; + J = 0; for (unsigned int j = 0; j <= height - p_height; j += stride[0]) { - J = 0; + K = 0; for (unsigned int k = 0; k <= width - p_width; k += stride[1]) { float sum = 0.0; for (unsigned int pi = 0; pi < p_height; ++pi) { for (unsigned int pj = 0; pj < p_width; ++pj) { sum += in.getValue(0, i, j + pi, k + pj); - } - } - sum = sum / (p_height + p_width); - output.setValue(0, i, I, J, sum); - J++; + } + } + sum = sum / (p_size); + output.setValue(0, i, J, K, sum); + K++; } - I++; + J++; } } } break; diff --git a/test/unittest/unittest_nntrainer_layers.cpp b/test/unittest/unittest_nntrainer_layers.cpp index d7608a4..e3abc43 100644 --- a/test/unittest/unittest_nntrainer_layers.cpp +++ b/test/unittest/unittest_nntrainer_layers.cpp @@ -695,6 +695,99 @@ TEST(nntrainer_Pooling2D, forwarding_02_p) { } /** + * @brief Pooling 2D Layer + */ +TEST(nntrainer_Pooling2D, backwarding_01_p) { + int status = ML_ERROR_NONE; + nntrainer::Pooling2DLayer layer; + std::vector input_str; + nntrainer::TensorDim previous_dim; + previous_dim.setTensorDim("1:2:5:5"); + layer.setInputDimension(previous_dim); + + input_str.push_back("pooling_size= 2,2"); + input_str.push_back("stride=1, 1"); + input_str.push_back("padding=0,0"); + input_str.push_back("pooling = max"); + + status = layer.setProperty(input_str); + EXPECT_EQ(status, ML_ERROR_NONE); + status = layer.initialize(false); + EXPECT_EQ(status, ML_ERROR_NONE); + + nntrainer::Tensor in(1, 2, 4, 4); + nntrainer::Tensor out, result(1, 2, 5, 5); + std::ifstream file("test_1_goldenConv2DResult.out"); + result.read(file); + file.close(); + out = layer.forwarding(result, status); + EXPECT_EQ(status, ML_ERROR_NONE); + + for (unsigned int i = 0; i < in.getDim().getDataLen(); ++i) { + in.getData()[i] = 1.0; + } + + out = layer.backwarding(in, 0); + + float golden[50] = {0, 0, 0, 1, 0, 0, 1, 2, 0, 1, 0, 1, 4, 0, 0, 0, 1, + 2, 2, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 4, 0, + 1, 0, 3, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 2, 1, 0}; + + float *out_ptr = out.getData(); + + for (int i = 0; i < 50; ++i) { + EXPECT_FLOAT_EQ(out_ptr[i], golden[i]); + } +} + +/** + * @brief Pooling 2D Layer + */ +TEST(nntrainer_Pooling2D, backwarding_02_p) { + int status = ML_ERROR_NONE; + nntrainer::Pooling2DLayer layer; + std::vector input_str; + nntrainer::TensorDim previous_dim; + previous_dim.setTensorDim("1:2:5:5"); + layer.setInputDimension(previous_dim); + + input_str.push_back("pooling_size= 2,2"); + input_str.push_back("stride=1, 1"); + input_str.push_back("padding=0,0"); + input_str.push_back("pooling = average"); + + status = layer.setProperty(input_str); + EXPECT_EQ(status, ML_ERROR_NONE); + status = layer.initialize(false); + EXPECT_EQ(status, ML_ERROR_NONE); + + nntrainer::Tensor in(1, 2, 4, 4); + nntrainer::Tensor out, result(1, 2, 5, 5); + std::ifstream file("test_1_goldenConv2DResult.out"); + result.read(file); + file.close(); + out = layer.forwarding(result, status); + EXPECT_EQ(status, ML_ERROR_NONE); + + for (unsigned int i = 0; i < in.getDim().getDataLen(); ++i) { + in.getData()[i] = 1.0; + } + + out = layer.backwarding(in, 0); + + float golden[50] = {0.25, 0.5, 0.5, 0.5, 0.25, 0.5, 1.0, 1.0, 1.0, 0.5, + 0.5, 1.0, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 1.0, 0.5, + 0.25, 0.5, 0.5, 0.5, 0.25, 0.25, 0.5, 0.5, 0.5, 0.25, + 0.5, 1.0, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 1.0, 0.5, + 0.5, 1.0, 1.0, 1.0, 0.5, 0.25, 0.5, 0.5, 0.5, 0.25}; + float *out_ptr = out.getData(); + + for (int i = 0; i < 50; ++i) { + EXPECT_FLOAT_EQ(out_ptr[i], golden[i]); + } +} + +/** * @brief Flatten Layer */ TEST(nntrainer_Flatten, forwarding_01_p) { -- 2.7.4