From 0a0f9be2fe60a66472c36838f38940e05827d259 Mon Sep 17 00:00:00 2001 From: Jeff Donahue Date: Sat, 24 May 2014 19:02:51 -0700 Subject: [PATCH] merge caffe_set definitions; define for int as well --- src/caffe/layers/pooling_layer.cpp | 4 +--- src/caffe/util/math_functions.cpp | 19 ++++++------------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/src/caffe/layers/pooling_layer.cpp b/src/caffe/layers/pooling_layer.cpp index 928c8c7..5d6921f 100644 --- a/src/caffe/layers/pooling_layer.cpp +++ b/src/caffe/layers/pooling_layer.cpp @@ -84,9 +84,7 @@ Dtype PoolingLayer::Forward_cpu(const vector*>& bottom, caffe_set(top_count, Dtype(-1), top_mask); } else { mask = max_idx_->mutable_cpu_data(); - for (int i = 0; i < top_count; ++i) { - mask[i] = -1; - } + caffe_set(top_count, -1, mask); } caffe_set(top_count, Dtype(-FLT_MAX), top_data); // The main loop diff --git a/src/caffe/util/math_functions.cpp b/src/caffe/util/math_functions.cpp index 2196b44..67274ef 100644 --- a/src/caffe/util/math_functions.cpp +++ b/src/caffe/util/math_functions.cpp @@ -120,10 +120,10 @@ void caffe_gpu_axpy(const int N, const double alpha, const double* X, CUBLAS_CHECK(cublasDaxpy(Caffe::cublas_handle(), N, &alpha, X, 1, Y, 1)); } -template <> -void caffe_set(const int N, const float alpha, float* Y) { +template +void caffe_set(const int N, const Dtype alpha, Dtype* Y) { if (alpha == 0) { - memset(Y, 0, sizeof(float) * N); + memset(Y, 0, sizeof(Dtype) * N); return; } for (int i = 0; i < N; ++i) { @@ -131,16 +131,9 @@ void caffe_set(const int N, const float alpha, float* Y) { } } -template <> -void caffe_set(const int N, const double alpha, double* Y) { - if (alpha == 0) { - memset(Y, 0, sizeof(double) * N); - return; - } - for (int i = 0; i < N; ++i) { - Y[i] = alpha; - } -} +template void caffe_set(const int N, const int alpha, int* Y); +template void caffe_set(const int N, const float alpha, float* Y); +template void caffe_set(const int N, const double alpha, double* Y); template <> void caffe_add_scalar(const int N, const float alpha, float* Y) { -- 2.7.4