};
/**
+ * @brief Computes @f$ y = log_{\gamma}(\alpha x + \beta) @f$,
+ * as specified by the scale @f$ \alpha @f$, shift @f$ \beta @f$,
+ * and base @f$ \gamma @f$.
+ */
+template <typename Dtype>
+class LogLayer : public NeuronLayer<Dtype> {
+ public:
+ /**
+ * @param param provides LogParameter log_param,
+ * with LogLayer options:
+ * - scale (\b optional, default 1) the scale @f$ \alpha @f$
+ * - shift (\b optional, default 0) the shift @f$ \beta @f$
+ * - base (\b optional, default -1 for a value of @f$ e \approx 2.718 @f$)
+ * the base @f$ \gamma @f$
+ */
+ explicit LogLayer(const LayerParameter& param)
+ : NeuronLayer<Dtype>(param) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ virtual inline const char* type() const { return "Log"; }
+
+ protected:
+ /**
+ * @param bottom input Blob vector (length 1)
+ * -# @f$ (N \times C \times H \times W) @f$
+ * the inputs @f$ x @f$
+ * @param top output Blob vector (length 1)
+ * -# @f$ (N \times C \times H \times W) @f$
+ * the computed outputs @f$
+ * y = log_{\gamma}(\alpha x + \beta)
+ * @f$
+ */
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+
+ /**
+ * @brief Computes the error gradient w.r.t. the exp inputs.
+ *
+ * @param top output Blob vector (length 1), providing the error gradient with
+ * respect to the outputs
+ * -# @f$ (N \times C \times H \times W) @f$
+ * containing error gradients @f$ \frac{\partial E}{\partial y} @f$
+ * with respect to computed outputs @f$ y @f$
+ * @param propagate_down see Layer::Backward.
+ * @param bottom input Blob vector (length 1)
+ * -# @f$ (N \times C \times H \times W) @f$
+ * the inputs @f$ x @f$; Backward fills their diff with
+ * gradients @f$
+ * \frac{\partial E}{\partial x} =
+ * \frac{\partial E}{\partial y} y \alpha \log_e(gamma)
+ * @f$ if propagate_down[0]
+ */
+ virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+ Dtype base_scale_;
+ Dtype input_scale_, input_shift_;
+ Dtype backward_num_scale_;
+};
+
+/**
* @brief Computes @f$ y = (\alpha x + \beta) ^ \gamma @f$,
* as specified by the scale @f$ \alpha @f$, shift @f$ \beta @f$,
* and power @f$ \gamma @f$.
void caffe_exp(const int n, const Dtype* a, Dtype* y);
template <typename Dtype>
+void caffe_log(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
void caffe_abs(const int n, const Dtype* a, Dtype* y);
template <typename Dtype>
void caffe_gpu_exp(const int n, const Dtype* a, Dtype* y);
template <typename Dtype>
+void caffe_gpu_log(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
void caffe_gpu_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
// caffe_gpu_rng_uniform with two arguments generates integers in the range
DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));
+DEFINE_VSL_UNARY_FUNC(Ln, y[i] = log(a[i]));
DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i]));
// A simple way to define the vsl unary functions with singular parameter b.
--- /dev/null
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/neuron_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void LogLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ NeuronLayer<Dtype>::LayerSetUp(bottom, top);
+ const Dtype base = this->layer_param_.log_param().base();
+ if (base != Dtype(-1)) {
+ CHECK_GT(base, 0) << "base must be strictly positive.";
+ }
+ // If base == -1, interpret the base as e and set log_base = 1 exactly.
+ // Otherwise, calculate its log explicitly.
+ const Dtype log_base = (base == Dtype(-1)) ? Dtype(1) : log(base);
+ CHECK(!isnan(log_base))
+ << "NaN result: log(base) = log(" << base << ") = " << log_base;
+ CHECK(!isinf(log_base))
+ << "Inf result: log(base) = log(" << base << ") = " << log_base;
+ base_scale_ = Dtype(1) / log_base;
+ CHECK(!isnan(base_scale_))
+ << "NaN result: 1/log(base) = 1/log(" << base << ") = " << base_scale_;
+ CHECK(!isinf(base_scale_))
+ << "Inf result: 1/log(base) = 1/log(" << base << ") = " << base_scale_;
+ input_scale_ = this->layer_param_.log_param().scale();
+ input_shift_ = this->layer_param_.log_param().shift();
+ backward_num_scale_ = input_scale_ / log_base;
+}
+
+template <typename Dtype>
+void LogLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const int count = bottom[0]->count();
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = top[0]->mutable_cpu_data();
+ if (input_scale_ == Dtype(1) && input_shift_ == Dtype(0)) {
+ caffe_log(count, bottom_data, top_data);
+ } else {
+ caffe_copy(count, bottom_data, top_data);
+ if (input_scale_ != Dtype(1)) {
+ caffe_scal(count, input_scale_, top_data);
+ }
+ if (input_shift_ != Dtype(0)) {
+ caffe_add_scalar(count, input_shift_, top_data);
+ }
+ caffe_log(count, top_data, top_data);
+ }
+ if (base_scale_ != Dtype(1)) {
+ caffe_scal(count, base_scale_, top_data);
+ }
+}
+
+template <typename Dtype>
+void LogLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ if (!propagate_down[0]) { return; }
+ const int count = bottom[0]->count();
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ const Dtype* top_diff = top[0]->cpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+ caffe_copy(count, bottom_data, bottom_diff);
+ if (input_scale_ != Dtype(1)) {
+ caffe_scal(count, input_scale_, bottom_diff);
+ }
+ if (input_shift_ != Dtype(0)) {
+ caffe_add_scalar(count, input_shift_, bottom_diff);
+ }
+ caffe_powx(count, bottom_diff, Dtype(-1), bottom_diff);
+ if (backward_num_scale_ != Dtype(1)) {
+ caffe_scal(count, backward_num_scale_, bottom_diff);
+ }
+ caffe_mul(count, top_diff, bottom_diff, bottom_diff);
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(LogLayer);
+#else
+
+template <typename Dtype>
+void LogLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top) {
+ const int count = bottom[0]->count();
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ Dtype* top_data = top[0]->mutable_gpu_data();
+ if (input_scale_ == Dtype(1) && input_shift_ == Dtype(0)) {
+ caffe_gpu_log(count, bottom_data, top_data);
+ } else {
+ caffe_copy(count, bottom_data, top_data);
+ if (input_scale_ != Dtype(1)) {
+ caffe_gpu_scal(count, input_scale_, top_data);
+ }
+ if (input_shift_ != Dtype(0)) {
+ caffe_gpu_add_scalar(count, input_shift_, top_data);
+ }
+ caffe_gpu_log(count, top_data, top_data);
+ }
+ if (base_scale_ != Dtype(1)) {
+ caffe_gpu_scal(count, base_scale_, top_data);
+ }
+}
+
+template <typename Dtype>
+void LogLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+ if (!propagate_down[0]) { return; }
+ const int count = bottom[0]->count();
+ const Dtype* bottom_data = bottom[0]->gpu_data();
+ const Dtype* top_diff = top[0]->gpu_diff();
+ Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+ caffe_copy(count, bottom_data, bottom_diff);
+ if (input_scale_ != Dtype(1)) {
+ caffe_gpu_scal(count, input_scale_, bottom_diff);
+ }
+ if (input_shift_ != Dtype(0)) {
+ caffe_gpu_add_scalar(count, input_shift_, bottom_diff);
+ }
+ caffe_gpu_powx(count, bottom_diff, Dtype(-1), bottom_diff);
+ if (backward_num_scale_ != Dtype(1)) {
+ caffe_gpu_scal(count, backward_num_scale_, bottom_diff);
+ }
+ caffe_gpu_mul(count, top_diff, bottom_diff, bottom_diff);
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(LogLayer);
+
+#endif
+
+INSTANTIATE_CLASS(LogLayer);
+REGISTER_LAYER_CLASS(Log);
+
+} // namespace caffe
// NOTE
// Update the next available ID when you add a new LayerParameter field.
//
-// LayerParameter next available layer-specific ID: 134 (last added: reshape_param)
+// LayerParameter next available layer-specific ID: 135 (last added: log_param)
message LayerParameter {
optional string name = 1; // the layer name
optional string type = 2; // the layer type
optional ImageDataParameter image_data_param = 115;
optional InfogainLossParameter infogain_loss_param = 116;
optional InnerProductParameter inner_product_param = 117;
+ optional LogParameter log_param = 134;
optional LRNParameter lrn_param = 118;
optional MemoryDataParameter memory_data_param = 119;
optional MVNParameter mvn_param = 120;
optional int32 axis = 5 [default = 1];
}
+// Message that stores parameters used by LogLayer
+message LogParameter {
+ // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0.
+ // Or if base is set to the default (-1), base is set to e,
+ // so y = ln(shift + scale * x) = log_e(shift + scale * x)
+ optional float base = 1 [default = -1.0];
+ optional float scale = 2 [default = 1.0];
+ optional float shift = 3 [default = 0.0];
+}
+
+// Message that stores parameters used by LRNLayer
message LRNParameter {
optional uint32 local_size = 1 [default = 5];
optional float alpha = 2 [default = 1.];
+ slope_data[c] * std::min(bottom_data[i], (Dtype)(0)));
}
}
+
+ void LogBottomInit() {
+ FillerParameter filler_param;
+ GaussianFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ Dtype* bottom_data = this->blob_bottom_->mutable_cpu_data();
+ caffe_exp(this->blob_bottom_->count(), bottom_data, bottom_data);
+ }
+
+ void TestLogForward(const float base, const float scale, const float shift) {
+ LogBottomInit();
+ LayerParameter layer_param;
+ layer_param.mutable_log_param()->set_base(base);
+ layer_param.mutable_log_param()->set_scale(scale);
+ layer_param.mutable_log_param()->set_shift(shift);
+ LogLayer<Dtype> layer(layer_param);
+ layer.SetUp(blob_bottom_vec_, blob_top_vec_);
+ layer.Forward(blob_bottom_vec_, blob_top_vec_);
+ const Dtype kDelta = 2e-4;
+ const Dtype* bottom_data = blob_bottom_->cpu_data();
+ const Dtype* top_data = blob_top_->cpu_data();
+ for (int i = 0; i < blob_bottom_->count(); ++i) {
+ const Dtype bottom_val = bottom_data[i];
+ const Dtype top_val = top_data[i];
+ if (base == -1) {
+ EXPECT_NEAR(top_val, log(shift + scale * bottom_val), kDelta);
+ } else {
+ EXPECT_NEAR(top_val, log(shift + scale * bottom_val) / log(base),
+ kDelta);
+ }
+ }
+ }
+
+ void TestLogGradient(const float base, const float scale, const float shift) {
+ LogBottomInit();
+ LayerParameter layer_param;
+ layer_param.mutable_log_param()->set_base(base);
+ layer_param.mutable_log_param()->set_scale(scale);
+ layer_param.mutable_log_param()->set_shift(shift);
+ LogLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-2);
+ checker.CheckGradientEltwise(&layer, blob_bottom_vec_, blob_top_vec_);
+ }
};
TYPED_TEST_CASE(NeuronLayerTest, TestDtypesAndDevices);
this->TestExpGradient(kBase, kScale, kShift);
}
+TYPED_TEST(NeuronLayerTest, TestLogLayer) {
+ typedef typename TypeParam::Dtype Dtype;
+ // Test default base of "-1" -- should actually set base := e.
+ const Dtype kBase = -1;
+ const Dtype kScale = 1;
+ const Dtype kShift = 0;
+ this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ // Test default base of "-1" -- should actually set base := e.
+ const Dtype kBase = -1;
+ const Dtype kScale = 1;
+ const Dtype kShift = 0;
+ this->TestLogGradient(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogLayerBase2) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kBase = 2;
+ const Dtype kScale = 1;
+ const Dtype kShift = 0;
+ this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradientBase2) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kBase = 2;
+ const Dtype kScale = 1;
+ const Dtype kShift = 0;
+ this->TestLogGradient(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogLayerBase2Shift1) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kBase = 2;
+ const Dtype kScale = 1;
+ const Dtype kShift = 1;
+ this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradientBase2Shift1) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kBase = 2;
+ const Dtype kScale = 1;
+ const Dtype kShift = 1;
+ this->TestLogGradient(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogLayerBase2Scale3) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kBase = 2;
+ const Dtype kScale = 3;
+ const Dtype kShift = 0;
+ this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradientBase2Scale3) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kBase = 2;
+ const Dtype kScale = 3;
+ const Dtype kShift = 0;
+ this->TestLogGradient(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogLayerBase2Shift1Scale3) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kBase = 2;
+ const Dtype kScale = 3;
+ const Dtype kShift = 1;
+ this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradientBase2Shift1Scale3) {
+ typedef typename TypeParam::Dtype Dtype;
+ const Dtype kBase = 2;
+ const Dtype kScale = 3;
+ const Dtype kShift = 1;
+ this->TestLogGradient(kBase, kScale, kShift);
+}
+
TYPED_TEST(NeuronLayerTest, TestDropoutHalf) {
const float kDropoutRatio = 0.5;
this->TestDropoutForward(kDropoutRatio);
}
template <>
+void caffe_log<float>(const int n, const float* a, float* y) {
+ vsLn(n, a, y);
+}
+
+template <>
+void caffe_log<double>(const int n, const double* a, double* y) {
+ vdLn(n, a, y);
+}
+
+template <>
void caffe_abs<float>(const int n, const float* a, float* y) {
vsAbs(n, a, y);
}
}
template <typename Dtype>
+__global__ void log_kernel(const int n, const Dtype* a, Dtype* y) {
+ CUDA_KERNEL_LOOP(index, n) {
+ y[index] = log(a[index]);
+ }
+}
+
+template <>
+void caffe_gpu_log<float>(const int N, const float* a, float* y) {
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ log_kernel<float><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+ N, a, y);
+}
+
+template <>
+void caffe_gpu_log<double>(const int N, const double* a, double* y) {
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ log_kernel<double><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+ N, a, y);
+}
+
+template <typename Dtype>
__global__ void powx_kernel(const int n, const Dtype* a,
const Dtype alpha, Dtype* y) {
CUDA_KERNEL_LOOP(index, n) {