From 2f05b03371e5936a478c7ad2946d0cd7c013920c Mon Sep 17 00:00:00 2001 From: Dmytro Mishkin Date: Wed, 25 Feb 2015 17:00:22 +0200 Subject: [PATCH] Added batch normalization layer with test and examples --- .../cifar10/cifar10_full_sigmoid_solver.prototxt | 28 ++ .../cifar10_full_sigmoid_solver_bn.prototxt | 28 ++ .../cifar10_full_sigmoid_train_test.prototxt | 212 +++++++++++++ .../cifar10_full_sigmoid_train_test_bn.prototxt | 284 +++++++++++++++++ examples/cifar10/train_full_sigmoid.sh | 7 + examples/cifar10/train_full_sigmoid_bn.sh | 7 + include/caffe/common_layers.hpp | 50 ++- src/caffe/layers/batch_norm_layer.cpp | 351 +++++++++++++++++++++ src/caffe/layers/batch_norm_layer.cu | 228 +++++++++++++ src/caffe/test/test_batch_norm_layer.cpp | 90 ++++++ 10 files changed, 1284 insertions(+), 1 deletion(-) create mode 100644 examples/cifar10/cifar10_full_sigmoid_solver.prototxt create mode 100644 examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt create mode 100644 examples/cifar10/cifar10_full_sigmoid_train_test.prototxt create mode 100644 examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt create mode 100755 examples/cifar10/train_full_sigmoid.sh create mode 100755 examples/cifar10/train_full_sigmoid_bn.sh create mode 100644 src/caffe/layers/batch_norm_layer.cpp create mode 100644 src/caffe/layers/batch_norm_layer.cu create mode 100644 src/caffe/test/test_batch_norm_layer.cpp diff --git a/examples/cifar10/cifar10_full_sigmoid_solver.prototxt b/examples/cifar10/cifar10_full_sigmoid_solver.prototxt new file mode 100644 index 0000000..7dd3ecb --- /dev/null +++ b/examples/cifar10/cifar10_full_sigmoid_solver.prototxt @@ -0,0 +1,28 @@ +# reduce learning rate after 120 epochs (60000 iters) by factor 0f 10 +# then another factor of 10 after 10 more epochs (5000 iters) + +# The train/test net protocol buffer definition +net: "examples/cifar10/cifar10_full_sigmoid_train_test.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of CIFAR10, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 10 +# Carry out testing every 1000 training iterations. +test_interval: 1000 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.001 +momentum: 0.9 +#weight_decay: 0.004 +# The learning rate policy +lr_policy: "step" +gamma: 1 +stepsize: 5000 +# Display every 200 iterations +display: 100 +# The maximum number of iterations +max_iter: 60000 +# snapshot intermediate results +snapshot: 10000 +snapshot_prefix: "examples/cifar10_full_sigmoid" +# solver mode: CPU or GPU +solver_mode: GPU diff --git a/examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt b/examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt new file mode 100644 index 0000000..a57b280 --- /dev/null +++ b/examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt @@ -0,0 +1,28 @@ +# reduce learning rate after 120 epochs (60000 iters) by factor 0f 10 +# then another factor of 10 after 10 more epochs (5000 iters) + +# The train/test net protocol buffer definition +net: "examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt" +# test_iter specifies how many forward passes the test should carry out. +# In the case of CIFAR10, we have test batch size 100 and 100 test iterations, +# covering the full 10,000 testing images. +test_iter: 10 +# Carry out testing every 1000 training iterations. +test_interval: 1000 +# The base learning rate, momentum and the weight decay of the network. +base_lr: 0.001 +momentum: 0.9 +#weight_decay: 0.004 +# The learning rate policy +lr_policy: "step" +gamma: 1 +stepsize: 5000 +# Display every 200 iterations +display: 100 +# The maximum number of iterations +max_iter: 60000 +# snapshot intermediate results +snapshot: 10000 +snapshot_prefix: "examples/cifar10_full_sigmoid_bn" +# solver mode: CPU or GPU +solver_mode: GPU diff --git a/examples/cifar10/cifar10_full_sigmoid_train_test.prototxt b/examples/cifar10/cifar10_full_sigmoid_train_test.prototxt new file mode 100644 index 0000000..6f5bf26 --- /dev/null +++ b/examples/cifar10/cifar10_full_sigmoid_train_test.prototxt @@ -0,0 +1,212 @@ +name: "CIFAR10_full" +layer { + name: "cifar" + type: "Data" + top: "data" + top: "label" + include { + phase: TRAIN + } + transform_param { + mean_file: "examples/cifar10/mean.binaryproto" + } + data_param { + source: "examples/cifar10/cifar10_train_lmdb" + batch_size: 111 + backend: LMDB + } +} +layer { + name: "cifar" + type: "Data" + top: "data" + top: "label" + include { + phase: TEST + } + transform_param { + mean_file: "examples/cifar10/mean.binaryproto" + } + data_param { + source: "examples/cifar10/cifar10_test_lmdb" + batch_size: 1000 + backend: LMDB + } +} +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + param { + lr_mult: 1 + } + param { + lr_mult: 2 + } + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.0001 + } + bias_filler { + type: "constant" + } + } +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + + + +layer { + name: "Sigmoid1" + type: "Sigmoid" + bottom: "pool1" + top: "Sigmoid1" +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "Sigmoid1" + top: "conv2" + param { + lr_mult: 1 + } + param { + lr_mult: 2 + } + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} + + +layer { + name: "Sigmoid2" + type: "Sigmoid" + bottom: "conv2" + top: "Sigmoid2" +} +layer { + name: "pool2" + type: "Pooling" + bottom: "Sigmoid2" + top: "pool2" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} +layer { + name: "conv3" + type: "Convolution" + bottom: "pool2" + top: "conv3" + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } + param { + lr_mult: 1 + } + param { + lr_mult: 1 + } + +} + +layer { + name: "Sigmoid3" + type: "Sigmoid" + bottom: "conv3" + top: "Sigmoid3" +} + +layer { + name: "pool3" + type: "Pooling" + bottom: "Sigmoid3" + top: "pool3" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "ip1" + type: "InnerProduct" + bottom: "pool3" + top: "ip1" + param { + lr_mult: 1 + decay_mult: 250 + } + param { + lr_mult: 0.2 + decay_mult: 0 + } + inner_product_param { + num_output: 10 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} +layer { + name: "accuracy" + type: "Accuracy" + bottom: "ip1" + bottom: "label" + top: "accuracy" + include { + phase: TEST + } +} +layer { + name: "loss" + type: "SoftmaxWithLoss" + bottom: "ip1" + bottom: "label" + top: "loss" +} diff --git a/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt b/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt new file mode 100644 index 0000000..85c2dff --- /dev/null +++ b/examples/cifar10/cifar10_full_sigmoid_train_test_bn.prototxt @@ -0,0 +1,284 @@ +name: "CIFAR10_full" +layer { + name: "cifar" + type: "Data" + top: "data" + top: "label" + include { + phase: TRAIN + } + transform_param { + mean_file: "examples/cifar10/mean.binaryproto" + } + data_param { + source: "examples/cifar10/cifar10_train_lmdb" + batch_size: 111 + backend: LMDB + } +} +layer { + name: "cifar" + type: "Data" + top: "data" + top: "label" + include { + phase: TEST + } + transform_param { + mean_file: "examples/cifar10/mean.binaryproto" + } + data_param { + source: "examples/cifar10/cifar10_test_lmdb" + batch_size: 1000 + backend: LMDB + } +} +layer { + name: "conv1" + type: "Convolution" + bottom: "data" + top: "conv1" + param { + lr_mult: 1 + } + param { + lr_mult: 2 + } + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.0001 + } + bias_filler { + type: "constant" + } + } +} +layer { + name: "pool1" + type: "Pooling" + bottom: "conv1" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "bn1" + type: "BatchNorm" + bottom: "pool1" + top: "bn1" + bn_param { + scale_filler { + type: "constant" + value: 1 + } + shift_filler { + type: "constant" + value: 0.001 + } + } + param { + lr_mult: 1.00001 + decay_mult: 0 + } + param { + lr_mult: 1.00001 + decay_mult: 0 + } +} + +layer { + name: "Sigmoid1" + type: "Sigmoid" + bottom: "bn1" + top: "Sigmoid1" +} + +layer { + name: "conv2" + type: "Convolution" + bottom: "Sigmoid1" + top: "conv2" + param { + lr_mult: 1 + } + param { + lr_mult: 2 + } + convolution_param { + num_output: 32 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} + + + +layer { + name: "bn2" + type: "BatchNorm" + bottom: "conv2" + top: "bn2" + bn_param { + scale_filler { + type: "constant" + value: 1 + } + shift_filler { + type: "constant" + value: 0.001 + } + } + param { + lr_mult: 1.00001 + decay_mult: 0 + } + param { + lr_mult: 1.00001 + decay_mult: 0 + } +} +layer { + name: "Sigmoid2" + type: "Sigmoid" + bottom: "bn2" + top: "Sigmoid2" +} +layer { + name: "pool2" + type: "Pooling" + bottom: "Sigmoid2" + top: "pool2" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} +layer { + name: "conv3" + type: "Convolution" + bottom: "pool2" + top: "conv3" + convolution_param { + num_output: 64 + pad: 2 + kernel_size: 5 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } + param { + lr_mult: 1 + } + param { + lr_mult: 1 + } + +} + + +layer { + name: "bn3" + type: "BatchNorm" + bottom: "conv3" + top: "bn3" + bn_param { + scale_filler { + type: "constant" + value: 1 + } + shift_filler { + type: "constant" + value: 0.001 + } + } + param { + lr_mult: 1.00001 + decay_mult: 0 + } + param { + lr_mult: 1.00001 + decay_mult: 0 + } +} +layer { + name: "Sigmoid3" + type: "Sigmoid" + bottom: "bn3" + top: "Sigmoid3" +} +layer { + name: "pool3" + type: "Pooling" + bottom: "Sigmoid3" + top: "pool3" + pooling_param { + pool: AVE + kernel_size: 3 + stride: 2 + } +} + +layer { + name: "ip1" + type: "InnerProduct" + bottom: "pool3" + top: "ip1" + param { + lr_mult: 1 + decay_mult: 250 + } + param { + lr_mult: 0.2 + decay_mult: 0 + } + inner_product_param { + num_output: 10 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + } + } +} +layer { + name: "accuracy" + type: "Accuracy" + bottom: "ip1" + bottom: "label" + top: "accuracy" + include { + phase: TEST + } +} +layer { + name: "loss" + type: "SoftmaxWithLoss" + bottom: "ip1" + bottom: "label" + top: "loss" +} diff --git a/examples/cifar10/train_full_sigmoid.sh b/examples/cifar10/train_full_sigmoid.sh new file mode 100755 index 0000000..9cff06d --- /dev/null +++ b/examples/cifar10/train_full_sigmoid.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +TOOLS=./build/tools + +$TOOLS/caffe train \ + --solver=examples/cifar10/cifar10_full_sigmoid_solver.prototxt + diff --git a/examples/cifar10/train_full_sigmoid_bn.sh b/examples/cifar10/train_full_sigmoid_bn.sh new file mode 100755 index 0000000..011387c --- /dev/null +++ b/examples/cifar10/train_full_sigmoid_bn.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env sh + +TOOLS=./build/tools + +$TOOLS/caffe train \ + --solver=examples/cifar10/cifar10_full_sigmoid_solver_bn.prototxt + diff --git a/include/caffe/common_layers.hpp b/include/caffe/common_layers.hpp index 21a27d7..09605db 100644 --- a/include/caffe/common_layers.hpp +++ b/include/caffe/common_layers.hpp @@ -79,6 +79,55 @@ class ArgMaxLayer : public Layer { }; /** +* @brief Batch Normalization per-channel with scale & shift linear transform. +* +*/ +template +class BatchNormLayer : public Layer { + public: + explicit BatchNormLayer(const LayerParameter& param) + : Layer(param) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "BN"; } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + // spatial mean & variance + Blob spatial_mean_, spatial_variance_; + // batch mean & variance + Blob batch_mean_, batch_variance_; + // buffer blob + Blob buffer_blob_; + + Blob x_norm_; + // x_sum_multiplier is used to carry out sum using BLAS + Blob spatial_sum_multiplier_, batch_sum_multiplier_; + + // dimension + int N_; + int C_; + int H_; + int W_; + // eps + Dtype var_eps_; +}; + +/** * @brief Index into the input blob along its first axis. * * This layer can be used to select, reorder, and even replicate examples in a @@ -146,7 +195,6 @@ class BatchReindexLayer : public Layer { const Dtype* ridx_data); }; - /** * @brief Takes at least two Blob%s and concatenates them along either the num * or channel dimension, outputting the result. diff --git a/src/caffe/layers/batch_norm_layer.cpp b/src/caffe/layers/batch_norm_layer.cpp new file mode 100644 index 0000000..8dea349 --- /dev/null +++ b/src/caffe/layers/batch_norm_layer.cpp @@ -0,0 +1,351 @@ +#include +#include + +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + template + void BatchNormLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + + x_norm_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + + // Figure out the dimensions + N_ = bottom[0]->num(); + C_ = bottom[0]->channels(); + H_ = bottom[0]->height(); + W_ = bottom[0]->width(); + + // mean + spatial_mean_.Reshape(N_, C_, 1, 1); + batch_mean_.Reshape(1, C_, 1, 1); + // variance + spatial_variance_.Reshape(N_, C_, 1, 1); + batch_variance_.Reshape(1, C_, 1, 1); + // buffer blod + buffer_blob_.Reshape(N_, C_, H_, W_); + + // fill spatial multiplier + spatial_sum_multiplier_.Reshape(1, 1, H_, W_); + Dtype* spatial_multipl_data = spatial_sum_multiplier_.mutable_cpu_data(); + caffe_set(spatial_sum_multiplier_.count(), Dtype(1), + spatial_multipl_data); + caffe_set(spatial_sum_multiplier_.count(), Dtype(0), + spatial_sum_multiplier_.mutable_cpu_diff()); + // fill batch multiplier + batch_sum_multiplier_.Reshape(N_, 1, 1, 1); + Dtype* batch_multiplier_data = batch_sum_multiplier_.mutable_cpu_data(); + caffe_set(batch_sum_multiplier_.count(), Dtype(1), + batch_multiplier_data); + caffe_set(batch_sum_multiplier_.count(), Dtype(0), + batch_sum_multiplier_.mutable_cpu_diff()); + this->param_propagate_down_.resize(this->blobs_.size(), true); + } + template + void BatchNormLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + CHECK_NE(top[0], bottom[0]) << this->type() << " Layer does not " + "allow in-place computation."; + + top[0]->Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + + x_norm_.Reshape(bottom[0]->num(), bottom[0]->channels(), + bottom[0]->height(), bottom[0]->width()); + // Figure out the dimensions + N_ = bottom[0]->num(); + C_ = bottom[0]->channels(); + H_ = bottom[0]->height(); + W_ = bottom[0]->width(); + var_eps_ = 1e-9; + + // mean + spatial_mean_.Reshape(N_, C_, 1, 1); + batch_mean_.Reshape(1, C_, 1, 1); + // variance + spatial_variance_.Reshape(N_, C_, 1, 1); + batch_variance_.Reshape(1, C_, 1, 1); + // buffer blod + buffer_blob_.Reshape(N_, C_, H_, W_); + + // fill spatial multiplier + spatial_sum_multiplier_.Reshape(1, 1, H_, W_); + Dtype* spatial_multipl_data = spatial_sum_multiplier_.mutable_cpu_data(); + caffe_set(spatial_sum_multiplier_.count(), Dtype(1), + spatial_multipl_data); + caffe_set(spatial_sum_multiplier_.count(), Dtype(0), + spatial_sum_multiplier_.mutable_cpu_diff()); + + // fill batch multiplier + batch_sum_multiplier_.Reshape(N_, 1, 1, 1); + Dtype* batch_multiplier_data = batch_sum_multiplier_.mutable_cpu_data(); + caffe_set(batch_sum_multiplier_.count(), Dtype(1), + batch_multiplier_data); + caffe_set(batch_sum_multiplier_.count(), Dtype(0), + batch_sum_multiplier_.mutable_cpu_diff()); + + // Check if we need to set up the weights + if (this->blobs_.size() > 0) { + LOG(INFO) << "Skipping parameter initialization"; + } else { + this->blobs_.resize(2); + + // fill scale with scale_filler + this->blobs_[0].reset(new Blob(1, C_, 1, 1)); + caffe_set(this->blobs_[0]->count(), Dtype(1), + this->blobs_[0]->mutable_cpu_data()); + + // fill shift with shift_filler + this->blobs_[1].reset(new Blob(1, C_, 1, 1)); + caffe_set(this->blobs_[1]->count(), Dtype(0), + this->blobs_[1]->mutable_cpu_data()); + } // parameter initialization + this->param_propagate_down_.resize(this->blobs_.size(), true); + } + + template + void BatchNormLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + const Dtype* const_top_data = top[0]->cpu_data(); + + const Dtype* scale_data = this->blobs_[0]->cpu_data(); + const Dtype* shift_data = this->blobs_[1]->cpu_data(); + + // put the squares of bottom into buffer_blob_ + caffe_powx(bottom[0]->count(), bottom_data, Dtype(2), + buffer_blob_.mutable_cpu_data()); + + // computes variance using var(X) = E(X^2) - (EX)^2 + // EX across spatial + caffe_cpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, + Dtype(1. / (H_ * W_)), bottom_data, + spatial_sum_multiplier_.cpu_data(), Dtype(0), + spatial_mean_.mutable_cpu_data()); + // EX across batch + caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1. / N_), + spatial_mean_.cpu_data(), + batch_sum_multiplier_.cpu_data(), Dtype(0), + batch_mean_.mutable_cpu_data()); + + // E(X^2) across spatial + caffe_cpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, + Dtype(1. / (H_ * W_)), buffer_blob_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + spatial_variance_.mutable_cpu_data()); + // E(X^2) across batch + caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1. / N_), + spatial_variance_.cpu_data(), + batch_sum_multiplier_.cpu_data(), Dtype(0), + batch_variance_.mutable_cpu_data()); + + caffe_powx(batch_mean_.count(), batch_mean_.cpu_data(), Dtype(2), + buffer_blob_.mutable_cpu_data()); // (EX)^2 + caffe_sub(batch_mean_.count(), batch_variance_.cpu_data(), + buffer_blob_.cpu_data(), + batch_variance_.mutable_cpu_data()); // variance + + // do mean and variance normalization + // subtract mean + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, + C_, 1, Dtype(1), + batch_sum_multiplier_.cpu_data(), + batch_mean_.cpu_data(), Dtype(0), + spatial_mean_.mutable_cpu_data()); + + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(-1), + spatial_mean_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + buffer_blob_.mutable_cpu_data()); + + caffe_add(buffer_blob_.count(), bottom_data, + buffer_blob_.cpu_data(), top_data); + + // normalize variance + caffe_add_scalar(batch_variance_.count(), var_eps_, + batch_variance_.mutable_cpu_data()); + caffe_powx(batch_variance_.count(), + batch_variance_.cpu_data(), Dtype(0.5), + batch_variance_.mutable_cpu_data()); + + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, + C_, 1, Dtype(1), + batch_sum_multiplier_.cpu_data(), + batch_variance_.cpu_data(), Dtype(0), + spatial_variance_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, + N_ * C_, H_ * W_, 1, Dtype(1), + spatial_variance_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + buffer_blob_.mutable_cpu_data()); + + caffe_div(buffer_blob_.count(), const_top_data, + buffer_blob_.cpu_data(), top_data); + + // Saving x_norm + caffe_copy(buffer_blob_.count(), const_top_data, + x_norm_.mutable_cpu_data()); + // scale + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.cpu_data(), scale_data, Dtype(0), + spatial_variance_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), + spatial_variance_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + buffer_blob_.mutable_cpu_data()); + caffe_mul(buffer_blob_.count(), top_data, + buffer_blob_.cpu_data(), top_data); + + // shift + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.cpu_data(), shift_data, Dtype(0), + spatial_mean_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, + N_ * C_, H_ * W_, 1, Dtype(1), + spatial_mean_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + buffer_blob_.mutable_cpu_data()); + caffe_add(buffer_blob_.count(), const_top_data, + buffer_blob_.cpu_data(), top_data); + } + + template + void BatchNormLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + const Dtype* top_diff = top[0]->cpu_diff(); + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + + Dtype* scale_diff = this->blobs_[0]->mutable_cpu_diff(); + Dtype* shift_diff = this->blobs_[1]->mutable_cpu_diff(); + const Dtype* scale_data = this->blobs_[0]->cpu_data(); + +// Propagate layer to parameters + // gradient w.r.t. scale + caffe_mul(buffer_blob_.count(), x_norm_.cpu_data(), + top_diff, buffer_blob_.mutable_cpu_data()); + // EX across spatial + caffe_cpu_gemv(CblasNoTrans, N_ * C_, + H_ * W_, Dtype(1), buffer_blob_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + spatial_variance_.mutable_cpu_diff()); + // EX across batch + caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1), + spatial_variance_.cpu_diff(), + batch_sum_multiplier_.cpu_data(), Dtype(0), scale_diff); + + // gradient w.r.t. shift + // EX across spatial + caffe_cpu_gemv(CblasNoTrans, N_ * C_, + H_ * W_, Dtype(1), top_diff, + spatial_sum_multiplier_.cpu_data(), + Dtype(0), spatial_mean_.mutable_cpu_diff()); + // EX across batch + caffe_cpu_gemv(CblasTrans, N_, C_, + Dtype(1), spatial_mean_.cpu_diff(), + batch_sum_multiplier_.cpu_data(), + Dtype(0), shift_diff); + +// Propagate down + + // put scale * top_diff to buffer_blob_ + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.cpu_data(), scale_data, Dtype(0), + spatial_variance_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), + spatial_variance_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + buffer_blob_.mutable_cpu_data()); + caffe_mul(buffer_blob_.count(), top_diff, buffer_blob_.cpu_data(), + buffer_blob_.mutable_cpu_data()); + + // use new top diff for computation + caffe_mul(buffer_blob_.count(), x_norm_.cpu_data(), + buffer_blob_.cpu_data(), bottom_diff); + // EX across spatial + caffe_cpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, + Dtype(1), bottom_diff, + spatial_sum_multiplier_.cpu_data(), Dtype(0), + spatial_mean_.mutable_cpu_data()); + // EX across batch + caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1), + spatial_mean_.cpu_data(), + batch_sum_multiplier_.cpu_data(), Dtype(0), + batch_mean_.mutable_cpu_data()); + + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, + N_, C_, 1, Dtype(1), + batch_sum_multiplier_.cpu_data(), + batch_mean_.cpu_data(), Dtype(0), + spatial_mean_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), + spatial_mean_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + bottom_diff); + + caffe_mul(buffer_blob_.count(), + x_norm_.cpu_data(), bottom_diff, bottom_diff); + + // EX across spatial + caffe_cpu_gemv(CblasNoTrans, N_ * C_, + H_ * W_, Dtype(1), buffer_blob_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + spatial_mean_.mutable_cpu_data()); + // EX across batch + caffe_cpu_gemv(CblasTrans, N_, C_, Dtype(1), + spatial_mean_.cpu_data(), + batch_sum_multiplier_.cpu_data(), Dtype(0), + batch_mean_.mutable_cpu_data()); + + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, + N_, C_, 1, Dtype(1), + batch_sum_multiplier_.cpu_data(), + batch_mean_.cpu_data(), Dtype(0), + spatial_mean_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, + N_ * C_, H_ * W_, 1, Dtype(1), + spatial_mean_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(1), bottom_diff); + + caffe_cpu_axpby(buffer_blob_.count(), Dtype(1), + buffer_blob_.cpu_data(), Dtype(-1. / (N_ * H_ * W_)), + bottom_diff); + + // put the squares of bottom into buffer_blob_ +// caffe_powx(buffer_blob_.count(), bottom_data, Dtype(2), +// buffer_blob_.mutable_cpu_data()); + + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, + N_, C_, 1, Dtype(1), + batch_sum_multiplier_.cpu_data(), + batch_variance_.cpu_data(), Dtype(0), + spatial_variance_.mutable_cpu_data()); + caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, + N_ * C_, H_ * W_, 1, Dtype(1), + spatial_variance_.cpu_data(), + spatial_sum_multiplier_.cpu_data(), Dtype(0), + buffer_blob_.mutable_cpu_data()); + + caffe_div(buffer_blob_.count(), bottom_diff, + buffer_blob_.cpu_data(), bottom_diff); + } +#ifdef CPU_ONLY +STUB_GPU(BatchNormLayer); +#endif + + INSTANTIATE_CLASS(BatchNormLayer); + REGISTER_LAYER_CLASS(BatchNorm); +} // namespace caffe + diff --git a/src/caffe/layers/batch_norm_layer.cu b/src/caffe/layers/batch_norm_layer.cu new file mode 100644 index 0000000..e87f8c6 --- /dev/null +++ b/src/caffe/layers/batch_norm_layer.cu @@ -0,0 +1,228 @@ +#include +#include + +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/math_functions.hpp" + +namespace caffe { + template + void BatchNormLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + const Dtype* const_top_data = top[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + Dtype* spatial_mean_data = spatial_mean_.mutable_gpu_data(); + Dtype* buffer_data = buffer_blob_.mutable_gpu_data(); + const Dtype* const_buffer_data = buffer_blob_.gpu_data(); + + + // put the squares of bottom into buffer_blob_ + caffe_gpu_powx(bottom[0]->count(), bottom_data, Dtype(2), + buffer_blob_.mutable_gpu_data()); + + // computes variance using var(X) = E(X^2) - (EX)^2 + // EX across spatial + caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, + Dtype(1. / (H_ * W_)), + bottom_data, spatial_sum_multiplier_.gpu_data(), + Dtype(0), spatial_mean_data); + // EX across batch + caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1. / N_), + spatial_mean_.gpu_data(), + batch_sum_multiplier_.gpu_data(), Dtype(0), + batch_mean_.mutable_gpu_data()); + + // E(X^2) across spatial + caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, + Dtype(1. / (H_ * W_)), buffer_data, + spatial_sum_multiplier_.gpu_data(), Dtype(0), + spatial_variance_.mutable_gpu_data()); + // E(X^2) across batch + caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1. / N_), + spatial_variance_.gpu_data(), + batch_sum_multiplier_.gpu_data(), Dtype(0), + batch_variance_.mutable_gpu_data()); + + caffe_gpu_powx(batch_mean_.count(), batch_mean_.gpu_data(), + Dtype(2), buffer_blob_.mutable_gpu_data()); // (EX)^2 + caffe_gpu_sub(batch_mean_.count(), batch_variance_.gpu_data(), + buffer_data, batch_variance_.mutable_gpu_data()); // variance + + // do mean and variance normalization + // subtract mean + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.gpu_data(), batch_mean_.gpu_data(), Dtype(0), + spatial_mean_data); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, H_ * W_, + 1, -Dtype(1), + spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(), Dtype(0), + buffer_blob_.mutable_gpu_data()); + + caffe_gpu_add(buffer_blob_.count(), bottom_data, buffer_data, top_data); + + // normalize variance + caffe_gpu_add_scalar(batch_variance_.count(), var_eps_, + batch_variance_.mutable_gpu_data()); + caffe_gpu_powx(batch_variance_.count(), batch_variance_.gpu_data(), + Dtype(0.5), batch_variance_.mutable_gpu_data()); + + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.gpu_data(), batch_variance_.gpu_data(), Dtype(0), + spatial_variance_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), + spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(), + Dtype(0), buffer_blob_.mutable_gpu_data()); + + caffe_gpu_div(buffer_blob_.count(), top_data, buffer_data, top_data); + + // Saving x_norm + caffe_copy(top[0]->count(), const_top_data, x_norm_.mutable_gpu_data()); + + // scale + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.gpu_data(), this->blobs_[0]->gpu_data(), + Dtype(0), spatial_variance_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), + spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(), + Dtype(0), buffer_blob_.mutable_gpu_data()); + + caffe_gpu_mul(buffer_blob_.count(), top_data, buffer_data, top_data); + + // shift + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.gpu_data(), + this->blobs_[1]->gpu_data(), Dtype(0), + spatial_mean_data); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, H_ * W_, 1, + Dtype(1), + spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(), Dtype(0), + buffer_blob_.mutable_gpu_data()); + caffe_gpu_add(buffer_blob_.count(), top_data, buffer_data, top_data); + } + + template + void BatchNormLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, + const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + const Dtype* const_bottom_diff = bottom[0]->gpu_diff(); + Dtype* spatial_mean_data = spatial_mean_.mutable_gpu_data(); + Dtype* buffer_data = buffer_blob_.mutable_gpu_data(); + const Dtype* const_buffer_data = buffer_blob_.gpu_data(); + + // Propage to layer params + // gradient w.r.t. scale + caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(), + top_diff, buffer_blob_.mutable_gpu_data()); + // EX across spatial + caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1), + buffer_data, spatial_sum_multiplier_.gpu_data(), Dtype(0), + spatial_variance_.mutable_gpu_data()); + // EX across batch + caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1), + spatial_variance_.gpu_data(), + batch_sum_multiplier_.gpu_data(), Dtype(0), + this->blobs_[0]->mutable_gpu_diff()); + + // gradient w.r.t. shift + // EX across spatial + caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1), + top_diff, spatial_sum_multiplier_.gpu_data(), + Dtype(0), spatial_mean_data); + // EX across batch + caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1), + spatial_mean_.gpu_data(), + batch_sum_multiplier_.gpu_data(), Dtype(0), + this->blobs_[1]->mutable_gpu_diff()); + + // Propagate down + // scale top diff + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.gpu_data(), this->blobs_[0]->gpu_data(), + Dtype(0), spatial_variance_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), + spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(), + Dtype(0), + buffer_blob_.mutable_gpu_data()); + caffe_gpu_mul(buffer_blob_.count(), top_diff, buffer_data, + buffer_blob_.mutable_gpu_data()); + + // use new top diff for computation + caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(), + buffer_data, bottom_diff); + // EX across spatial + caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, + Dtype(1), bottom_diff, + spatial_sum_multiplier_.gpu_data(), Dtype(0), spatial_mean_data); + // EX across batch + caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1), + spatial_mean_.gpu_data(), + batch_sum_multiplier_.gpu_data(), Dtype(0), + batch_mean_.mutable_gpu_data()); + + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.gpu_data(), + batch_mean_.gpu_data(), Dtype(0), + spatial_mean_data); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), spatial_mean_.gpu_data(), + spatial_sum_multiplier_.gpu_data(), Dtype(0), + bottom_diff); + + caffe_gpu_mul(buffer_blob_.count(), x_norm_.gpu_data(), + bottom_diff, bottom_diff); + + // EX across spatial + caffe_gpu_gemv(CblasNoTrans, N_ * C_, H_ * W_, Dtype(1), + buffer_data, spatial_sum_multiplier_.gpu_data(), + Dtype(0), spatial_mean_data); + + // EX across batch + caffe_gpu_gemv(CblasTrans, N_, C_, Dtype(1), + spatial_mean_.gpu_data(), + batch_sum_multiplier_.gpu_data(), Dtype(0), + batch_mean_.mutable_gpu_data()); + + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, + C_, 1, Dtype(1), + batch_sum_multiplier_.gpu_data(), + batch_mean_.gpu_data(), Dtype(0), + spatial_mean_data); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), + spatial_mean_.gpu_data(), spatial_sum_multiplier_.gpu_data(), + Dtype(1), + bottom_diff); + + caffe_gpu_axpby(buffer_blob_.count(), Dtype(1), buffer_data, + Dtype(-1. / (N_ * H_ * W_)), + bottom_diff); + + // put the squares of bottom into buffer_blob_ +// caffe_gpu_powx(buffer_blob_.count(), bottom_data, Dtype(2), +// buffer_blob_.mutable_gpu_data()); + + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_, C_, 1, Dtype(1), + batch_sum_multiplier_.gpu_data(), batch_variance_.gpu_data(), Dtype(0), + spatial_variance_.mutable_gpu_data()); + caffe_gpu_gemm(CblasNoTrans, CblasNoTrans, N_ * C_, + H_ * W_, 1, Dtype(1), + spatial_variance_.gpu_data(), spatial_sum_multiplier_.gpu_data(), + Dtype(0), + buffer_blob_.mutable_gpu_data()); + + caffe_gpu_div(buffer_blob_.count(), const_bottom_diff, + const_buffer_data, bottom_diff); + } + + INSTANTIATE_LAYER_GPU_FUNCS(BatchNormLayer); +} // namespace caffe + diff --git a/src/caffe/test/test_batch_norm_layer.cpp b/src/caffe/test/test_batch_norm_layer.cpp new file mode 100644 index 0000000..704efd5 --- /dev/null +++ b/src/caffe/test/test_batch_norm_layer.cpp @@ -0,0 +1,90 @@ +#include +#include +#include + +#include "gtest/gtest.h" + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/common_layers.hpp" +#include "caffe/filler.hpp" + +#include "caffe/test/test_caffe_main.hpp" +#include "caffe/test/test_gradient_check_util.hpp" + +#define BATCH_SIZE 2 +#define INPUT_DATA_SIZE 3 + +namespace caffe { + + template + class BatchNormLayerTest : public MultiDeviceTest { + typedef typename TypeParam::Dtype Dtype; + protected: + BatchNormLayerTest() + : blob_bottom_(new Blob(5, 2, 3, 4)), + blob_top_(new Blob()) { + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~BatchNormLayerTest() { delete blob_bottom_; delete blob_top_; } + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; + }; + + TYPED_TEST_CASE(BatchNormLayerTest, TestDtypesAndDevices); + + TYPED_TEST(BatchNormLayerTest, TestForward) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + + BatchNormLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + + // Test mean + int num = this->blob_bottom_->num(); + int channels = this->blob_bottom_->channels(); + int height = this->blob_bottom_->height(); + int width = this->blob_bottom_->width(); + + for (int j = 0; j < channels; ++j) { + Dtype sum = 0, var = 0; + for (int i = 0; i < num; ++i) { + for ( int k = 0; k < height; ++k ) { + for ( int l = 0; l < width; ++l ) { + Dtype data = this->blob_top_->data_at(i, j, k, l); + Dtype bottom_data = this->blob_bottom_->data_at(i, j, k, l); + sum += data; + var += data * data; + } + } + } + sum /= height * width * num; + var /= height * width * num; + + const Dtype kErrorBound = 0.001; + // expect zero mean + EXPECT_NEAR(0, sum, kErrorBound); + // expect unit variance + EXPECT_NEAR(1, var, kErrorBound); + } + } + + TYPED_TEST(BatchNormLayerTest, TestGradient) { + typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + + BatchNormLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-4); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); + } + +} // namespace caffe -- 2.7.4