1. Porting Caffe onto ARM Compute Library.
authorYao Honggui <honggui@openailan.com>
Fri, 2 Jun 2017 13:29:07 +0000 (21:29 +0800)
committerYao Honggui <honggui@openailan.com>
Fri, 2 Jun 2017 13:29:07 +0000 (21:29 +0800)
2. The release version is 0.2.0

72 files changed:
Makefile
Makefile.config.acl [new file with mode: 0644]
README.md
cmake/Dependencies.cmake
cmake/Modules/FindACL.cmake [new file with mode: 0644]
docs/caffeOnACL_release_notes_0_2_0.docx [new file with mode: 0644]
docs/caffeOnACL_user_guide_0_2_0.docx [new file with mode: 0644]
examples/cpp_classification/classification_profiling.cpp [new file with mode: 0644]
examples/cpp_classification/classification_profiling_gpu.cpp [new file with mode: 0644]
include/caffe/acl_layer.hpp [new file with mode: 0644]
include/caffe/caffe.hpp
include/caffe/common.hpp
include/caffe/layer.hpp
include/caffe/layers/acl_absval_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_base_activation_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_bnll_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_conv_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_inner_product_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_lrn_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_pooling_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_relu_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_sigmoid_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_softmax_layer.hpp [new file with mode: 0644]
include/caffe/layers/acl_tanh_layer.hpp [new file with mode: 0644]
include/caffe/util/device_alternate.hpp
include/caffe/util/hdf5.hpp
src/caffe/acl_layer.cpp [new file with mode: 0644]
src/caffe/common.cpp
src/caffe/layer.cpp
src/caffe/layer_factory.cpp
src/caffe/layers/absval_layer.cpp
src/caffe/layers/acl_absval_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_base_activation_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_bnll_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_conv_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_inner_product_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_lrn_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_pooling_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_relu_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_sigmoid_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_softmax_layer.cpp [new file with mode: 0644]
src/caffe/layers/acl_tanh_layer.cpp [new file with mode: 0644]
src/caffe/layers/bnll_layer.cpp
src/caffe/layers/hdf5_data_layer.cpp
src/caffe/layers/hdf5_data_layer.cu
src/caffe/layers/hdf5_output_layer.cpp
src/caffe/layers/hdf5_output_layer.cu
src/caffe/layers/inner_product_layer.cpp
src/caffe/net.cpp
src/caffe/solvers/sgd_solver.cpp
src/caffe/syncedmem.cpp
src/caffe/test/test_hdf5_output_layer.cpp
src/caffe/test/test_hdf5data_layer.cpp
src/caffe/util/hdf5.cpp
src/caffe/util/math_functions.cpp
unit_tests/Makefile [new file with mode: 0644]
unit_tests/pmu.c [new file with mode: 0644]
unit_tests/pmu.h [new file with mode: 0644]
unit_tests/prof_convolution_layer.cpp [new file with mode: 0644]
unit_tests/sgemm.cpp [new file with mode: 0644]
unit_tests/test.cpp [new file with mode: 0644]
unit_tests/test_caffe_main.cpp [new file with mode: 0644]
unit_tests/test_common.cpp [new file with mode: 0644]
unit_tests/test_convolution_layer.cpp [new file with mode: 0644]
unit_tests/test_fail.cpp [new file with mode: 0644]
unit_tests/test_inner_product_layer.cpp [new file with mode: 0644]
unit_tests/test_lrn_layer.cpp [new file with mode: 0644]
unit_tests/test_neuron_layer.cpp [new file with mode: 0644]
unit_tests/test_pooling_layer.cpp [new file with mode: 0644]
unit_tests/test_softmax_layer.cpp [new file with mode: 0644]
unit_tests/testbed.c [new file with mode: 0644]
unit_tests/testbed.h [new file with mode: 0644]

index 77900b6..2d5d305 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -7,6 +7,12 @@ $(error $(CONFIG_FILE) not found. See $(CONFIG_FILE).example.)
 endif
 include $(CONFIG_FILE)
 
+ifeq ($(CPU_ONLY),1)
+       USE_CUDA := 0
+endif
+ifeq ($(USE_ACL),1)
+       USE_CUDA := 0
+endif
 BUILD_DIR_LINK := $(BUILD_DIR)
 ifeq ($(RELEASE_BUILD_DIR),)
        RELEASE_BUILD_DIR := .$(BUILD_DIR)_release
@@ -172,13 +178,13 @@ endif
 CUDA_LIB_DIR += $(CUDA_DIR)/lib
 
 INCLUDE_DIRS += $(BUILD_INCLUDE_DIR) ./src ./include
-ifneq ($(CPU_ONLY), 1)
+ifeq ($(USE_CUDA), 1)
        INCLUDE_DIRS += $(CUDA_INCLUDE_DIR)
        LIBRARY_DIRS += $(CUDA_LIB_DIR)
        LIBRARIES := cudart cublas curand
 endif
 
-LIBRARIES += glog gflags protobuf boost_system boost_filesystem m hdf5_hl hdf5
+LIBRARIES += glog gflags protobuf boost_system boost_filesystem m 
 
 # handle IO dependencies
 USE_LEVELDB ?= 1
@@ -271,7 +277,7 @@ endif
 # libstdc++ for NVCC compatibility on OS X >= 10.9 with CUDA < 7.0
 ifeq ($(OSX), 1)
        CXX := /usr/bin/clang++
-       ifneq ($(CPU_ONLY), 1)
+       ifeq ($(USE_CUDA), 1)
                CUDA_VERSION := $(shell $(CUDA_DIR)/bin/nvcc -V | grep -o 'release [0-9.]*' | tr -d '[a-z ]')
                ifeq ($(shell echo | awk '{exit $(CUDA_VERSION) < 7.0;}'), 1)
                        CXXFLAGS += -stdlib=libstdc++
@@ -296,6 +302,10 @@ ifeq ($(OSX), 1)
        ORIGIN := @loader_path
        VERSIONFLAGS += -Wl,-install_name,@rpath/$(DYNAMIC_VERSIONED_NAME_SHORT) -Wl,-rpath,$(ORIGIN)/../../build/lib
 else
+       ifeq (${USE_OPENMP}, 1)
+               CXXFLAGS += -fopenmp
+               LINKFLAGS += -fopenmp
+       endif
        ORIGIN := \$$ORIGIN
 endif
 
@@ -334,6 +344,27 @@ ifeq ($(USE_NCCL), 1)
        COMMON_FLAGS += -DUSE_NCCL
 endif
 
+# ACL acceleration configuration
+ifeq ($(USE_ACL), 1)
+        LIBRARY_DIRS += $(ACL_LIBS_DIR)
+       LIBRARIES += $(ACL_LIBS)
+       INCLUDE_DIRS +=$(ACL_INCS)
+       COMMON_FLAGS += -DUSE_ACL -std=c++11
+endif
+
+#USE_PROFILING -- get profiling informations, is controled by LOGACL
+#LAYER_PERF_STAT -- haitao's net profiling information
+ifeq ($(USE_PROFILING), 1)
+       COMMON_FLAGS += -DUSE_PROFILING -DLAYER_PERF_STAT
+endif
+#HDF5
+ifeq ($(USE_HDF5), 1)
+        LIBRARY_DIRS += $(HDF5_LIBRARY_DIRS)
+       LIBRARIES += $(HDF5_LIBRARIES)
+       INCLUDE_DIRS +=$(HDF5_INCLUDE_DIRS)
+       COMMON_FLAGS += -DUSE_HDF5
+endif
+
 # configure IO libraries
 ifeq ($(USE_OPENCV), 1)
        COMMON_FLAGS += -DUSE_OPENCV
@@ -358,6 +389,15 @@ ifeq ($(CPU_ONLY), 1)
        COMMON_FLAGS += -DCPU_ONLY
 endif
 
+ifeq ($(USE_ACL), 1)
+       OBJS := $(PROTO_OBJS) $(CXX_OBJS)
+       TEST_OBJS := $(TEST_CXX_OBJS)
+       TEST_BINS := $(TEST_CXX_BINS)
+       ALL_WARNS := $(ALL_CXX_WARNS)
+       TEST_FILTER := --gtest_filter="-*GPU*"
+       COMMON_FLAGS += -DCPU_ONLY
+endif
+
 # Python layer support
 ifeq ($(WITH_PYTHON_LAYER), 1)
        COMMON_FLAGS += -DWITH_PYTHON_LAYER
@@ -365,7 +405,8 @@ ifeq ($(WITH_PYTHON_LAYER), 1)
 endif
 
 # BLAS configuration (default = ATLAS)
-BLAS ?= atlas
+#BLAS ?= atlas
+BLAS ?= open
 ifeq ($(BLAS), mkl)
        # MKL
        LIBRARIES += mkl_rt
diff --git a/Makefile.config.acl b/Makefile.config.acl
new file mode 100644 (file)
index 0000000..b30759f
--- /dev/null
@@ -0,0 +1,140 @@
+## Refer to http://caffe.berkeleyvision.org/installation.html
+# Contributions simplifying and improving our build system are welcome!
+
+# cuDNN acceleration switch (uncomment to build with cuDNN).
+# USE_CUDNN := 1
+
+# CPU-only switch (uncomment to build without GPU support).
+CPU_ONLY := 1
+
+USE_PROFILING := 0
+
+USE_ACL :=1
+ACL_ROOT :=/home/firefly/ComputeLibrary
+ACL_INCS :=$(ACL_ROOT)/include 
+ACL_INCS +=$(ACL_ROOT)
+ACL_LIBS_DIR :=$(ACL_ROOT)/build
+ACL_LIBS_DIR +=$(ACL_ROOT)/build/arm_compute
+ACL_LIBS :=arm_compute OpenCL
+
+# uncomment to disable IO dependencies and corresponding data layers
+# USE_OPENCV := 0
+# USE_LEVELDB := 0
+# USE_LMDB := 0
+
+# uncomment to allow MDB_NOLOCK when reading LMDB files (only if necessary)
+#      You should not set this flag if you will be reading LMDBs with any
+#      possibility of simultaneous read and write
+# ALLOW_LMDB_NOLOCK := 1
+
+# Uncomment if you're using OpenCV 3
+# OPENCV_VERSION := 3
+
+# To customize your choice of compiler, uncomment and set the following.
+# N.B. the default for Linux is g++ and the default for OSX is clang++
+# CUSTOM_CXX := g++
+#CUSTOM_CXX := aarch64-linux-gnu-g++
+#os :=linux 
+#arch :=arm64-v8a
+
+# CUDA directory contains bin/ and lib/ directories that we need.
+CUDA_DIR := /usr/local/cuda
+# On Ubuntu 14.04, if cuda tools are installed via
+# "sudo apt-get install nvidia-cuda-toolkit" then use this instead:
+# CUDA_DIR := /usr
+
+# CUDA architecture setting: going with all of them.
+# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility.
+# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility.
+CUDA_ARCH := -gencode arch=compute_20,code=sm_20 \
+               -gencode arch=compute_20,code=sm_21 \
+               -gencode arch=compute_30,code=sm_30 \
+               -gencode arch=compute_35,code=sm_35 \
+               -gencode arch=compute_50,code=sm_50 \
+               -gencode arch=compute_52,code=sm_52 \
+               -gencode arch=compute_60,code=sm_60 \
+               -gencode arch=compute_61,code=sm_61 \
+               -gencode arch=compute_61,code=compute_61
+
+# BLAS choice:
+# atlas for ATLAS (default)
+# mkl for MKL
+# open for OpenBlas
+#BLAS := atlas
+BLAS := open
+# Custom (MKL/ATLAS/OpenBLAS) include and lib directories.
+# Leave commented to accept the defaults for your choice of BLAS
+# (which should work)!
+# BLAS_INCLUDE := /path/to/your/blas
+# BLAS_LIB := /path/to/your/blas
+
+# Homebrew puts openblas in a directory that is not on the standard search path
+# BLAS_INCLUDE := $(shell brew --prefix openblas)/include
+# BLAS_LIB := $(shell brew --prefix openblas)/lib
+
+# This is required only if you will compile the matlab interface.
+# MATLAB directory should contain the mex binary in /bin.
+# MATLAB_DIR := /usr/local
+# MATLAB_DIR := /Applications/MATLAB_R2012b.app
+
+# NOTE: this is required only if you will compile the python interface.
+# We need to be able to find Python.h and numpy/arrayobject.h.
+PYTHON_INCLUDE := /usr/include/python2.7 \
+               /usr/lib/python2.7/dist-packages/numpy/core/include
+# Anaconda Python distribution is quite popular. Include path:
+# Verify anaconda location, sometimes it's in root.
+# ANACONDA_HOME := $(HOME)/anaconda
+# PYTHON_INCLUDE := $(ANACONDA_HOME)/include \
+               # $(ANACONDA_HOME)/include/python2.7 \
+               # $(ANACONDA_HOME)/lib/python2.7/site-packages/numpy/core/include
+
+# Uncomment to use Python 3 (default is Python 2)
+# PYTHON_LIBRARIES := boost_python3 python3.5m
+# PYTHON_INCLUDE := /usr/include/python3.5m \
+#                 /usr/lib/python3.5/dist-packages/numpy/core/include
+
+# We need to be able to find libpythonX.X.so or .dylib.
+PYTHON_LIB := /usr/lib
+# PYTHON_LIB := $(ANACONDA_HOME)/lib
+
+# Homebrew installs numpy in a non standard path (keg only)
+# PYTHON_INCLUDE += $(dir $(shell python -c 'import numpy.core; print(numpy.core.__file__)'))/include
+# PYTHON_LIB += $(shell brew --prefix numpy)/lib
+
+# Uncomment to support layers written in Python (will link against Python libs)
+# WITH_PYTHON_LAYER := 1
+
+# Whatever else you find you need goes here.
+INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include
+LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib
+
+# If Homebrew is installed at a non standard location (for example your home directory) and you use it for general dependencies
+# INCLUDE_DIRS += $(shell brew --prefix)/include
+# LIBRARY_DIRS += $(shell brew --prefix)/lib
+
+# NCCL acceleration switch (uncomment to build with NCCL)
+# https://github.com/NVIDIA/nccl (last tested version: v1.2.3-1+cuda8.0)
+# USE_NCCL := 1
+
+# Uncomment to use `pkg-config` to specify OpenCV library paths.
+# (Usually not necessary -- OpenCV libraries are normally installed in one of the above $LIBRARY_DIRS.)
+# USE_PKG_CONFIG := 1
+
+# N.B. both build and distribute dirs are cleared on `make clean`
+BUILD_DIR := build
+DISTRIBUTE_DIR := distribute
+
+#HDF5
+USE_HDF5 := 1
+HDF5_INCLUDE_DIRS := /usr/include/hdf5/serial
+HDF5_LIBRARY_DIRS := /usr/lib/aarch64-linux-gnu/hdf5/serial
+HDF5_LIBRARIES :=hdf5_hl hdf5
+
+# Uncomment for debugging. Does not work on OSX due to https://github.com/BVLC/caffe/issues/171
+# DEBUG := 1
+
+# The ID of the GPU that 'make runtest' will use to run unit tests.
+TEST_GPUID := 0
+
+# enable pretty build (comment to see full commands)
+Q ?= @
index 44b9e62..a20099a 100644 (file)
--- a/README.md
+++ b/README.md
@@ -1,3 +1,14 @@
+# CaffeOnACL
+[![License](https://img.shields.io/badge/license-BSD-blue.svg)](LICENSE)
+
+CaffeOnACL is a project to use ARM Compute Library (NEON+GPU) to speed up caffe and provide utilities to debug, profile and tune application performance.
+
+Check out the documents for the details like
+- [release notes](https://github.com/OAID/caffeOnACL/tree/master/docs/caffeOnACL_release_notes_0_2_0.docx)
+- [user guide](https://github.com/OAID/caffeOnACL/tree/master/docs/caffeOnACL_user_guide_0_2_0.docx)
+
+
+
 # Caffe
 
 [![Build Status](https://travis-ci.org/BVLC/caffe.svg?branch=master)](https://travis-ci.org/BVLC/caffe)
index 02c8152..888443d 100644 (file)
@@ -66,6 +66,21 @@ if(USE_LEVELDB)
   list(APPEND Caffe_DEFINITIONS PUBLIC -DUSE_LEVELDB)
 endif()
 
+# ---[ ACL
+if(USE_ACL)
+  find_package(ACL REQUIRED)
+  list(APPEND Caffe_INCLUDE_DIRS PUBLIC ${ACL_INCLUDE})
+  list(APPEND Caffe_INCLUDE_DIRS PUBLIC ${ACL_INCLUDE}/include)
+  set(__list ${ACL_LIBRARIES})
+  separate_arguments(__list)
+    list(REMOVE_DUPLICATES __list)
+    foreach(i ${__list})
+        list(APPEND Caffe_LINKER_LIBS PUBLIC ${i})
+    endforeach()
+  list(APPEND Caffe_COMPILE_OPTIONS PRIVATE -std=c++11)
+  list(APPEND Caffe_DEFINITIONS PUBLIC -DUSE_ACL)
+endif()
+
 # ---[ Snappy
 if(USE_LEVELDB)
   find_package(Snappy REQUIRED)
diff --git a/cmake/Modules/FindACL.cmake b/cmake/Modules/FindACL.cmake
new file mode 100644 (file)
index 0000000..d7df5aa
--- /dev/null
@@ -0,0 +1,37 @@
+set(ACL_INC_PATHS
+    /usr/include
+    /usr/local/include
+    /usr/local/acl
+    $ENV{ACL_DIR}/include
+    )
+
+set(ACL_LIB_PATHS
+    /lib
+    /lib64
+    /usr/lib
+    /usr/lib64
+    /usr/local/lib
+    /usr/local/lib64
+    /usr/local/acl/lib
+    /usr/local/acl/lib64
+    $ENV{ACL_DIR}/lib
+    )
+
+find_path(ACL_INCLUDE NAMES arm_compute PATHS ${ACL_INC_PATHS})
+find_library(ACL_LIBRARIES NAMES arm_compute-static PATHS ${ACL_LIB_PATHS})
+find_library(ACL_CORE_LIBRARIES NAMES arm_compute_core-static PATHS ${ACL_LIB_PATHS})
+SET(ACL_LIBRARIES "${ACL_CORE_LIBRARIES} ${ACL_LIBRARIES}")
+
+if(ACL_INCS)
+  SET(ACL_INCLUDE "${ACL_INCS}")
+  SET(ACL_LIBRARIES "${ACL_LIBS}")
+  SET(ACL_FOUND 1)
+else  ()
+  include(FindPackageHandleStandardArgs)
+  find_package_handle_standard_args(ACL DEFAULT_MSG ACL_INCLUDE ACL_LIBRARIES)
+endif ()
+
+if (ACL_FOUND)
+  message(STATUS "Found ACL    (include: ${ACL_INCLUDE}, library: ${ACL_LIBRARIES})")
+  mark_as_advanced(ACL_INCLUDE ACL_LIBRARIES)
+endif ()
diff --git a/docs/caffeOnACL_release_notes_0_2_0.docx b/docs/caffeOnACL_release_notes_0_2_0.docx
new file mode 100644 (file)
index 0000000..0c1c014
Binary files /dev/null and b/docs/caffeOnACL_release_notes_0_2_0.docx differ
diff --git a/docs/caffeOnACL_user_guide_0_2_0.docx b/docs/caffeOnACL_user_guide_0_2_0.docx
new file mode 100644 (file)
index 0000000..a7ebf61
Binary files /dev/null and b/docs/caffeOnACL_user_guide_0_2_0.docx differ
diff --git a/examples/cpp_classification/classification_profiling.cpp b/examples/cpp_classification/classification_profiling.cpp
new file mode 100644 (file)
index 0000000..f5d5eae
--- /dev/null
@@ -0,0 +1,546 @@
+#include <caffe/caffe.hpp>
+#ifdef USE_OPENCV
+#include <opencv2/core/core.hpp>
+#include <opencv2/highgui/highgui.hpp>
+#include <opencv2/imgproc/imgproc.hpp>
+#endif  // USE_OPENCV
+#include <algorithm>
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#ifdef USE_PROFILING
+
+#include <iostream>
+
+#include <time.h>
+
+#define REPEAT_TEST
+
+unsigned long get_cur_time(void)
+{
+   struct timespec tm;
+
+   clock_gettime(CLOCK_MONOTONIC_COARSE, &tm);
+
+   return (tm.tv_sec*1000+tm.tv_nsec/1000000);
+}
+
+#endif //USE_PROFILING
+
+#ifdef USE_OPENCV
+using namespace caffe;  // NOLINT(build/namespaces)
+using std::string;
+
+/* Pair (label, confidence) representing a prediction. */
+typedef std::pair<string, float> Prediction;
+
+class Classifier {
+ public:
+  Classifier(const string& model_file,
+             const string& trained_file,
+             const string& mean_file,
+             const string& label_file);
+
+  std::vector<Prediction> Classify(const cv::Mat& img, int N = 5);
+
+#ifdef USE_PROFILING
+
+#ifdef LAYER_PERF_STAT
+  void  dump_perf_stat(void);
+  void  dump_single_layer_io(int idx, Layer<float> * p_layer);
+  void  dump_single_layer_perf(int idx, Layer<float> * p_layer,uint64_t total_net_time);
+#ifdef REPEAT_TEST
+  void collect_layer_stat(vector<vector<perf_stat> * > & all_stat);
+  void dump_all_stat(vector <vector<perf_stat>*>& all_stat);
+  void reset_layer_stat();
+#endif
+#endif
+
+#endif //USE_PROFILING
+
+ private:
+  void SetMean(const string& mean_file);
+
+  std::vector<float> Predict(const cv::Mat& img);
+
+  void WrapInputLayer(std::vector<cv::Mat>* input_channels);
+
+  void Preprocess(const cv::Mat& img,
+                  std::vector<cv::Mat>* input_channels);
+
+ private:
+  shared_ptr<Net<float> > net_;
+  cv::Size input_geometry_;
+  int num_channels_;
+  cv::Mat mean_;
+  std::vector<string> labels_;
+};
+
+Classifier::Classifier(const string& model_file,
+                       const string& trained_file,
+                       const string& mean_file,
+                       const string& label_file) {
+#ifdef CPU_ONLY
+  Caffe::set_mode(Caffe::CPU);
+#else
+  Caffe::set_mode(Caffe::GPU);
+#endif
+
+  /* Load the network. */
+  net_.reset(new Net<float>(model_file, TEST));
+  net_->CopyTrainedLayersFrom(trained_file);
+
+  CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input.";
+  CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output.";
+
+  Blob<float>* input_layer = net_->input_blobs()[0];
+  num_channels_ = input_layer->channels();
+  CHECK(num_channels_ == 3 || num_channels_ == 1)
+    << "Input layer should have 1 or 3 channels.";
+  input_geometry_ = cv::Size(input_layer->width(), input_layer->height());
+
+  /* Load the binaryproto mean file. */
+  SetMean(mean_file);
+
+  /* Load labels. */
+  std::ifstream labels(label_file.c_str());
+  CHECK(labels) << "Unable to open labels file " << label_file;
+  string line;
+  while (std::getline(labels, line))
+    labels_.push_back(string(line));
+
+  Blob<float>* output_layer = net_->output_blobs()[0];
+  CHECK_EQ(labels_.size(), output_layer->channels())
+    << "Number of labels is different from the output layer dimension.";
+}
+
+static bool PairCompare(const std::pair<float, int>& lhs,
+                        const std::pair<float, int>& rhs) {
+  return lhs.first > rhs.first;
+}
+
+/* Return the indices of the top N values of vector v. */
+static std::vector<int> Argmax(const std::vector<float>& v, int N) {
+  std::vector<std::pair<float, int> > pairs;
+  for (size_t i = 0; i < v.size(); ++i)
+    pairs.push_back(std::make_pair(v[i], i));
+  std::partial_sort(pairs.begin(), pairs.begin() + N, pairs.end(), PairCompare);
+
+  std::vector<int> result;
+  for (int i = 0; i < N; ++i)
+    result.push_back(pairs[i].second);
+  return result;
+}
+
+/* Return the top N predictions. */
+std::vector<Prediction> Classifier::Classify(const cv::Mat& img, int N) {
+  std::vector<float> output = Predict(img);
+
+  N = std::min<int>(labels_.size(), N);
+  std::vector<int> maxN = Argmax(output, N);
+  std::vector<Prediction> predictions;
+  for (int i = 0; i < N; ++i) {
+    int idx = maxN[i];
+    predictions.push_back(std::make_pair(labels_[idx], output[idx]));
+  }
+
+  return predictions;
+}
+
+/* Load the mean file in binaryproto format. */
+void Classifier::SetMean(const string& mean_file) {
+  BlobProto blob_proto;
+  ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
+
+  /* Convert from BlobProto to Blob<float> */
+  Blob<float> mean_blob;
+  mean_blob.FromProto(blob_proto);
+  CHECK_EQ(mean_blob.channels(), num_channels_)
+    << "Number of channels of mean file doesn't match input layer.";
+
+  /* The format of the mean file is planar 32-bit float BGR or grayscale. */
+  std::vector<cv::Mat> channels;
+  float* data = mean_blob.mutable_cpu_data();
+  for (int i = 0; i < num_channels_; ++i) {
+    /* Extract an individual channel. */
+    cv::Mat channel(mean_blob.height(), mean_blob.width(), CV_32FC1, data);
+    channels.push_back(channel);
+    data += mean_blob.height() * mean_blob.width();
+  }
+
+  /* Merge the separate channels into a single image. */
+  cv::Mat mean;
+  cv::merge(channels, mean);
+
+  /* Compute the global mean pixel value and create a mean image
+   * filled with this value. */
+  cv::Scalar channel_mean = cv::mean(mean);
+  mean_ = cv::Mat(input_geometry_, mean.type(), channel_mean);
+}
+
+std::vector<float> Classifier::Predict(const cv::Mat& img) {
+  Blob<float>* input_layer = net_->input_blobs()[0];
+  input_layer->Reshape(1, num_channels_,
+                       input_geometry_.height, input_geometry_.width);
+  /* Forward dimension change to all layers. */
+  net_->Reshape();
+
+  std::vector<cv::Mat> input_channels;
+  WrapInputLayer(&input_channels);
+
+  Preprocess(img, &input_channels);
+
+#ifdef USE_PROFILING
+  unsigned long tstart=get_cur_time();
+#endif //USE_PROFILING
+
+  net_->Forward();
+
+#ifdef USE_PROFILING
+
+  unsigned long tend=get_cur_time();
+
+  std::cout<<"used time: "<<tend-tstart<<std::endl;
+
+#ifdef LAYER_PERF_STAT
+  dump_perf_stat(); 
+#ifdef REPEAT_TEST
+
+   reset_layer_stat();
+
+   vector<vector<perf_stat>* >  all_stat;
+   int rep_number=10;
+
+   for(int i=0;i<rep_number;i++)
+   {
+      net_->Forward();
+      collect_layer_stat(all_stat);
+      reset_layer_stat();
+   }
+
+   //dump stats
+   dump_all_stat(all_stat);
+
+   for(int i=0;i<all_stat.size();i++)
+         delete all_stat[i];
+   
+#endif //REPEAT_TEST
+#endif //LAYER_PERF_STAT
+#endif //USE_PROFILING
+
+  /* Copy the output layer to a std::vector */
+  Blob<float>* output_layer = net_->output_blobs()[0];
+  const float* begin = output_layer->cpu_data();
+  const float* end = begin + output_layer->channels();
+  return std::vector<float>(begin, end);
+}
+
+#ifdef USE_PROFILING
+
+#ifdef LAYER_PERF_STAT
+
+#ifdef REPEAT_TEST
+void Classifier::collect_layer_stat(vector<vector<perf_stat>*>& all_stat)
+{
+   vector<perf_stat > * p_stat;
+   perf_stat * p_time_stat;
+   const vector<shared_ptr<Layer<float> > >& layers=net_->layers();
+
+   
+   p_stat=new vector<perf_stat>;
+
+   for (int i =0;i< layers.size(); i++) {
+        p_time_stat=layers[i]->get_time_stat();
+        p_stat->push_back(*p_time_stat);
+
+   }
+
+   all_stat.push_back(p_stat);
+}
+
+void Classifier::reset_layer_stat(void)
+{
+   const vector<shared_ptr<Layer<float> > >& layers=net_->layers();
+   perf_stat * p_time_stat;
+
+   for (int i =0;i< layers.size(); i++) {
+        p_time_stat=layers[i]->get_time_stat();
+
+        p_time_stat->count=0;
+        p_time_stat->total=0;
+        p_time_stat->used=p_time_stat->start=p_time_stat->end=0;
+   }
+}
+
+void Classifier::dump_all_stat(vector<vector<perf_stat>*>& all_stat)
+{
+
+   struct new_perf_stat {
+        perf_stat stat;
+        int       idx;
+   };
+    
+   vector<new_perf_stat > layer_stat;
+   perf_stat * p_stat;
+
+   uint64_t total_time=0;
+
+   layer_stat.resize(all_stat[0]->size());
+
+   for(int i=0;i<all_stat.size();i++)
+   {
+      for(int j=0;j<layer_stat.size();j++)
+       {
+          p_stat=&layer_stat[j].stat;
+
+          p_stat->total+=(*all_stat[i])[j].total;
+          p_stat->count+=(*all_stat[i])[j].count;
+          total_time+=(*all_stat[i])[j].total;
+       }
+   }
+
+   total_time=total_time/all_stat.size();
+
+   std::cout<<std::endl<<"----------------------------------"<<std::endl;
+   std::cout<<"STATS for "<<all_stat.size()<<" reptitions: ..."<<std::endl;
+   std::cout<<"Total time: "<<total_time<<" per forward"<<std::endl;
+   std::cout<<"Each layer stats: ..."<<std::endl;
+
+
+   for(int i=layer_stat.size()-1;i>=0;i--)
+   {
+      p_stat=&layer_stat[i].stat;
+
+      layer_stat[i].idx=i;
+
+     std::cout<<"  "<<i<<": used time: "<<p_stat->total/all_stat.size();
+     std::cout<<" ratio: "<<((float)p_stat->total)/all_stat.size()/total_time*100;
+     std::cout<<" enter count: "<<p_stat->count/all_stat.size()<<std::endl;
+   }
+
+   std::cout<<std::endl;
+
+   std::cout<<"time cost top 10 layers are: ..."<<std::endl;
+
+   std::sort(layer_stat.begin(),layer_stat.end(),[](const new_perf_stat& a, const new_perf_stat& b)
+       {
+          if(a.stat.total>b.stat.total)
+            return true;
+          else
+            return false;
+       });
+
+   uint64_t  top_total_time=0;
+
+   for(int i=0; i<10; i++)
+   {
+      p_stat=&layer_stat[i].stat;
+
+     std::cout<<"  "<<layer_stat[i].idx<<": used time: "<<p_stat->total/all_stat.size();
+     std::cout<<" ratio: "<<((float)p_stat->total)/all_stat.size()/total_time*100;
+     std::cout<<" enter count: "<<p_stat->count/all_stat.size()<<std::endl;
+     top_total_time+=p_stat->total;
+   }
+
+   std::cout<<"Top cost layers occupied: "<<(float)top_total_time/all_stat.size()/total_time*100<<std::endl;
+
+   std::cout<<std::endl;
+}
+
+#endif
+
+void Classifier::dump_single_layer_io(int idx, Layer<float> * p_layer)
+{
+   const LayerParameter& layer_param=p_layer->layer_param();
+
+   std::cout<<std::endl<<"LAYER IDX: "<<idx<<" name: "<<layer_param.name();
+   std::cout<<" type: "<<layer_param.type()<<std::endl;
+
+   const vector<Blob<float>*> *p_bottom_vec=p_layer->saved_bottom;
+
+   for(int i=0;i<layer_param.bottom_size(); i++)
+   {
+      std::cout<<"bottom "<<layer_param.bottom(i)<<": ";
+
+      Blob<float> * p_blob=(*p_bottom_vec)[i];
+
+      for(int j=0;j<p_blob->num_axes();j++)
+      {
+          std::cout<<p_blob->shape(j)<<" ";
+      }
+      std::cout<<std::endl;
+   }
+
+   const vector<Blob<float>*> *p_top_vec=p_layer->saved_top;
+   for(int i=0;i<layer_param.top_size(); i++)
+   {
+      std::cout<<"top "<<layer_param.top(i)<<": ";
+      Blob<float> * p_blob=(*p_top_vec)[i];
+
+      for(int j=0;j<p_blob->num_axes();j++)
+      {
+          std::cout<<p_blob->shape(j)<<" ";
+      }
+      std::cout<<std::endl;
+   }
+}
+
+void Classifier::dump_single_layer_perf(int idx, Layer<float> * p_layer, uint64_t total_net_time)
+{
+   const LayerParameter& layer_param=p_layer->layer_param();
+   perf_stat * p_time_stat;
+
+   p_time_stat=p_layer->get_time_stat();
+
+   std::cout<<std::endl<<"LAYER IDX: "<<idx<<" name: "<<layer_param.name();
+   std::cout<<" type: "<<layer_param.type();
+   std::cout<<"  ratio: "<<(float)p_time_stat->total/total_net_time*100<<std::endl;
+
+
+   std::cout<<"time stat:  total: "<<p_time_stat->total<<" count: "<<p_time_stat->count;
+   if(p_time_stat->count)
+    {
+       std::cout<<" average: "<<((float)p_time_stat->total)/p_time_stat->count;
+    }
+
+   std::cout<<" start: "<<p_time_stat->start<<" end: "<<p_time_stat->end;
+   std::cout<<std::endl;
+
+
+} 
+
+void Classifier::dump_perf_stat(void)
+{
+   uint64_t total_net_time=0;
+
+   const vector<shared_ptr<Layer<float> > >& layers=net_->layers();
+
+   std::cout<<"Input/output shape for each layer ... total: "<<layers.size()<<std::endl;
+
+   for (int i = layers.size() - 1; i >= 0; --i) {
+     dump_single_layer_io(i,layers[i].get());
+   }
+
+
+   for (int i = layers.size() - 1; i >= 0; --i) {
+
+     perf_stat * p_time_stat;
+
+     p_time_stat=layers[i]->get_time_stat();
+
+     total_net_time+=p_time_stat->total;
+
+   }
+   
+   std::cout<<"Time for each layer ... sum of all layers is : ";
+   std::cout<<total_net_time<<std::endl;
+
+   for (int i = layers.size() - 1; i >= 0; --i) {
+
+     dump_single_layer_perf(i,layers[i].get(),total_net_time);
+   }
+
+}
+
+#endif
+
+#endif //USE_PROFILING
+
+/* Wrap the input layer of the network in separate cv::Mat objects
+ * (one per channel). This way we save one memcpy operation and we
+ * don't need to rely on cudaMemcpy2D. The last preprocessing
+ * operation will write the separate channels directly to the input
+ * layer. */
+void Classifier::WrapInputLayer(std::vector<cv::Mat>* input_channels) {
+  Blob<float>* input_layer = net_->input_blobs()[0];
+
+  int width = input_layer->width();
+  int height = input_layer->height();
+  float* input_data = input_layer->mutable_cpu_data();
+  for (int i = 0; i < input_layer->channels(); ++i) {
+    cv::Mat channel(height, width, CV_32FC1, input_data);
+    input_channels->push_back(channel);
+    input_data += width * height;
+  }
+}
+
+void Classifier::Preprocess(const cv::Mat& img,
+                            std::vector<cv::Mat>* input_channels) {
+  /* Convert the input image to the input image format of the network. */
+  cv::Mat sample;
+  if (img.channels() == 3 && num_channels_ == 1)
+    cv::cvtColor(img, sample, cv::COLOR_BGR2GRAY);
+  else if (img.channels() == 4 && num_channels_ == 1)
+    cv::cvtColor(img, sample, cv::COLOR_BGRA2GRAY);
+  else if (img.channels() == 4 && num_channels_ == 3)
+    cv::cvtColor(img, sample, cv::COLOR_BGRA2BGR);
+  else if (img.channels() == 1 && num_channels_ == 3)
+    cv::cvtColor(img, sample, cv::COLOR_GRAY2BGR);
+  else
+    sample = img;
+
+  cv::Mat sample_resized;
+  if (sample.size() != input_geometry_)
+    cv::resize(sample, sample_resized, input_geometry_);
+  else
+    sample_resized = sample;
+
+  cv::Mat sample_float;
+  if (num_channels_ == 3)
+    sample_resized.convertTo(sample_float, CV_32FC3);
+  else
+    sample_resized.convertTo(sample_float, CV_32FC1);
+
+  cv::Mat sample_normalized;
+  cv::subtract(sample_float, mean_, sample_normalized);
+
+  /* This operation will write the separate BGR planes directly to the
+   * input layer of the network because it is wrapped by the cv::Mat
+   * objects in input_channels. */
+  cv::split(sample_normalized, *input_channels);
+
+  CHECK(reinterpret_cast<float*>(input_channels->at(0).data)
+        == net_->input_blobs()[0]->cpu_data())
+    << "Input channels are not wrapping the input layer of the network.";
+}
+
+int main(int argc, char** argv) {
+  if (argc != 6) {
+    std::cerr << "Usage: " << argv[0]
+              << " deploy.prototxt network.caffemodel"
+              << " mean.binaryproto labels.txt img.jpg" << std::endl;
+    return 1;
+  }
+
+  ::google::InitGoogleLogging(argv[0]);
+
+  string model_file   = argv[1];
+  string trained_file = argv[2];
+  string mean_file    = argv[3];
+  string label_file   = argv[4];
+  Classifier classifier(model_file, trained_file, mean_file, label_file);
+
+  string file = argv[5];
+
+  std::cout << "---------- Prediction for "
+            << file << " ----------" << std::endl;
+
+  cv::Mat img = cv::imread(file, -1);
+  CHECK(!img.empty()) << "Unable to decode image " << file;
+  std::vector<Prediction> predictions = classifier.Classify(img);
+
+  /* Print the top N predictions. */
+  for (size_t i = 0; i < predictions.size(); ++i) {
+    Prediction p = predictions[i];
+    std::cout << std::fixed << std::setprecision(4) << p.second << " - \""
+              << p.first << "\"" << std::endl;
+  }
+}
+#else
+int main(int argc, char** argv) {
+  LOG(FATAL) << "This example requires OpenCV; compile with USE_OPENCV.";
+}
+#endif  // USE_OPENCV
diff --git a/examples/cpp_classification/classification_profiling_gpu.cpp b/examples/cpp_classification/classification_profiling_gpu.cpp
new file mode 100644 (file)
index 0000000..3c5e04a
--- /dev/null
@@ -0,0 +1,546 @@
+#include <caffe/caffe.hpp>
+#ifdef USE_OPENCV
+#include <opencv2/core/core.hpp>
+#include <opencv2/highgui/highgui.hpp>
+#include <opencv2/imgproc/imgproc.hpp>
+#endif  // USE_OPENCV
+#include <algorithm>
+#include <iosfwd>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#ifdef USE_PROFILING
+
+#include <iostream>
+
+#include <time.h>
+
+#define REPEAT_TEST
+
+unsigned long get_cur_time(void)
+{
+   struct timespec tm;
+
+   clock_gettime(CLOCK_MONOTONIC_COARSE, &tm);
+
+   return (tm.tv_sec*1000+tm.tv_nsec/1000000);
+}
+
+#endif //USE_PROFILING
+
+#ifdef USE_OPENCV
+using namespace caffe;  // NOLINT(build/namespaces)
+using std::string;
+
+/* Pair (label, confidence) representing a prediction. */
+typedef std::pair<string, float> Prediction;
+
+class Classifier {
+ public:
+  Classifier(const string& model_file,
+             const string& trained_file,
+             const string& mean_file,
+             const string& label_file);
+
+  std::vector<Prediction> Classify(const cv::Mat& img, int N = 5);
+
+#ifdef USE_PROFILING
+
+#ifdef LAYER_PERF_STAT
+  void  dump_perf_stat(void);
+  void  dump_single_layer_io(int idx, Layer<float> * p_layer);
+  void  dump_single_layer_perf(int idx, Layer<float> * p_layer,uint64_t total_net_time);
+#ifdef REPEAT_TEST
+  void collect_layer_stat(vector<vector<perf_stat> * > & all_stat);
+  void dump_all_stat(vector <vector<perf_stat>*>& all_stat);
+  void reset_layer_stat();
+#endif
+#endif
+
+#endif //USE_PROFILING
+
+ private:
+  void SetMean(const string& mean_file);
+
+  std::vector<float> Predict(const cv::Mat& img);
+
+  void WrapInputLayer(std::vector<cv::Mat>* input_channels);
+
+  void Preprocess(const cv::Mat& img,
+                  std::vector<cv::Mat>* input_channels);
+
+ private:
+  shared_ptr<Net<float> > net_;
+  cv::Size input_geometry_;
+  int num_channels_;
+  cv::Mat mean_;
+  std::vector<string> labels_;
+};
+
+Classifier::Classifier(const string& model_file,
+                       const string& trained_file,
+                       const string& mean_file,
+                       const string& label_file) {
+//#ifdef CPU_ONLY
+//  Caffe::set_mode(Caffe::CPU);
+//#else
+  Caffe::set_mode(Caffe::GPU); //For ARM GPU (the code is in CPU_ONLY mode, just set caffe mode to GPU)
+//#endif
+
+  /* Load the network. */
+  net_.reset(new Net<float>(model_file, TEST));
+  net_->CopyTrainedLayersFrom(trained_file);
+
+  CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input.";
+  CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output.";
+
+  Blob<float>* input_layer = net_->input_blobs()[0];
+  num_channels_ = input_layer->channels();
+  CHECK(num_channels_ == 3 || num_channels_ == 1)
+    << "Input layer should have 1 or 3 channels.";
+  input_geometry_ = cv::Size(input_layer->width(), input_layer->height());
+
+  /* Load the binaryproto mean file. */
+  SetMean(mean_file);
+
+  /* Load labels. */
+  std::ifstream labels(label_file.c_str());
+  CHECK(labels) << "Unable to open labels file " << label_file;
+  string line;
+  while (std::getline(labels, line))
+    labels_.push_back(string(line));
+
+  Blob<float>* output_layer = net_->output_blobs()[0];
+  CHECK_EQ(labels_.size(), output_layer->channels())
+    << "Number of labels is different from the output layer dimension.";
+}
+
+static bool PairCompare(const std::pair<float, int>& lhs,
+                        const std::pair<float, int>& rhs) {
+  return lhs.first > rhs.first;
+}
+
+/* Return the indices of the top N values of vector v. */
+static std::vector<int> Argmax(const std::vector<float>& v, int N) {
+  std::vector<std::pair<float, int> > pairs;
+  for (size_t i = 0; i < v.size(); ++i)
+    pairs.push_back(std::make_pair(v[i], i));
+  std::partial_sort(pairs.begin(), pairs.begin() + N, pairs.end(), PairCompare);
+
+  std::vector<int> result;
+  for (int i = 0; i < N; ++i)
+    result.push_back(pairs[i].second);
+  return result;
+}
+
+/* Return the top N predictions. */
+std::vector<Prediction> Classifier::Classify(const cv::Mat& img, int N) {
+  std::vector<float> output = Predict(img);
+
+  N = std::min<int>(labels_.size(), N);
+  std::vector<int> maxN = Argmax(output, N);
+  std::vector<Prediction> predictions;
+  for (int i = 0; i < N; ++i) {
+    int idx = maxN[i];
+    predictions.push_back(std::make_pair(labels_[idx], output[idx]));
+  }
+
+  return predictions;
+}
+
+/* Load the mean file in binaryproto format. */
+void Classifier::SetMean(const string& mean_file) {
+  BlobProto blob_proto;
+  ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
+
+  /* Convert from BlobProto to Blob<float> */
+  Blob<float> mean_blob;
+  mean_blob.FromProto(blob_proto);
+  CHECK_EQ(mean_blob.channels(), num_channels_)
+    << "Number of channels of mean file doesn't match input layer.";
+
+  /* The format of the mean file is planar 32-bit float BGR or grayscale. */
+  std::vector<cv::Mat> channels;
+  float* data = mean_blob.mutable_cpu_data();
+  for (int i = 0; i < num_channels_; ++i) {
+    /* Extract an individual channel. */
+    cv::Mat channel(mean_blob.height(), mean_blob.width(), CV_32FC1, data);
+    channels.push_back(channel);
+    data += mean_blob.height() * mean_blob.width();
+  }
+
+  /* Merge the separate channels into a single image. */
+  cv::Mat mean;
+  cv::merge(channels, mean);
+
+  /* Compute the global mean pixel value and create a mean image
+   * filled with this value. */
+  cv::Scalar channel_mean = cv::mean(mean);
+  mean_ = cv::Mat(input_geometry_, mean.type(), channel_mean);
+}
+
+std::vector<float> Classifier::Predict(const cv::Mat& img) {
+  Blob<float>* input_layer = net_->input_blobs()[0];
+  input_layer->Reshape(1, num_channels_,
+                       input_geometry_.height, input_geometry_.width);
+  /* Forward dimension change to all layers. */
+  net_->Reshape();
+
+  std::vector<cv::Mat> input_channels;
+  WrapInputLayer(&input_channels);
+
+  Preprocess(img, &input_channels);
+
+#ifdef USE_PROFILING
+  unsigned long tstart=get_cur_time();
+#endif //USE_PROFILING
+
+  net_->Forward();
+
+#ifdef USE_PROFILING
+
+  unsigned long tend=get_cur_time();
+
+  std::cout<<"used time: "<<tend-tstart<<std::endl;
+
+#ifdef LAYER_PERF_STAT
+  dump_perf_stat(); 
+#ifdef REPEAT_TEST
+
+   reset_layer_stat();
+
+   vector<vector<perf_stat>* >  all_stat;
+   int rep_number=10;
+
+   for(int i=0;i<rep_number;i++)
+   {
+      net_->Forward();
+      collect_layer_stat(all_stat);
+      reset_layer_stat();
+   }
+
+   //dump stats
+   dump_all_stat(all_stat);
+
+   for(int i=0;i<all_stat.size();i++)
+         delete all_stat[i];
+   
+#endif //REPEAT_TEST
+#endif //LAYER_PERF_STAT
+#endif //USE_PROFILING
+
+  /* Copy the output layer to a std::vector */
+  Blob<float>* output_layer = net_->output_blobs()[0];
+  const float* begin = output_layer->cpu_data();
+  const float* end = begin + output_layer->channels();
+  return std::vector<float>(begin, end);
+}
+
+#ifdef USE_PROFILING
+
+#ifdef LAYER_PERF_STAT
+
+#ifdef REPEAT_TEST
+void Classifier::collect_layer_stat(vector<vector<perf_stat>*>& all_stat)
+{
+   vector<perf_stat > * p_stat;
+   perf_stat * p_time_stat;
+   const vector<shared_ptr<Layer<float> > >& layers=net_->layers();
+
+   
+   p_stat=new vector<perf_stat>;
+
+   for (int i =0;i< layers.size(); i++) {
+        p_time_stat=layers[i]->get_time_stat();
+        p_stat->push_back(*p_time_stat);
+
+   }
+
+   all_stat.push_back(p_stat);
+}
+
+void Classifier::reset_layer_stat(void)
+{
+   const vector<shared_ptr<Layer<float> > >& layers=net_->layers();
+   perf_stat * p_time_stat;
+
+   for (int i =0;i< layers.size(); i++) {
+        p_time_stat=layers[i]->get_time_stat();
+
+        p_time_stat->count=0;
+        p_time_stat->total=0;
+        p_time_stat->used=p_time_stat->start=p_time_stat->end=0;
+   }
+}
+
+void Classifier::dump_all_stat(vector<vector<perf_stat>*>& all_stat)
+{
+
+   struct new_perf_stat {
+        perf_stat stat;
+        int       idx;
+   };
+    
+   vector<new_perf_stat > layer_stat;
+   perf_stat * p_stat;
+
+   uint64_t total_time=0;
+
+   layer_stat.resize(all_stat[0]->size());
+
+   for(int i=0;i<all_stat.size();i++)
+   {
+      for(int j=0;j<layer_stat.size();j++)
+       {
+          p_stat=&layer_stat[j].stat;
+
+          p_stat->total+=(*all_stat[i])[j].total;
+          p_stat->count+=(*all_stat[i])[j].count;
+          total_time+=(*all_stat[i])[j].total;
+       }
+   }
+
+   total_time=total_time/all_stat.size();
+
+   std::cout<<std::endl<<"----------------------------------"<<std::endl;
+   std::cout<<"STATS for "<<all_stat.size()<<" reptitions: ..."<<std::endl;
+   std::cout<<"Total time: "<<total_time<<" per forward"<<std::endl;
+   std::cout<<"Each layer stats: ..."<<std::endl;
+
+
+   for(int i=layer_stat.size()-1;i>=0;i--)
+   {
+      p_stat=&layer_stat[i].stat;
+
+      layer_stat[i].idx=i;
+
+     std::cout<<"  "<<i<<": used time: "<<p_stat->total/all_stat.size();
+     std::cout<<" ratio: "<<((float)p_stat->total)/all_stat.size()/total_time*100;
+     std::cout<<" enter count: "<<p_stat->count/all_stat.size()<<std::endl;
+   }
+
+   std::cout<<std::endl;
+
+   std::cout<<"time cost top 10 layers are: ..."<<std::endl;
+
+   std::sort(layer_stat.begin(),layer_stat.end(),[](const new_perf_stat& a, const new_perf_stat& b)
+       {
+          if(a.stat.total>b.stat.total)
+            return true;
+          else
+            return false;
+       });
+
+   uint64_t  top_total_time=0;
+
+   for(int i=0; i<10; i++)
+   {
+      p_stat=&layer_stat[i].stat;
+
+     std::cout<<"  "<<layer_stat[i].idx<<": used time: "<<p_stat->total/all_stat.size();
+     std::cout<<" ratio: "<<((float)p_stat->total)/all_stat.size()/total_time*100;
+     std::cout<<" enter count: "<<p_stat->count/all_stat.size()<<std::endl;
+     top_total_time+=p_stat->total;
+   }
+
+   std::cout<<"Top cost layers occupied: "<<(float)top_total_time/all_stat.size()/total_time*100<<std::endl;
+
+   std::cout<<std::endl;
+}
+
+#endif
+
+void Classifier::dump_single_layer_io(int idx, Layer<float> * p_layer)
+{
+   const LayerParameter& layer_param=p_layer->layer_param();
+
+   std::cout<<std::endl<<"LAYER IDX: "<<idx<<" name: "<<layer_param.name();
+   std::cout<<" type: "<<layer_param.type()<<std::endl;
+
+   const vector<Blob<float>*> *p_bottom_vec=p_layer->saved_bottom;
+
+   for(int i=0;i<layer_param.bottom_size(); i++)
+   {
+      std::cout<<"bottom "<<layer_param.bottom(i)<<": ";
+
+      Blob<float> * p_blob=(*p_bottom_vec)[i];
+
+      for(int j=0;j<p_blob->num_axes();j++)
+      {
+          std::cout<<p_blob->shape(j)<<" ";
+      }
+      std::cout<<std::endl;
+   }
+
+   const vector<Blob<float>*> *p_top_vec=p_layer->saved_top;
+   for(int i=0;i<layer_param.top_size(); i++)
+   {
+      std::cout<<"top "<<layer_param.top(i)<<": ";
+      Blob<float> * p_blob=(*p_top_vec)[i];
+
+      for(int j=0;j<p_blob->num_axes();j++)
+      {
+          std::cout<<p_blob->shape(j)<<" ";
+      }
+      std::cout<<std::endl;
+   }
+}
+
+void Classifier::dump_single_layer_perf(int idx, Layer<float> * p_layer, uint64_t total_net_time)
+{
+   const LayerParameter& layer_param=p_layer->layer_param();
+   perf_stat * p_time_stat;
+
+   p_time_stat=p_layer->get_time_stat();
+
+   std::cout<<std::endl<<"LAYER IDX: "<<idx<<" name: "<<layer_param.name();
+   std::cout<<" type: "<<layer_param.type();
+   std::cout<<"  ratio: "<<(float)p_time_stat->total/total_net_time*100<<std::endl;
+
+
+   std::cout<<"time stat:  total: "<<p_time_stat->total<<" count: "<<p_time_stat->count;
+   if(p_time_stat->count)
+    {
+       std::cout<<" average: "<<((float)p_time_stat->total)/p_time_stat->count;
+    }
+
+   std::cout<<" start: "<<p_time_stat->start<<" end: "<<p_time_stat->end;
+   std::cout<<std::endl;
+
+
+} 
+
+void Classifier::dump_perf_stat(void)
+{
+   uint64_t total_net_time=0;
+
+   const vector<shared_ptr<Layer<float> > >& layers=net_->layers();
+
+   std::cout<<"Input/output shape for each layer ... total: "<<layers.size()<<std::endl;
+
+   for (int i = layers.size() - 1; i >= 0; --i) {
+     dump_single_layer_io(i,layers[i].get());
+   }
+
+
+   for (int i = layers.size() - 1; i >= 0; --i) {
+
+     perf_stat * p_time_stat;
+
+     p_time_stat=layers[i]->get_time_stat();
+
+     total_net_time+=p_time_stat->total;
+
+   }
+   
+   std::cout<<"Time for each layer ... sum of all layers is : ";
+   std::cout<<total_net_time<<std::endl;
+
+   for (int i = layers.size() - 1; i >= 0; --i) {
+
+     dump_single_layer_perf(i,layers[i].get(),total_net_time);
+   }
+
+}
+
+#endif
+
+#endif //USE_PROFILING
+
+/* Wrap the input layer of the network in separate cv::Mat objects
+ * (one per channel). This way we save one memcpy operation and we
+ * don't need to rely on cudaMemcpy2D. The last preprocessing
+ * operation will write the separate channels directly to the input
+ * layer. */
+void Classifier::WrapInputLayer(std::vector<cv::Mat>* input_channels) {
+  Blob<float>* input_layer = net_->input_blobs()[0];
+
+  int width = input_layer->width();
+  int height = input_layer->height();
+  float* input_data = input_layer->mutable_cpu_data();
+  for (int i = 0; i < input_layer->channels(); ++i) {
+    cv::Mat channel(height, width, CV_32FC1, input_data);
+    input_channels->push_back(channel);
+    input_data += width * height;
+  }
+}
+
+void Classifier::Preprocess(const cv::Mat& img,
+                            std::vector<cv::Mat>* input_channels) {
+  /* Convert the input image to the input image format of the network. */
+  cv::Mat sample;
+  if (img.channels() == 3 && num_channels_ == 1)
+    cv::cvtColor(img, sample, cv::COLOR_BGR2GRAY);
+  else if (img.channels() == 4 && num_channels_ == 1)
+    cv::cvtColor(img, sample, cv::COLOR_BGRA2GRAY);
+  else if (img.channels() == 4 && num_channels_ == 3)
+    cv::cvtColor(img, sample, cv::COLOR_BGRA2BGR);
+  else if (img.channels() == 1 && num_channels_ == 3)
+    cv::cvtColor(img, sample, cv::COLOR_GRAY2BGR);
+  else
+    sample = img;
+
+  cv::Mat sample_resized;
+  if (sample.size() != input_geometry_)
+    cv::resize(sample, sample_resized, input_geometry_);
+  else
+    sample_resized = sample;
+
+  cv::Mat sample_float;
+  if (num_channels_ == 3)
+    sample_resized.convertTo(sample_float, CV_32FC3);
+  else
+    sample_resized.convertTo(sample_float, CV_32FC1);
+
+  cv::Mat sample_normalized;
+  cv::subtract(sample_float, mean_, sample_normalized);
+
+  /* This operation will write the separate BGR planes directly to the
+   * input layer of the network because it is wrapped by the cv::Mat
+   * objects in input_channels. */
+  cv::split(sample_normalized, *input_channels);
+
+  CHECK(reinterpret_cast<float*>(input_channels->at(0).data)
+        == net_->input_blobs()[0]->cpu_data())
+    << "Input channels are not wrapping the input layer of the network.";
+}
+
+int main(int argc, char** argv) {
+  if (argc != 6) {
+    std::cerr << "Usage: " << argv[0]
+              << " deploy.prototxt network.caffemodel"
+              << " mean.binaryproto labels.txt img.jpg" << std::endl;
+    return 1;
+  }
+
+  ::google::InitGoogleLogging(argv[0]);
+
+  string model_file   = argv[1];
+  string trained_file = argv[2];
+  string mean_file    = argv[3];
+  string label_file   = argv[4];
+  Classifier classifier(model_file, trained_file, mean_file, label_file);
+
+  string file = argv[5];
+
+  std::cout << "---------- Prediction for "
+            << file << " ----------" << std::endl;
+
+  cv::Mat img = cv::imread(file, -1);
+  CHECK(!img.empty()) << "Unable to decode image " << file;
+  std::vector<Prediction> predictions = classifier.Classify(img);
+
+  /* Print the top N predictions. */
+  for (size_t i = 0; i < predictions.size(); ++i) {
+    Prediction p = predictions[i];
+    std::cout << std::fixed << std::setprecision(4) << p.second << " - \""
+              << p.first << "\"" << std::endl;
+  }
+}
+#else
+int main(int argc, char** argv) {
+  LOG(FATAL) << "This example requires OpenCV; compile with USE_OPENCV.";
+}
+#endif  // USE_OPENCV
diff --git a/include/caffe/acl_layer.hpp b/include/caffe/acl_layer.hpp
new file mode 100644 (file)
index 0000000..db9fee5
--- /dev/null
@@ -0,0 +1,246 @@
+#ifndef CAFFE_ACL_LAYER_HPP_
+#define CAFFE_ACL_LAYER_HPP_
+
+#ifdef USE_ACL
+#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h"
+#include "arm_compute/runtime/CL/functions/CLConvolutionLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEActivationLayer.h"
+#include "arm_compute/runtime/CL/functions/CLActivationLayer.h"
+#include "arm_compute/runtime/NEON/functions/NENormalizationLayer.h"
+#include "arm_compute/runtime/CL/functions/CLNormalizationLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEPoolingLayer.h"
+#include "arm_compute/runtime/CL/functions/CLPoolingLayer.h"
+#include "arm_compute/runtime/NEON/functions/NESoftmaxLayer.h"
+#include "arm_compute/runtime/CL/functions/CLSoftmaxLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h"
+#include "arm_compute/runtime/CL/functions/CLFullyConnectedLayer.h"
+#include "arm_compute/runtime/CL/CLTensor.h"
+#include "arm_compute/runtime/Tensor.h"
+#include "arm_compute/runtime/CL/CLScheduler.h"
+using namespace arm_compute;
+#define FLAGS_ENABLE_ACL_ABSVAL    0x00000001
+#define FLAGS_ENABLE_ACL_BNLL      0x00000002
+#define FLAGS_ENABLE_ACL_CONV      0x00000004
+#define FLAGS_ENABLE_ACL_FC        0x00000008
+#define FLAGS_ENABLE_ACL_LRN       0x00000010
+#define FLAGS_ENABLE_ACL_POOLING   0x00000020
+#define FLAGS_ENABLE_ACL_RELU      0x00000040
+#define FLAGS_ENABLE_ACL_SIGMOID   0x00000080
+#define FLAGS_ENABLE_ACL_SOFTMAX   0x00000100
+#define FLAGS_ENABLE_ACL_TANH      0x00000200
+extern unsigned int bypass_acl_class_layer;
+#endif
+#ifdef USE_PROFILING
+#include "layer.hpp"
+
+#define MASK_LOG_APP_TIME 0x00000001
+#define MASK_LOG_ALLOCATE 0x00000002
+#define MASK_LOG_RUN      0x00000004
+#define MASK_LOG_CONFIG   0x00000008
+#define MASK_LOG_COPY     0x00000010
+#define MASK_LOG_ABSVAL   0x00000020
+#define MASK_LOG_BNLL     0x00000040
+#define MASK_LOG_CONV     0x00000080
+#define MASK_LOG_FC       0x00000100
+#define MASK_LOG_LRN      0x00000200
+#define MASK_LOG_POOLING  0x00000400
+#define MASK_LOG_RELU     0x00000800
+#define MASK_LOG_SIGMOID  0x00001000
+#define MASK_LOG_SOFTMAX  0x00002000
+#define MASK_LOG_TANH     0x00004000
+#define APP_TIME_INFO     MASK_LOG_APP_TIME,"time:       \t"
+#define ACL_ALLOCATE_INFO MASK_LOG_ALLOCATE,"allocate:   \t\t"
+#define ACL_RUN_INFO      MASK_LOG_RUN,     "run:        \t\t\t"
+#define ACL_CONFIG_INFO   MASK_LOG_CONFIG,  "configure:  \t\t\t\t"
+#define ACL_COPY_INFO     MASK_LOG_COPY,    "tensor_copy:\t\t\t\t\t"
+#define ACL_ABSVAL_INFO   MASK_LOG_ABSVAL,  "ACL_ABSVAL :\t\t\t\t\t\t"
+#define ACL_BNLL_INFO     MASK_LOG_BNLL,    "ACL_BNLL   :\t\t\t\t\t\t\t"
+#define ACL_CONV_INFO     MASK_LOG_CONV,    "ACL_CONV   :\t\t\t\t\t\t\t\t"
+#define ACL_FC_INFO       MASK_LOG_FC,      "ACL_FC     :\t\t\t\t\t\t\t\t\t"
+#define ACL_LRN_INFO      MASK_LOG_LRN,     "ACL_LRN    :\t\t\t\t\t\t\t\t\t\t"
+#define ACL_POOLING_INFO  MASK_LOG_POOLING, "ACL_POOLING:\t\t\t\t\t\t\t\t\t\t\t"
+#define ACL_RELU_INFO     MASK_LOG_RELU,    "ACL_RELU   :\t\t\t\t\t\t\t\t\t\t\t\t"
+#define ACL_SIGMOID_INFO  MASK_LOG_SIGMOID, "ACL_SIGMOID:\t\t\t\t\t\t\t\t\t\t\t\t\t"
+#define ACL_SOFTMAX_INFO  MASK_LOG_SOFTMAX, "ACL_SOFTMAX:\t\t\t\t\t\t\t\t\t\t\t\t\t\t"
+#define ACL_TANH_INFO     MASK_LOG_TANH,    "ACL_TANH   :\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t"
+extern unsigned int acl_log_flags;
+#endif //USE_PROFILING
+namespace caffe {
+#ifdef USE_ACL
+enum TensorType{
+    tensor_input,
+    tensor_output,
+    tensor_weights,
+    tensor_biases,
+};
+template <typename ACLTensor>
+class BaseTensor:public ACLTensor{
+public:
+    BaseTensor(bool share)
+       :share_(share),type_(tensor_input),allocate_(false){
+    }
+    virtual void bindmem(void *mem,bool share){
+        mem_=mem;
+        share_=share;
+    }
+    virtual void settensortype(TensorType type){
+        type_=type;
+    };
+    virtual void map(bool blocking = true){}
+    virtual void unmap(){}
+    virtual void commit();
+    int tensor_copy(void * mem, bool toTensor=true);
+protected:
+    void* mem_;
+    bool share_;
+    TensorType type_;
+    bool allocate_;
+};
+class GPUTensor:public BaseTensor<CLTensor>{
+public:
+    explicit GPUTensor(bool share)
+       :BaseTensor(share){}
+    virtual void map(bool blocking = true){
+        if (!allocate_){
+            CLTensor::allocator()->allocate();
+            allocate_=true;
+        }
+        CLTensor::map(blocking);
+     }
+     virtual void unmap(){
+        CLTensor::unmap();
+     }
+};
+class CPUTensor:public BaseTensor<Tensor>{
+public:
+    explicit CPUTensor(bool share)
+        :BaseTensor(share){}
+    virtual void map(bool blocking = true){
+        if (!allocate_){
+            Tensor::allocator()->allocate();
+            allocate_=true;
+        }
+    }
+    virtual void unmap(){
+    }
+};
+template <typename ACLLayer,typename ACLTensor>
+class ACLXPUBaseLayer{
+public:
+    virtual void commit(){
+        if (input) {
+            input->settensortype(tensor_input);
+            input->commit();
+        }
+        if (output){
+            output->settensortype(tensor_output);
+            output->commit();
+        }
+        if (weights){
+            weights->settensortype(tensor_weights);
+            weights->commit();
+        }
+        if (biases){
+            biases->settensortype(tensor_biases);
+            biases->commit();
+        }
+    }
+    virtual void run(bool gpu){
+        commit();
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_RUN_INFO);
+#endif //USE_PROFILING
+        layer->run();
+        if (gpu) {
+            // Make sure all the OpenCL jobs are done executing:
+            CLScheduler::get().sync();
+        }
+    }
+    virtual bool reshape(TensorShape &shape,TensorType type);
+    explicit ACLXPUBaseLayer(){
+        layer=nullptr;
+        input=nullptr;
+        output=nullptr;
+        weights=nullptr;
+        biases=nullptr;
+#ifdef USE_CONV_CACHE
+        for(int i = 0; i < 16; ++i){
+           cache.layer[i] = nullptr;
+           cache.input[i] = nullptr;
+           cache.output[i] = nullptr;
+           cache.weights[i] = nullptr;
+           cache.biases[i] = nullptr;
+        }
+#endif //USE_CONV_CACHE    
+    }
+    virtual void freelayer(){
+#ifndef USE_CONV_CACHE
+        if (layer) delete layer;
+        if (input) delete input;
+        if (output) delete output;
+        if (weights) delete weights;
+        if (biases) delete biases;
+#endif //USE_CONV_CACHE    
+        layer=nullptr;
+        input=nullptr;
+        output=nullptr;
+        weights=nullptr;
+        biases=nullptr;
+    }
+    virtual ~ACLXPUBaseLayer(){
+        freelayer();
+    }
+    ACLLayer *layer;
+    ACLTensor *input;
+    ACLTensor *output;
+    ACLTensor *weights;
+    ACLTensor *biases;
+#ifdef USE_CONV_CACHE
+    struct{
+        ACLLayer *layer[16];
+        ACLTensor *input[16];
+        ACLTensor *output[16];
+        ACLTensor *weights[16];
+        ACLTensor *biases[16];
+    }cache;
+#endif //USE_CONV_CACHE    
+};
+template <typename GPULayer, typename CPULayer>
+class ACLBaseLayer {
+public:
+    explicit ACLBaseLayer();
+    virtual void gpu_run();
+    virtual void cpu_run();
+    virtual ~ACLBaseLayer();
+    virtual GPULayer * new_gpulayer();
+    virtual CPULayer * new_cpulayer();
+    ACLXPUBaseLayer<GPULayer,GPUTensor>& gpu(){
+        return gpu_;
+    }
+    ACLXPUBaseLayer<CPULayer,CPUTensor>& cpu(){
+        return cpu_;
+    }
+    bool checkreshape(TensorShape shape,bool gpu=false, TensorType type=tensor_input);
+    template <typename ACLTensor> bool tensor_mem(ACLTensor *tensor,void *mem,bool share=false);
+    template <typename ACLTensor> bool tensor_mem(void *mem,ACLTensor *tensor,bool share=false);
+    template <typename ACLTensor> ACLTensor * new_tensor(TensorShape shape,void *mem=nullptr,bool share=false);
+protected:
+    ACLXPUBaseLayer<GPULayer,GPUTensor> gpu_;
+    ACLXPUBaseLayer<CPULayer,CPUTensor> cpu_;
+    bool init_layer_;
+    bool force_bypass_acl_path_;
+
+};
+
+#endif
+}
+#define INSTANTIATE_ACLBASECLASS(GPULayer,CPULayer) \
+  template class ACLBaseLayer<GPULayer,CPULayer>; 
+
+#define INSTANTIATE_ACLBASE_FUNCTION(GPULayer,CPULayer,ACLTensor) \
+    template bool ACLBaseLayer<GPULayer,CPULayer>::tensor_mem<ACLTensor>(ACLTensor *tensor,void *mem,bool share); \
+    template bool ACLBaseLayer<GPULayer,CPULayer>::tensor_mem(void *mem,ACLTensor *tensor,bool share); \
+    template ACLTensor * ACLBaseLayer<GPULayer,CPULayer>::new_tensor(TensorShape shape,void *mem,bool share); \
+
+
+#endif
index 0688209..6ea2bb5 100644 (file)
@@ -4,6 +4,12 @@
 #ifndef CAFFE_CAFFE_HPP_
 #define CAFFE_CAFFE_HPP_
 
+#ifdef USE_ACL
+#ifndef CPU_ONLY
+#define CPU_ONLY
+#endif
+#endif
+
 #include "caffe/blob.hpp"
 #include "caffe/common.hpp"
 #include "caffe/filler.hpp"
index 4904d1d..34f2b8c 100644 (file)
@@ -145,7 +145,13 @@ class Caffe {
   // into the program since that may cause allocation of pinned memory being
   // freed in a non-pinned way, which may cause problems - I haven't verified
   // it personally but better to note it here in the header file.
-  inline static void set_mode(Brew mode) { Get().mode_ = mode; }
+#ifdef USE_ACL  
+  inline static bool arm_gpu_mode() {return Get().use_mali_gpu_;}
+  inline static void set_mode(Brew mode) { Get().mode_ = CPU;  set_arm_gpu_mode(mode==GPU);}
+  inline static void set_arm_gpu_mode(bool use_mali_gpu) { Get().use_mali_gpu_ = use_mali_gpu;}
+#else
+  inline static void set_mode(Brew mode) { Get().mode_ = mode;}
+#endif
   // Sets the random seed of both boost and curand
   static void set_random_seed(const unsigned int seed);
   // Sets the device. Since we have cublas and curand stuff, set device also
@@ -175,7 +181,9 @@ class Caffe {
   shared_ptr<RNG> random_generator_;
 
   Brew mode_;
-
+#ifdef USE_ACL  
+  bool use_mali_gpu_;
+#endif
   // Parallel training
   int solver_count_;
   int solver_rank_;
index 30dbfd5..49b1e69 100644 (file)
 #include "caffe/layer_factory.hpp"
 #include "caffe/proto/caffe.pb.h"
 #include "caffe/util/math_functions.hpp"
-
+#ifdef USE_PROFILING
+#include <sys/time.h>
+#define        NANO_SEC_CONV 1000000
+extern unsigned int acl_log_flags;
+#endif //USE_PROFILING
 /**
  Forward declare boost::thread instead of including boost/thread.hpp
  to avoid a boost/NVCC issues (#1009, #1010) on OSX.
 namespace boost { class mutex; }
 
 namespace caffe {
+#ifdef USE_PROFILING
+class logtime_util
+{
+  public:
+    logtime_util(int mask_, const char* information_){
+      mask = mask_;
+      if(acl_log_flags & mask){
+        strncpy(information, information_, 255);
+        gettimeofday(&tv[0], NULL);
+      }
+    }
+    ~logtime_util(){
+      if(acl_log_flags & mask){
+        long time[2];
+        gettimeofday(&tv[1], NULL);
+        time[0] = tv[0].tv_sec * NANO_SEC_CONV + tv[0].tv_usec;
+        time[1]   = tv[1].tv_sec * NANO_SEC_CONV + tv[1].tv_usec;
+        printf("%s %.6lf\n", information, (((double)time[1] - time[0]) / NANO_SEC_CONV));
+      }
+    }
+    void log_time(bool start)
+    {
+      if(acl_log_flags & mask){
+        if (start){
+          gettimeofday(&tv[0], NULL);
+        }
+        else{
+          long time[2];
+          gettimeofday(&tv[1], NULL);
+          time[0] = tv[0].tv_sec * NANO_SEC_CONV + tv[0].tv_usec;
+          time[1]   = tv[1].tv_sec * NANO_SEC_CONV + tv[1].tv_usec;
+          printf("%s %.6lf\n", information, (((double)time[1] - time[0]) / NANO_SEC_CONV));
+        }
+      }
+    }
+private:
+  struct timeval tv[2];
+  int mask;
+  char information[256];
+};
+
+#ifdef LAYER_PERF_STAT
+
+struct perf_stat {
+
+uint64_t total;
+uint32_t start;
+uint32_t end;
+uint32_t used;
+uint32_t count;
+
+perf_stat(): total(0),start(0),end(0),count(0){};
+
+};
+
+
+#endif
+#endif //USE_PROFILING
 
 /**
  * @brief An interface for the units of computation which can be composed into a
@@ -123,8 +185,13 @@ class Layer {
    *
    * Your layer should implement Forward_cpu and (optionally) Forward_gpu.
    */
+#ifdef USE_PROFILING
+   Dtype Forward(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+#else
   inline Dtype Forward(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
+#endif //USE_PROFILING
 
   /**
    * @brief Given the top blob error gradients, compute the bottom blob error
@@ -290,7 +357,19 @@ class Layer {
     }
     param_propagate_down_[param_id] = value;
   }
+  
+#ifdef USE_PROFILING
 
+#ifdef LAYER_PERF_STAT
+
+   const vector<Blob<Dtype>*> * saved_top;
+   const vector<Blob<Dtype>*> * saved_bottom;
+   perf_stat * get_time_stat(void) { return &time_stat_;}
+   perf_stat * get_pmu_stat(int index) { return &pmu_stat_[index];}
+
+#endif
+
+#endif //USE_PROFILING
 
  protected:
   /** The protobuf that stores the layer parameters */
@@ -404,8 +483,17 @@ class Layer {
 
  private:
   DISABLE_COPY_AND_ASSIGN(Layer);
+
+#ifdef USE_PROFILING
+#ifdef LAYER_PERF_STAT
+  perf_stat time_stat_;
+  perf_stat pmu_stat_[16];
+#endif
+#endif //USE_PROFILING
 };  // class Layer
 
+
+#ifndef LAYER_PERF_STAT
 // Forward and backward wrappers. You should implement the cpu and
 // gpu specific implementations instead, and should not change these
 // functions.
@@ -445,6 +533,8 @@ inline Dtype Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
   return loss;
 }
 
+#endif
+
 template <typename Dtype>
 inline void Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down,
diff --git a/include/caffe/layers/acl_absval_layer.hpp b/include/caffe/layers/acl_absval_layer.hpp
new file mode 100644 (file)
index 0000000..c165540
--- /dev/null
@@ -0,0 +1,57 @@
+#ifndef CAFFE_ACL_ABSVAL_LAYER_HPP_
+#define CAFFE_ACL_ABSVAL_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+#include "caffe/layers/absval_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#include "caffe/layers/acl_base_activation_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/**
+ * @brief ACL acceleration of AbsValLayer.
+ *        Fallback to AbsValLayer for some corner cases. 
+ */
+template <typename Dtype>
+class ACLAbsValLayer : public ACLBaseActivationLayer<Dtype>,public AbsValLayer<Dtype> {
+ public:
+  explicit ACLAbsValLayer(const LayerParameter& param)
+      : ACLBaseActivationLayer<Dtype>(param),AbsValLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLAbsValLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top, ActivationLayerInfo::ActivationFunction type);
+
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_ABSVAL_LAYER_HPP_
diff --git a/include/caffe/layers/acl_base_activation_layer.hpp b/include/caffe/layers/acl_base_activation_layer.hpp
new file mode 100644 (file)
index 0000000..e2abdaf
--- /dev/null
@@ -0,0 +1,56 @@
+#ifndef CAFFE_ACL_BASE_ACTIVATION_LAYER_HPP_
+#define CAFFE_ACL_BASE_ACTIVATION_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#include "caffe/layers/acl_base_activation_layer.hpp"
+#endif
+
+
+namespace caffe {
+
+#ifdef USE_ACL
+/**
+ * @brief ACL acceleration of BNLLLayer.
+ *        Fallback to BNLLLayer for some corner cases. 
+ */
+template <typename Dtype>
+class ACLBaseActivationLayer : public ACLBaseLayer<CLActivationLayer,NEActivationLayer> {
+ public:
+  explicit ACLBaseActivationLayer(const LayerParameter& param)
+      {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLBaseActivationLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top,ActivationLayerInfo::ActivationFunction type=ActivationLayerInfo::ActivationFunction::RELU);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_BASE_ACTIVATION_LAYER_HPP_
diff --git a/include/caffe/layers/acl_bnll_layer.hpp b/include/caffe/layers/acl_bnll_layer.hpp
new file mode 100644 (file)
index 0000000..ea2f8a1
--- /dev/null
@@ -0,0 +1,57 @@
+#ifndef CAFFE_ACL_BNLL_LAYER_HPP_
+#define CAFFE_ACL_BNLL_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+#include "caffe/layers/bnll_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#include "caffe/layers/acl_base_activation_layer.hpp"
+#endif
+
+
+namespace caffe {
+
+#ifdef USE_ACL
+/**
+ * @brief ACL acceleration of BNLLLayer.
+ *        Fallback to BNLLLayer for some corner cases. 
+ */
+template <typename Dtype>
+class ACLBNLLLayer : public ACLBaseActivationLayer<Dtype>,public BNLLLayer<Dtype> {
+ public:
+  explicit ACLBNLLLayer(const LayerParameter& param)
+      : ACLBaseActivationLayer<Dtype>(param),BNLLLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLBNLLLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top, ActivationLayerInfo::ActivationFunction type);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_BNLL_LAYER_HPP_
diff --git a/include/caffe/layers/acl_conv_layer.hpp b/include/caffe/layers/acl_conv_layer.hpp
new file mode 100644 (file)
index 0000000..b4a7584
--- /dev/null
@@ -0,0 +1,56 @@
+#ifndef CAFFE_ACL_CONV_LAYER_HPP_
+#define CAFFE_ACL_CONV_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/conv_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/*
+ * @brief ACL implementation of ConvolutionLayer.
+ *        Fallback to ConvolutionLayer for some corner cases.
+ *
+*/
+template <typename Dtype>
+class ACLConvolutionLayer : public ACLBaseLayer<CLConvolutionLayer,NEConvolutionLayer>,public ConvolutionLayer<Dtype> {
+ public:
+  explicit ACLConvolutionLayer(const LayerParameter& param)
+      : ConvolutionLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLConvolutionLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_CONV_LAYER_HPP_
diff --git a/include/caffe/layers/acl_inner_product_layer.hpp b/include/caffe/layers/acl_inner_product_layer.hpp
new file mode 100644 (file)
index 0000000..f42becb
--- /dev/null
@@ -0,0 +1,54 @@
+#ifndef CAFFE_ACL_INNER_PRODUCT_LAYER_HPP_
+#define CAFFE_ACL_INNER_PRODUCT_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/inner_product_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/**
+ * @brief ACL acceleration of InnerProductLayer.
+ *        Fallback to InnerProductLayer for some corner cases. 
+ */
+template <typename Dtype>
+class ACLInnerProductLayer : public ACLBaseLayer<CLFullyConnectedLayer,NEFullyConnectedLayer>,public InnerProductLayer<Dtype> {
+ public:
+  explicit ACLInnerProductLayer(const LayerParameter& param)
+      : InnerProductLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLInnerProductLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_INNER_PRODUCT_LAYER_HPP_
diff --git a/include/caffe/layers/acl_lrn_layer.hpp b/include/caffe/layers/acl_lrn_layer.hpp
new file mode 100644 (file)
index 0000000..6fd9fbc
--- /dev/null
@@ -0,0 +1,54 @@
+#ifndef CAFFE_ACL_LRN_LAYER_HPP_
+#define CAFFE_ACL_LRN_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/lrn_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/*
+ * @brief ACL implementation of LRNLayer.
+ *        Fallback to LRNLayer for some corner cases.
+*/
+template <typename Dtype>
+class ACLLRNLayer : public ACLBaseLayer<CLNormalizationLayer,NENormalizationLayer>,public LRNLayer<Dtype> {
+ public:
+  explicit ACLLRNLayer(const LayerParameter& param)
+      : LRNLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLLRNLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_LRN_LAYER_HPP_
diff --git a/include/caffe/layers/acl_pooling_layer.hpp b/include/caffe/layers/acl_pooling_layer.hpp
new file mode 100644 (file)
index 0000000..acca35c
--- /dev/null
@@ -0,0 +1,54 @@
+#ifndef CAFFE_ACL_POOLING_LAYER_HPP_
+#define CAFFE_ACL_POOLING_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/pooling_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/*
+ * @brief ACL implementation of PoolingLayer.
+ *        Fallback to PoolingLayer for some corner cases.
+*/
+template <typename Dtype>
+class ACLPoolingLayer : public ACLBaseLayer<CLPoolingLayer,NEPoolingLayer>,public PoolingLayer<Dtype> {
+ public:
+  explicit ACLPoolingLayer(const LayerParameter& param)
+      : PoolingLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLPoolingLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_POOLING_LAYER_HPP_
diff --git a/include/caffe/layers/acl_relu_layer.hpp b/include/caffe/layers/acl_relu_layer.hpp
new file mode 100644 (file)
index 0000000..041dbec
--- /dev/null
@@ -0,0 +1,56 @@
+#ifndef CAFFE_ACL_RELU_LAYER_HPP_
+#define CAFFE_ACL_RELU_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+#include "caffe/layers/relu_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#include "caffe/layers/acl_base_activation_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/**
+ * @brief ACL acceleration of ReLULayer.
+ *        Fallback to ReLULayer for some corner cases. 
+ */
+template <typename Dtype>
+class ACLReLULayer : public ACLBaseActivationLayer<Dtype>,public ReLULayer<Dtype> {
+ public:
+  explicit ACLReLULayer(const LayerParameter& param)
+      : ACLBaseActivationLayer<Dtype>(param), ReLULayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLReLULayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_RELU_LAYER_HPP_
diff --git a/include/caffe/layers/acl_sigmoid_layer.hpp b/include/caffe/layers/acl_sigmoid_layer.hpp
new file mode 100644 (file)
index 0000000..8638f73
--- /dev/null
@@ -0,0 +1,55 @@
+#ifndef CAFFE_ACL_SIGMOID_LAYER_HPP_
+#define CAFFE_ACL_SIGMOID_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+#include "caffe/layers/sigmoid_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#include "caffe/layers/acl_base_activation_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/**
+ * @brief ACL acceleration of SigmoidLayer.
+ */
+template <typename Dtype>
+class ACLSigmoidLayer : public ACLBaseActivationLayer<Dtype>,public SigmoidLayer<Dtype> {
+ public:
+  explicit ACLSigmoidLayer(const LayerParameter& param)
+      : ACLBaseActivationLayer<Dtype>(param),SigmoidLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLSigmoidLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top, ActivationLayerInfo::ActivationFunction type);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_SIGMOID_LAYER_HPP_
diff --git a/include/caffe/layers/acl_softmax_layer.hpp b/include/caffe/layers/acl_softmax_layer.hpp
new file mode 100644 (file)
index 0000000..9e450f5
--- /dev/null
@@ -0,0 +1,54 @@
+#ifndef CAFFE_ACL_SOFTMAX_LAYER_HPP_
+#define CAFFE_ACL_SOFTMAX_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/softmax_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/**
+ * @brief ACL implementation of SoftmaxLayer.
+ *        Fallback to SoftmaxLayer for some corner cases.
+ */
+template <typename Dtype>
+class ACLSoftmaxLayer : public ACLBaseLayer<CLSoftmaxLayer,NESoftmaxLayer>,public SoftmaxLayer<Dtype> {
+ public:
+  explicit ACLSoftmaxLayer(const LayerParameter& param)
+      : SoftmaxLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLSoftmaxLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+     const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+     const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_SOFTMAX_LAYER_HPP_
diff --git a/include/caffe/layers/acl_tanh_layer.hpp b/include/caffe/layers/acl_tanh_layer.hpp
new file mode 100644 (file)
index 0000000..5a74ce5
--- /dev/null
@@ -0,0 +1,56 @@
+#ifndef CAFFE_ACL_TANH_LAYER_HPP_
+#define CAFFE_ACL_TANH_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/neuron_layer.hpp"
+#include "caffe/layers/tanh_layer.hpp"
+
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+#include "caffe/layers/acl_base_activation_layer.hpp"
+#endif
+
+namespace caffe {
+
+#ifdef USE_ACL
+/**
+ * @brief ACL acceleration of TanHLayer.
+ *        Fallback to TanHLayer for some corner cases. 
+ */
+template <typename Dtype>
+class ACLTanHLayer : public ACLBaseActivationLayer<Dtype>,public TanHLayer<Dtype> {
+ public:
+  explicit ACLTanHLayer(const LayerParameter& param)
+      : ACLBaseActivationLayer<Dtype>(param),TanHLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~ACLTanHLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom){
+                 NOT_IMPLEMENTED;
+      }
+  virtual void SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top, ActivationLayerInfo::ActivationFunction type);
+};
+#endif
+
+}  // namespace caffe
+
+#endif  // CAFFE_ACL_TANH_LAYER_HPP_
index e3fe4fe..e7ab4b7 100644 (file)
@@ -7,7 +7,11 @@
 
 // Stub out GPU calls as unavailable.
 
+#ifdef USE_ACL
+#define NO_GPU 
+#else
 #define NO_GPU LOG(FATAL) << "Cannot use GPU in CPU-only Caffe: check mode."
+#endif
 
 #define STUB_GPU(classname) \
 template <typename Dtype> \
index 71549c1..dbd8bb6 100644 (file)
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
 #ifndef CAFFE_UTIL_HDF5_H_
 #define CAFFE_UTIL_HDF5_H_
 
@@ -37,3 +38,4 @@ string hdf5_get_name_by_idx(hid_t loc_id, int idx);
 }  // namespace caffe
 
 #endif   // CAFFE_UTIL_HDF5_H_
+#endif   // USE_HDF5
diff --git a/src/caffe/acl_layer.cpp b/src/caffe/acl_layer.cpp
new file mode 100644 (file)
index 0000000..4c20037
--- /dev/null
@@ -0,0 +1,274 @@
+#ifdef USE_ACL
+#include "caffe/acl_layer.hpp"
+
+unsigned int bypass_acl_class_layer =    (0 | \
+                                          /*0xffffffff |*/ \
+                                          /*FLAGS_ENABLE_ACL_FC |*/ \
+                                          /*FLAGS_ENABLE_ACL_LRN |*/ \
+                                          0 );
+
+#ifdef USE_PROFILING
+
+#include "arm_neon.h"
+
+unsigned int acl_log_flags = (0 | \
+                              MASK_LOG_APP_TIME | \
+                            /*MASK_LOG_ALLOCATE | */\
+                            /*MASK_LOG_ALLOCATE | */\
+                            /*MASK_LOG_RUN      | */\
+                            /*MASK_LOG_CONFIG   | */\
+                            /*MASK_LOG_COPY     | */\
+                              MASK_LOG_ABSVAL   | \
+                              MASK_LOG_BNLL     | \
+                              MASK_LOG_CONV     | \
+                              MASK_LOG_FC       | \
+                              MASK_LOG_LRN      | \
+                              MASK_LOG_POOLING  | \
+                              MASK_LOG_RELU     | \
+                              MASK_LOG_SIGMOID  | \
+                              MASK_LOG_SOFTMAX  | \
+                              MASK_LOG_TANH     | \
+                              0);                                          
+#include <stdio.h>      /* printf */
+#include <stdlib.h>     /* getenv */
+#endif //USE_PROFILING
+
+namespace caffe {
+template <typename GPULayer, typename CPULayer>
+ACLBaseLayer<GPULayer,CPULayer>::ACLBaseLayer()
+    :init_layer_(true),force_bypass_acl_path_(false){
+  const char* pBypassACL;
+  pBypassACL = getenv ("BYPASSACL");
+  if (pBypassACL){
+    unsigned int bacl;
+    sscanf(pBypassACL,"%i", &bacl);
+       if(bacl != bypass_acl_class_layer){
+           bypass_acl_class_layer = bacl;
+        printf("BYPASSACL<%s>\n", pBypassACL);
+        printf("BYPASSACL: %x\n", bypass_acl_class_layer);
+       }
+  }
+#ifdef USE_PROFILING
+  const char* pLogACL;
+  pLogACL    = getenv("LOGACL");
+  if (pLogACL){
+    unsigned int alf;
+    sscanf(pLogACL,"%i", &alf);
+       if (alf != acl_log_flags){
+           acl_log_flags = alf;
+        printf("LOGACL<%s>\n", pLogACL);
+        printf("LOGACL: %x\n", acl_log_flags);
+       }
+  }
+#endif //USE_PROFILING
+}
+template <typename GPULayer, typename CPULayer>
+void ACLBaseLayer<GPULayer,CPULayer>::gpu_run() {
+    gpu_.run(true);
+}
+template <typename GPULayer, typename CPULayer>
+void ACLBaseLayer<GPULayer,CPULayer>::cpu_run() {
+    cpu_.run(false);
+}
+
+template <typename GPULayer, typename CPULayer>
+ACLBaseLayer<GPULayer,CPULayer>::~ACLBaseLayer(){
+}
+template <typename GPULayer, typename CPULayer>
+template <typename ACLTensor> ACLTensor * ACLBaseLayer<GPULayer,CPULayer>::new_tensor(TensorShape shape,void *mem,bool share)
+{
+    ACLTensor * tensor=new ACLTensor(share);
+#if 1    //F32
+    tensor->allocator()->init(TensorInfo(shape, Format::F32));
+#else  //F16
+    tensor->allocator()->init(TensorInfo(shape, Format::F16));
+#endif    
+    tensor->bindmem(mem,share);
+    return tensor;
+}
+
+template <typename ACLTensor>
+void BaseTensor<ACLTensor>::commit(){
+    if (!share_&&mem_) {
+        if (!allocate_){ 
+#ifdef USE_PROFILING
+            logtime_util log_time(ACL_ALLOCATE_INFO);
+#endif //USE_PROFILING
+            ACLTensor::allocator()->allocate(); 
+            allocate_=true;
+        }
+        if (type_!= tensor_output) {
+           tensor_copy(mem_);
+        }
+        mem_=nullptr;
+    }
+}
+
+template <typename ACLTensor>
+int BaseTensor<ACLTensor>::tensor_copy(void * mem,bool toTensor)
+{
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_COPY_INFO);
+#endif //USE_PROFILING
+    arm_compute::Window window;
+    ACLTensor* tensor=this;
+    window.use_tensor_dimensions(tensor->info(), /* first_dimension =*/Window::DimY); // Iterate through the rows (not each element)
+    int width = tensor->info()->tensor_shape()[0]; //->dimension(0); //window.x().end() - window.x().start(); // + 1;
+    int height = tensor->info()->tensor_shape()[1]; //->dimension(1); //window.y().end() - window.y().start(); // + 1;
+    int deepth = tensor->info()->tensor_shape()[2];
+    map();
+    // Create an iterator:
+    arm_compute::Iterator it(tensor, window);
+    // Except it works for an arbitrary number of dimensions
+    if (toTensor) { //mem->tensor
+        arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates & id)
+        {
+#if 0 //F16
+            if (tensor->info()->element_size() ==2)
+            {
+                for(int i = 0; i < width; i+= 4){
+                    auto pa = (float32x4_t*)((char*)mem) + ((id[3] * (width * height * deepth) + id.z() * (width * height) + id.y() * width + id.x() + i) * 4);
+                    *(float16x4_t*)(((char*)it.ptr()) + i*2) = vcvt_f16_f32(*pa);
+                }
+            }
+            else{
+#endif
+                memcpy(it.ptr(), ((char*)mem) + ((id[3] * (width * height * deepth) + id.z() * (width * height) + id.y() * width + id.x()) * tensor->info()->element_size()), width * tensor->info()->element_size());
+#if 0 //F16
+            }
+#endif
+        },
+        it);
+    }else{ //tensor-->mem
+        arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates & id)
+        {
+#if 0 //F16            
+            if (tensor->info()->element_size() ==2)
+            {
+                for(int i = 0; i < width; i+= 4){
+                    auto pa = (float32x4_t*)(((char*)mem) + ((id[3] * (width * height * deepth) + id.z() * (width * height) + id.y() * width + id.x() + i) * 4));
+                    *pa = vcvt_f32_f16(*(float16x4_t*)(((char*)it.ptr()) + i*2));
+                }
+            }
+            else{
+#endif                 
+                memcpy(((char*)mem) + ((id[3] * (width * height * deepth) + id.z() * (width * height) + id.y() * width) * tensor->info()->element_size()), it.ptr(), width * tensor->info()->element_size());
+#if 0 //F16                            
+            }
+#endif                 
+        },
+        it);
+    }
+    unmap();
+
+    return 0;
+}
+
+template <typename GPULayer, typename CPULayer>
+template <typename ACLTensor> bool  ACLBaseLayer<GPULayer,CPULayer>::tensor_mem(ACLTensor *tensor,void *mem,bool share)
+{
+    tensor->bindmem(mem,share);
+    return true;
+}
+
+template <typename GPULayer, typename CPULayer>
+template <typename ACLTensor> bool  ACLBaseLayer<GPULayer,CPULayer>::tensor_mem(void *mem,ACLTensor *tensor,bool share)
+{
+    if (mem==tensor->buffer()) return true;
+    if (!share) {
+     tensor->tensor_copy(mem,false);
+    }
+    return true;
+}
+
+
+template <typename GPULayer, typename CPULayer>
+bool ACLBaseLayer<GPULayer,CPULayer>::checkreshape(TensorShape shape,bool gpu, TensorType type)
+{
+    if (gpu) {
+        init_layer_ = gpu_.reshape(shape,type);
+    }else{
+        init_layer_ = cpu_.reshape(shape,type);
+    }
+    return init_layer_;
+}
+
+template <typename GPULayer, typename CPULayer>
+GPULayer * ACLBaseLayer<GPULayer,CPULayer>::new_gpulayer(){
+        gpu_.layer= new GPULayer;
+        return gpu_.layer;
+}
+template <typename GPULayer, typename CPULayer>
+CPULayer * ACLBaseLayer<GPULayer,CPULayer>::new_cpulayer(){
+        cpu_.layer= new CPULayer;
+        return cpu_.layer;
+}
+template <typename ACLLayer,typename ACLTensor>
+bool ACLXPUBaseLayer<ACLLayer,ACLTensor>::reshape(TensorShape &shape,TensorType type)
+{
+    TensorShape _shape;
+    if (!layer) return true;
+#ifdef USE_CONV_CACHE
+    if (tensor_input == type){
+        _shape = input->info()->tensor_shape();
+        if (_shape.total_size()==shape.total_size() && _shape[0]==shape[0] && _shape[1]==shape[1]) {
+            return false;
+        }
+        for(int i = 0; i < 16; ++i){
+            if(cache.input[i] == nullptr) break;
+            _shape = cache.input[i]->info()->tensor_shape();
+            if (_shape.total_size()==shape.total_size() && _shape[0]==shape[0] && _shape[1]==shape[1]) {
+                this->layer = cache.layer[i];
+                this->input = cache.input[i];
+                this->output = cache.output[i];
+                this->weights = cache.weights[i];
+                this->biases = cache.biases[i]; 
+                return false;
+            }
+        }
+    }
+#endif //USE_CONV_CACHE    
+    switch (type) {
+    case tensor_biases:
+        _shape = biases->info()->tensor_shape();
+        break;
+    case tensor_weights:
+        _shape = weights->info()->tensor_shape();
+        break;
+    case tensor_output:
+        _shape = output->info()->tensor_shape();
+        break;
+    case tensor_input:
+    default:
+        _shape = input->info()->tensor_shape();
+        break;
+    }
+    if (_shape.total_size()==shape.total_size() && _shape[0]==shape[0] && _shape[1]==shape[1]) {
+        return false;
+    }
+    freelayer();
+    return true;
+}
+
+INSTANTIATE_ACLBASECLASS(CLNormalizationLayer,NENormalizationLayer); 
+  INSTANTIATE_ACLBASE_FUNCTION(CLNormalizationLayer,NENormalizationLayer,GPUTensor);
+  INSTANTIATE_ACLBASE_FUNCTION(CLNormalizationLayer,NENormalizationLayer,CPUTensor);
+INSTANTIATE_ACLBASECLASS(CLActivationLayer,NEActivationLayer); 
+  INSTANTIATE_ACLBASE_FUNCTION(CLActivationLayer,NEActivationLayer,GPUTensor);
+  INSTANTIATE_ACLBASE_FUNCTION(CLActivationLayer,NEActivationLayer,CPUTensor);
+INSTANTIATE_ACLBASECLASS(CLPoolingLayer,NEPoolingLayer); 
+  INSTANTIATE_ACLBASE_FUNCTION(CLPoolingLayer,NEPoolingLayer,GPUTensor);
+  INSTANTIATE_ACLBASE_FUNCTION(CLPoolingLayer,NEPoolingLayer,CPUTensor);
+INSTANTIATE_ACLBASECLASS(CLSoftmaxLayer,NESoftmaxLayer); 
+  INSTANTIATE_ACLBASE_FUNCTION(CLSoftmaxLayer,NESoftmaxLayer,GPUTensor);
+  INSTANTIATE_ACLBASE_FUNCTION(CLSoftmaxLayer,NESoftmaxLayer,CPUTensor);
+INSTANTIATE_ACLBASECLASS(CLFullyConnectedLayer,NEFullyConnectedLayer); 
+  INSTANTIATE_ACLBASE_FUNCTION(CLFullyConnectedLayer,NEFullyConnectedLayer,GPUTensor);
+  INSTANTIATE_ACLBASE_FUNCTION(CLFullyConnectedLayer,NEFullyConnectedLayer,CPUTensor);
+INSTANTIATE_ACLBASECLASS(CLConvolutionLayer,NEConvolutionLayer); 
+  INSTANTIATE_ACLBASE_FUNCTION(CLConvolutionLayer,NEConvolutionLayer,GPUTensor);
+  INSTANTIATE_ACLBASE_FUNCTION(CLConvolutionLayer,NEConvolutionLayer,CPUTensor);
+
+}
+
+#endif
index 4f6f9bc..63097ca 100644 (file)
@@ -6,6 +6,10 @@
 
 #include "caffe/common.hpp"
 #include "caffe/util/rng.hpp"
+#ifdef USE_ACL
+#include "arm_compute/runtime/CL/CLScheduler.h"
+using namespace arm_compute;
+#endif
 
 namespace caffe {
 
@@ -52,8 +56,12 @@ void GlobalInit(int* pargc, char*** pargv) {
 #ifdef CPU_ONLY  // CPU-only Caffe.
 
 Caffe::Caffe()
-    : random_generator_(), mode_(Caffe::CPU),
-      solver_count_(1), solver_rank_(0), multiprocess_(false) { }
+    : random_generator_(), mode_(Caffe::CPU),use_mali_gpu_(false),
+      solver_count_(1), solver_rank_(0), multiprocess_(false) {
+#ifdef USE_ACL
+    CLScheduler::get().default_init();
+#endif
+}
 
 Caffe::~Caffe() { }
 
index 684ae88..677ed10 100644 (file)
@@ -1,7 +1,81 @@
 #include "caffe/layer.hpp"
 
+#ifdef USE_PROFILING
+
+#ifdef LAYER_PERF_STAT
+#include <time.h>
+
+#endif
+#endif //USE_PROFILING
+
 namespace caffe {
 
 INSTANTIATE_CLASS(Layer);
 
+#ifdef USE_PROFILING
+#ifdef LAYER_PERF_STAT
+
+/* current timestamp in us */
+unsigned long get_cur_time(void)
+{
+   struct timespec tm;
+
+   clock_gettime(CLOCK_MONOTONIC_COARSE, &tm);
+
+   return (tm.tv_sec*1000000+tm.tv_nsec/1000);
+}
+
+
+// Forward and backward wrappers. You should implement the cpu and
+// gpu specific implementations instead, and should not change these
+// functions.
+template <typename Dtype>
+Dtype Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  Dtype loss = 0;
+  Reshape(bottom, top);
+
+   saved_top=&top;
+   saved_bottom=&bottom;
+  
+   time_stat_.count++;
+   time_stat_.start=get_cur_time();
+
+  switch (Caffe::mode()) {
+  case Caffe::CPU:
+    Forward_cpu(bottom, top);
+    for (int top_id = 0; top_id < top.size(); ++top_id) {
+      if (!this->loss(top_id)) { continue; }
+      const int count = top[top_id]->count();
+      const Dtype* data = top[top_id]->cpu_data();
+      const Dtype* loss_weights = top[top_id]->cpu_diff();
+      loss += caffe_cpu_dot(count, data, loss_weights);
+    }
+    break;
+  case Caffe::GPU:
+    Forward_gpu(bottom, top);
+#ifndef CPU_ONLY
+    for (int top_id = 0; top_id < top.size(); ++top_id) {
+      if (!this->loss(top_id)) { continue; }
+      const int count = top[top_id]->count();
+      const Dtype* data = top[top_id]->gpu_data();
+      const Dtype* loss_weights = top[top_id]->gpu_diff();
+      Dtype blob_loss = 0;
+      caffe_gpu_dot(count, data, loss_weights, &blob_loss);
+      loss += blob_loss;
+    }
+#endif
+    break;
+  default:
+    LOG(FATAL) << "Unknown caffe mode.";
+  }
+   time_stat_.end=get_cur_time();
+   time_stat_.used=time_stat_.end-time_stat_.start;
+   time_stat_.total+=time_stat_.used;
+  return loss;
+}
+
+#endif
+#endif //USE_PROFILING
+
 }  // namespace caffe
index f14253a..f7bf886 100644 (file)
 #include "caffe/layers/cudnn_tanh_layer.hpp"
 #endif
 
+#ifdef USE_ACL
+#include "caffe/layers/acl_absval_layer.hpp"
+#include "caffe/layers/acl_bnll_layer.hpp"
+#include "caffe/layers/acl_conv_layer.hpp"
+#include "caffe/layers/acl_inner_product_layer.hpp"
+#include "caffe/layers/acl_lrn_layer.hpp"
+#include "caffe/layers/acl_pooling_layer.hpp"
+#include "caffe/layers/acl_relu_layer.hpp"
+#include "caffe/layers/acl_sigmoid_layer.hpp"
+#include "caffe/layers/acl_softmax_layer.hpp"
+#include "caffe/layers/acl_tanh_layer.hpp"
+#endif
+
 #ifdef WITH_PYTHON_LAYER
 #include "caffe/layers/python_layer.hpp"
 #endif
@@ -39,6 +52,9 @@ shared_ptr<Layer<Dtype> > GetConvolutionLayer(
     const LayerParameter& param) {
   ConvolutionParameter conv_param = param.convolution_param();
   ConvolutionParameter_Engine engine = conv_param.engine();
+#ifdef USE_ACL
+  return shared_ptr<Layer<Dtype> >(new ACLConvolutionLayer<Dtype>(param));
+#endif  
 #ifdef USE_CUDNN
   bool use_dilation = false;
   for (int i = 0; i < conv_param.dilation_size(); ++i) {
@@ -77,6 +93,9 @@ REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer);
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
   PoolingParameter_Engine engine = param.pooling_param().engine();
+#ifdef USE_ACL
+  return shared_ptr<Layer<Dtype> >(new ACLPoolingLayer<Dtype>(param));
+#endif  
   if (engine == PoolingParameter_Engine_DEFAULT) {
     engine = PoolingParameter_Engine_CAFFE;
 #ifdef USE_CUDNN
@@ -115,7 +134,9 @@ REGISTER_LAYER_CREATOR(Pooling, GetPoolingLayer);
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetLRNLayer(const LayerParameter& param) {
   LRNParameter_Engine engine = param.lrn_param().engine();
-
+#ifdef USE_ACL
+  return shared_ptr<Layer<Dtype> >(new ACLLRNLayer<Dtype>(param));
+#endif  
   if (engine == LRNParameter_Engine_DEFAULT) {
 #ifdef USE_CUDNN
     engine = LRNParameter_Engine_CUDNN;
@@ -153,6 +174,9 @@ REGISTER_LAYER_CREATOR(LRN, GetLRNLayer);
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetReLULayer(const LayerParameter& param) {
   ReLUParameter_Engine engine = param.relu_param().engine();
+#ifdef USE_ACL
+  return shared_ptr<Layer<Dtype> >(new ACLReLULayer<Dtype>(param));
+#endif    
   if (engine == ReLUParameter_Engine_DEFAULT) {
     engine = ReLUParameter_Engine_CAFFE;
 #ifdef USE_CUDNN
@@ -177,6 +201,9 @@ REGISTER_LAYER_CREATOR(ReLU, GetReLULayer);
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetSigmoidLayer(const LayerParameter& param) {
   SigmoidParameter_Engine engine = param.sigmoid_param().engine();
+#ifdef USE_ACL
+  return shared_ptr<Layer<Dtype> >(new ACLSigmoidLayer<Dtype>(param));
+#endif    
   if (engine == SigmoidParameter_Engine_DEFAULT) {
     engine = SigmoidParameter_Engine_CAFFE;
 #ifdef USE_CUDNN
@@ -201,6 +228,9 @@ REGISTER_LAYER_CREATOR(Sigmoid, GetSigmoidLayer);
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetSoftmaxLayer(const LayerParameter& param) {
   SoftmaxParameter_Engine engine = param.softmax_param().engine();
+#ifdef USE_ACL
+  return shared_ptr<Layer<Dtype> >(new ACLSoftmaxLayer<Dtype>(param));
+#endif    
   if (engine == SoftmaxParameter_Engine_DEFAULT) {
     engine = SoftmaxParameter_Engine_CAFFE;
 #ifdef USE_CUDNN
@@ -225,6 +255,9 @@ REGISTER_LAYER_CREATOR(Softmax, GetSoftmaxLayer);
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetTanHLayer(const LayerParameter& param) {
   TanHParameter_Engine engine = param.tanh_param().engine();
+#ifdef USE_ACL
+  return shared_ptr<Layer<Dtype> >(new ACLTanHLayer<Dtype>(param));
+#endif    
   if (engine == TanHParameter_Engine_DEFAULT) {
     engine = TanHParameter_Engine_CAFFE;
 #ifdef USE_CUDNN
@@ -245,6 +278,33 @@ shared_ptr<Layer<Dtype> > GetTanHLayer(const LayerParameter& param) {
 
 REGISTER_LAYER_CREATOR(TanH, GetTanHLayer);
 
+#ifdef USE_ACL
+// Get AbsVal layer according to engine.
+template <typename Dtype>
+shared_ptr<Layer<Dtype> > GetAbsValLayer(const LayerParameter& param) {
+  return shared_ptr<Layer<Dtype> >(new ACLAbsValLayer<Dtype>(param));
+}
+
+REGISTER_LAYER_CREATOR(AbsVal, GetAbsValLayer);
+
+// Get BNLL layer according to engine.
+template <typename Dtype>
+shared_ptr<Layer<Dtype> > GetBNLLLayer(const LayerParameter& param) {
+  return shared_ptr<Layer<Dtype> >(new ACLBNLLLayer<Dtype>(param));
+}
+
+REGISTER_LAYER_CREATOR(BNLL, GetBNLLLayer);
+
+// Get InnerProduct layer according to engine.
+template <typename Dtype>
+shared_ptr<Layer<Dtype> > GetInnerProductLayer(const LayerParameter& param) {
+  return shared_ptr<Layer<Dtype> >(new ACLInnerProductLayer<Dtype>(param));
+}
+
+REGISTER_LAYER_CREATOR(InnerProduct, GetInnerProductLayer);
+
+#endif // USE_ACL
+
 #ifdef WITH_PYTHON_LAYER
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetPythonLayer(const LayerParameter& param) {
index 855bf0b..65980e4 100644 (file)
@@ -39,6 +39,8 @@ STUB_GPU(AbsValLayer);
 #endif
 
 INSTANTIATE_CLASS(AbsValLayer);
+#ifndef USE_ACL
 REGISTER_LAYER_CLASS(AbsVal);
+#endif
 
 }  // namespace caffe
diff --git a/src/caffe/layers/acl_absval_layer.cpp b/src/caffe/layers/acl_absval_layer.cpp
new file mode 100644 (file)
index 0000000..b0b0304
--- /dev/null
@@ -0,0 +1,64 @@
+#ifdef USE_ACL
+#include <vector>
+
+#include "caffe/layers/acl_absval_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLAbsValLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  AbsValLayer<Dtype>::LayerSetUp(bottom, top);
+  ACLBaseActivationLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_ABSVAL;
+}
+
+template <typename Dtype>
+void ACLAbsValLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top,ActivationLayerInfo::ActivationFunction type){
+    ACLBaseActivationLayer<Dtype>::SetupACLLayer(bottom, top,ActivationLayerInfo::ActivationFunction::ABS);
+}
+
+template <typename Dtype>
+void ACLAbsValLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  AbsValLayer<Dtype>::Reshape(bottom, top);
+  ACLBaseActivationLayer<Dtype>::Reshape(bottom, top);
+}
+
+template <typename Dtype>
+void ACLAbsValLayer<Dtype>::Forward_cpu(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_ABSVAL_INFO);
+#endif //USE_PROFILING
+    if (this->force_bypass_acl_path_) {
+        AbsValLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    ACLBaseActivationLayer<Dtype>::Forward_cpu(bottom,top);
+}
+
+template <typename Dtype>
+void ACLAbsValLayer<Dtype>::Forward_gpu(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_ABSVAL_INFO);
+#endif //USE_PROFILING
+    if (this->force_bypass_acl_path_) {
+        AbsValLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    ACLBaseActivationLayer<Dtype>::Forward_gpu(bottom,top);
+}
+
+template <typename Dtype>
+ACLAbsValLayer<Dtype>::~ACLAbsValLayer() {
+}
+
+INSTANTIATE_CLASS(ACLAbsValLayer);
+
+}  // namespace caffe
+
+#endif  // USE_ACL
diff --git a/src/caffe/layers/acl_base_activation_layer.cpp b/src/caffe/layers/acl_base_activation_layer.cpp
new file mode 100644 (file)
index 0000000..2595a95
--- /dev/null
@@ -0,0 +1,97 @@
+#ifdef USE_ACL
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layers/acl_base_activation_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLBaseActivationLayer<Dtype>::LayerSetUp(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+}
+template <typename Dtype>
+void ACLBaseActivationLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top,ActivationLayerInfo::ActivationFunction type){
+
+    const unsigned int count  = bottom[0]->count();
+    const unsigned int count_ = top[0]->count();
+    TensorShape input_shape(count);
+    TensorShape output_shape(count_);
+    checkreshape(input_shape,Caffe::arm_gpu_mode());
+    if (!this->init_layer_) return;
+    this->init_layer_=false;
+    // Initialize ACL.
+    if (Caffe::arm_gpu_mode()) {
+        new_gpulayer();
+    }else{
+        new_cpulayer();
+    }
+
+    this->force_bypass_acl_path_=false;
+    ActivationLayerInfo act_info(type);
+     
+    if(type== ActivationLayerInfo::ActivationFunction::TANH)
+      act_info=ActivationLayerInfo(type,1.0,1.0);
+
+   
+
+    if (Caffe::arm_gpu_mode()) {
+        Dtype *top_data = top[0]->mutable_gpu_data(); 
+        const Dtype* bottom_data = bottom[0]->gpu_data();
+        this->gpu().input=new_tensor<GPUTensor>(input_shape,(void*)bottom_data);
+        this->gpu().output=new_tensor<GPUTensor>(output_shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->gpu().layer->configure(this->gpu().input,this->gpu().output,act_info);
+    }else{
+        Dtype *top_data = top[0]->mutable_cpu_data(); 
+        const Dtype* bottom_data = bottom[0]->cpu_data();
+        this->cpu().input=new_tensor<CPUTensor>(input_shape,(void*)bottom_data);
+        this->cpu().output=new_tensor<CPUTensor>(output_shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->cpu().layer->configure(this->cpu().input,this->cpu().output,act_info);
+    }
+}
+template <typename Dtype>
+void ACLBaseActivationLayer<Dtype>::Reshape(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+}
+
+template <typename Dtype>
+void ACLBaseActivationLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+    if(Caffe::arm_gpu_mode()){
+        Forward_gpu(bottom, top);
+        return;
+    }        
+    Dtype* top_data = top[0]->mutable_cpu_data();
+    const Dtype* bottom_data = bottom[0]->cpu_data();
+    SetupACLLayer(bottom,top);
+    tensor_mem(this->cpu().input,(void*)(bottom_data));
+    cpu_run();
+    tensor_mem((void*)(top_data),this->cpu().output);
+}
+
+template <typename Dtype>
+void ACLBaseActivationLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+    Dtype* top_data = top[0]->mutable_gpu_data();
+    const Dtype* bottom_data = bottom[0]->gpu_data();
+    SetupACLLayer(bottom,top);
+    tensor_mem(this->gpu().input,(void*)(bottom_data));
+    gpu_run();
+    tensor_mem((void*)(top_data),this->gpu().output);
+}
+
+template <typename Dtype>
+ACLBaseActivationLayer<Dtype>::~ACLBaseActivationLayer() {
+}
+
+INSTANTIATE_CLASS(ACLBaseActivationLayer);
+
+}  // namespace caffe
+#endif  // USE_ACL
diff --git a/src/caffe/layers/acl_bnll_layer.cpp b/src/caffe/layers/acl_bnll_layer.cpp
new file mode 100644 (file)
index 0000000..86f0983
--- /dev/null
@@ -0,0 +1,61 @@
+#ifdef USE_ACL
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layers/acl_bnll_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLBNLLLayer<Dtype>::LayerSetUp(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  BNLLLayer<Dtype>::LayerSetUp(bottom, top);
+  ACLBaseActivationLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_BNLL;
+}
+template <typename Dtype>
+void ACLBNLLLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top, ActivationLayerInfo::ActivationFunction type){
+    ACLBaseActivationLayer<Dtype>::SetupACLLayer(bottom, top,ActivationLayerInfo::ActivationFunction::SOFT_RELU);
+}
+template <typename Dtype>
+void ACLBNLLLayer<Dtype>::Reshape(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  BNLLLayer<Dtype>::Reshape(bottom, top);
+  ACLBaseActivationLayer<Dtype>::Reshape(bottom, top);
+}
+
+template <typename Dtype>
+void ACLBNLLLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_BNLL_INFO);
+#endif //USE_PROFILING
+    if (this->force_bypass_acl_path_) {
+        BNLLLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    ACLBaseActivationLayer<Dtype>::Forward_cpu(bottom,top);
+}
+
+template <typename Dtype>
+void ACLBNLLLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_BNLL_INFO);
+#endif //USE_PROFILING
+    if (this->force_bypass_acl_path_) {
+        BNLLLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    ACLBaseActivationLayer<Dtype>::Forward_gpu(bottom,top);
+}
+
+template <typename Dtype>
+ACLBNLLLayer<Dtype>::~ACLBNLLLayer() {
+}
+
+INSTANTIATE_CLASS(ACLBNLLLayer);
+
+}  // namespace caffe
+#endif  // USE_ACL
diff --git a/src/caffe/layers/acl_conv_layer.cpp b/src/caffe/layers/acl_conv_layer.cpp
new file mode 100644 (file)
index 0000000..02732fb
--- /dev/null
@@ -0,0 +1,218 @@
+#ifdef USE_ACL
+#include <algorithm>
+#include <vector>
+
+#include "caffe/filler.hpp"
+#include "caffe/layers/acl_conv_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLConvolutionLayer<Dtype>::LayerSetUp(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  ConvolutionLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_CONV;
+}
+
+template <typename Dtype>
+void ACLConvolutionLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top){
+
+    TensorShape input_shape((unsigned int)bottom[0]->width(), (unsigned int)bottom[0]->height(),(unsigned int)bottom[0]->channels(),(unsigned int)bottom[0]->num());
+    checkreshape(input_shape,Caffe::arm_gpu_mode());
+    if (!this->init_layer_) return;
+    this->init_layer_=false;
+  // Initialize ACL.
+    if (Caffe::arm_gpu_mode()) {
+        new_gpulayer();
+    }else{
+        new_cpulayer();
+    }
+    this->force_bypass_acl_path_=false;
+    ConvolutionParameter conv_param = this->layer_param_.convolution_param();
+    int stride_x =this->stride_.mutable_cpu_data()[1];
+    int stride_y =this->stride_.mutable_cpu_data()[0];
+    int pad_x=this->pad_.mutable_cpu_data()[1];
+    int pad_y=this->pad_.mutable_cpu_data()[0];
+    unsigned int kernel_x=this->kernel_shape_.mutable_cpu_data()[1];
+    unsigned int kernel_y=this->kernel_shape_.mutable_cpu_data()[0];
+    PadStrideInfo conv_info(stride_x,stride_y,pad_x,pad_y);
+    TensorShape weights_shape(kernel_x,kernel_y,(unsigned int)this->channels_, (unsigned int)this->num_output_);
+    TensorShape biases_shape ((unsigned int)this->num_output_);
+    TensorShape output_shape((unsigned int)top[0]->width(), (unsigned int)top[0]->height(),(unsigned int)top[0]->channels(),(unsigned int)top[0]->num());
+
+    if (Caffe::arm_gpu_mode()) {
+        Dtype *top_data = top[0]->mutable_gpu_data(); 
+        const Dtype* bottom_data = bottom[0]->gpu_data();
+        //[kernel_x, kernel_y, IFM, OFM]
+        this->gpu().weights=new_tensor<GPUTensor>(weights_shape,(void*)(this->blobs_[0].get()->mutable_gpu_data()));
+        tensor_mem(this->gpu().weights,(void*)(this->blobs_[0].get()->mutable_gpu_data()));
+        //[OFM]
+        if (this->bias_term_) {
+            this->gpu().biases=new_tensor<GPUTensor>(biases_shape,(void*)(this->blobs_[1].get()->mutable_gpu_data()));
+            tensor_mem(this->gpu().biases,(void*)(this->blobs_[1].get()->mutable_gpu_data()));
+        }
+
+        //[width, height, IFM]
+        this->gpu().input=new_tensor<GPUTensor>(input_shape,(void*)bottom_data);
+        //[width, height, OFM]
+        this->gpu().output=new_tensor<GPUTensor>(output_shape,(void*)top_data);
+#ifdef USE_PROFILING
+        {
+            logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->gpu().layer->configure(this->gpu().input,this->gpu().weights,this->gpu().biases,this->gpu().output,conv_info);
+#ifdef USE_PROFILING
+        }
+#endif //USE_PROFILING
+#ifdef USE_CONV_CACHE
+        for(int i = 0; i < 16; ++i){
+            fprintf(stderr, "<GPU>check cache[%d]\n", i);
+            if(this->gpu().cache.layer[i] == nullptr){
+                this->gpu().cache.layer[i] = this->gpu().layer;
+                this->gpu().cache.input[i] = this->gpu().input;
+                this->gpu().cache.output[i] = this->gpu().output;
+                this->gpu().cache.weights[i] = this->gpu().weights;
+                this->gpu().cache.biases[i] = this->gpu().biases;
+                break;
+            }
+        }    
+#endif //USE_CONV_CACHE                
+    }else{
+        Dtype *top_data = top[0]->mutable_cpu_data(); 
+        const Dtype* bottom_data = bottom[0]->cpu_data();
+        //[kernel_x, kernel_y, IFM, OFM]
+        this->cpu().weights=new_tensor<CPUTensor>(weights_shape,(void*)(this->blobs_[0].get()->mutable_cpu_data()));
+        tensor_mem(this->cpu().weights,(void*)(this->blobs_[0].get()->mutable_cpu_data()));
+        //[OFM]
+        if (this->bias_term_) {
+            this->cpu().biases=new_tensor<CPUTensor>(biases_shape,(void*)(this->blobs_[1].get()->mutable_cpu_data()));
+            tensor_mem(this->cpu().biases,(void*)(this->blobs_[1].get()->mutable_cpu_data()));
+        }
+
+        //[width, height, IFM]
+        this->cpu().input=new_tensor<CPUTensor>(input_shape,(void*)bottom_data);
+        //[width, height, OFM]
+        this->cpu().output=new_tensor<CPUTensor>(output_shape,(void*)top_data);
+#ifdef USE_PROFILING
+        {
+            logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->cpu().layer->configure(this->cpu().input,this->cpu().weights,this->cpu().biases,this->cpu().output,conv_info);
+#ifdef USE_PROFILING
+        }
+#endif //USE_PROFILING
+#ifdef USE_CONV_CACHE
+        for(int i = 0; i < 16; ++i){
+            fprintf(stderr, "<CPU>check cache[%d]\n", i);
+            if(this->cpu().cache.layer[i] == nullptr){
+                this->cpu().cache.layer[i] = this->cpu().layer;
+                this->cpu().cache.input[i] = this->cpu().input;
+                this->cpu().cache.output[i] = this->cpu().output;
+                this->cpu().cache.weights[i] = this->cpu().weights;
+                this->cpu().cache.biases[i] = this->cpu().biases;
+                break;
+            }
+        }    
+#endif //USE_CONV_CACHE                
+    }
+}
+template <typename Dtype>
+void ACLConvolutionLayer<Dtype>::Reshape(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  ConvolutionLayer<Dtype>::Reshape(bottom, top);
+}
+
+template <typename Dtype>
+void ACLConvolutionLayer<Dtype>::Forward_cpu(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+    if(Caffe::arm_gpu_mode()){
+        Forward_gpu(bottom, top);
+        return;
+    }         
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_CONV_INFO);
+#endif //USE_PROFILING
+    if (this->force_bypass_acl_path_|| this->group_!=1) {
+        ConvolutionLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+
+    ConvolutionParameter conv_param = this->layer_param_.convolution_param();
+    if (conv_param.kernel_size_size()>2 || this->num_spatial_axes_>2 || this->num_spatial_axes_==0) {
+        ConvolutionLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    /* check dilation */
+    int dilated=0;
+
+    for(int i=0;i<this->num_spatial_axes_;i++)
+    {
+        const int *p=this->dilation_.cpu_data();
+
+        if(p[i]!=1) 
+           dilated=1;
+    }
+    if(dilated) {
+        ConvolutionLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+     }
+    
+    SetupACLLayer(bottom,top);
+    for (int i = 0; i < bottom.size(); ++i) {
+        const Dtype* bottom_data = bottom[i]->cpu_data();
+        Dtype* top_data = top[i]->mutable_cpu_data();
+        tensor_mem(this->cpu().input,(void*)bottom_data);
+        cpu_run();
+        tensor_mem((void*)top_data,this->cpu().output);
+  }
+}
+
+template <typename Dtype>
+void ACLConvolutionLayer<Dtype>::Forward_gpu(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_CONV_INFO);
+#endif //USE_PROFILING
+    ConvolutionParameter conv_param = this->layer_param_.convolution_param();
+    if (this->force_bypass_acl_path_|| this->group_!=1) {
+        ConvolutionLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    if (conv_param.kernel_size_size()>2 || this->num_spatial_axes_>2 ) {
+        ConvolutionLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    /* check dilation */
+    int dilated=0;
+
+    for(int i=0;i<this->num_spatial_axes_;i++)
+    {
+        const int *p=this->dilation_.gpu_data();
+
+        if(p[i]!=1) 
+           dilated=1;
+    }
+
+    if(dilated) {
+        ConvolutionLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+     }
+    SetupACLLayer(bottom,top);
+    for (int i = 0; i < bottom.size(); ++i) {
+      const Dtype* bottom_data = bottom[i]->gpu_data();
+      Dtype* top_data = top[i]->mutable_gpu_data();
+      tensor_mem(this->gpu().input,(void*)bottom_data);
+      gpu_run();
+      tensor_mem((void*)top_data,this->gpu().output);
+    }
+}
+
+template <typename Dtype>
+ACLConvolutionLayer<Dtype>::~ACLConvolutionLayer() {
+}
+
+INSTANTIATE_CLASS(ACLConvolutionLayer);
+
+}   // namespace caffe
+#endif  // USE_ACL
diff --git a/src/caffe/layers/acl_inner_product_layer.cpp b/src/caffe/layers/acl_inner_product_layer.cpp
new file mode 100644 (file)
index 0000000..6e7a46e
--- /dev/null
@@ -0,0 +1,131 @@
+#ifdef USE_ACL
+#include <vector>
+
+#include "caffe/filler.hpp"
+#include "caffe/layers/acl_inner_product_layer.hpp"
+#include "caffe/util/math_functions.hpp"
+
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLInnerProductLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  InnerProductLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_FC;
+}
+template <typename Dtype>
+void ACLInnerProductLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top){
+
+    TensorShape weights_shape_t((unsigned int)this->K_, (unsigned int)this->N_);
+    TensorShape weights_shape((unsigned int)this->N_, (unsigned int)this->K_);
+    TensorShape biases_shape((unsigned int)this->N_);
+    TensorShape input_shape((unsigned int)this->K_, (unsigned int)this->M_);
+    TensorShape output_shape((unsigned int)this->N_, (unsigned int)this->M_);
+    checkreshape(input_shape,Caffe::arm_gpu_mode());
+    if (!this->init_layer_) return;
+    this->init_layer_=false;
+    // Initialize ACL.
+    if (Caffe::arm_gpu_mode()) {
+        new_gpulayer();
+    }else{
+        new_cpulayer();
+    }
+
+    bool transpose = !this->layer_param_.inner_product_param().transpose();
+    this->force_bypass_acl_path_ = false; 
+    if (Caffe::arm_gpu_mode()) {
+        Dtype *top_data = top[0]->mutable_gpu_data(); 
+        const Dtype* bottom_data = bottom[0]->gpu_data();
+        if (transpose) {
+            this->gpu().weights=new_tensor<GPUTensor>(weights_shape_t,(void*)(this->blobs_[0].get()->mutable_gpu_data()));
+        }else{
+            this->gpu().weights=new_tensor<GPUTensor>(weights_shape,(void*)(this->blobs_[0].get()->mutable_gpu_data()));
+        }
+        tensor_mem(this->gpu().weights,(void*)(this->blobs_[0].get()->mutable_gpu_data()));
+        if (this->bias_term_) {
+            this->gpu().biases=new_tensor<GPUTensor>(biases_shape,(void*)(this->blobs_[1].get()->mutable_gpu_data()));
+            tensor_mem(this->gpu().biases,(void*)(this->blobs_[1].get()->mutable_gpu_data()));
+        }
+        this->gpu().input=new_tensor<GPUTensor>(input_shape,(void*)bottom_data);
+        this->gpu().output=new_tensor<GPUTensor>(output_shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->gpu().layer->configure(this->gpu().input,this->gpu().weights,this->gpu().biases,this->gpu().output,transpose);
+    }else{
+        Dtype *top_data = top[0]->mutable_cpu_data(); 
+        const Dtype* bottom_data = bottom[0]->cpu_data();
+        if (transpose) {
+            this->cpu().weights=new_tensor<CPUTensor>(weights_shape_t,(void*)(this->blobs_[0].get()->mutable_cpu_data()));
+        }else{
+            this->cpu().weights=new_tensor<CPUTensor>(weights_shape,(void*)(this->blobs_[0].get()->mutable_cpu_data()));
+        }
+        tensor_mem(this->cpu().weights,(void*)(this->blobs_[0].get()->mutable_cpu_data()));
+        if (this->bias_term_) {
+            this->cpu().biases=new_tensor<CPUTensor>(biases_shape,(void*)(this->blobs_[1].get()->mutable_cpu_data()));
+            tensor_mem(this->cpu().biases,(void*)(this->blobs_[1].get()->mutable_cpu_data()));
+        }
+        this->cpu().input=new_tensor<CPUTensor>(input_shape,(void*)bottom_data);
+        this->cpu().output=new_tensor<CPUTensor>(output_shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->cpu().layer->configure(this->cpu().input,this->cpu().weights,this->cpu().biases,this->cpu().output,transpose);
+    }
+}
+template <typename Dtype>
+void ACLInnerProductLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  InnerProductLayer<Dtype>::Reshape(bottom, top);
+}
+
+template <typename Dtype>
+void ACLInnerProductLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  if(Caffe::arm_gpu_mode()){
+       Forward_gpu(bottom, top);
+       return;
+  }         
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_FC_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_) {
+       InnerProductLayer<Dtype>::Forward_cpu(bottom,top);
+       return;
+  }
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  SetupACLLayer(bottom,top);
+  tensor_mem(this->cpu().input,(void*)(bottom_data));
+  cpu_run();
+  tensor_mem((void*)(top_data),this->cpu().output);
+}
+
+template <typename Dtype>
+void ACLInnerProductLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_FC_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_) {
+        InnerProductLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+  }
+  Dtype* top_data = top[0]->mutable_gpu_data();
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  SetupACLLayer(bottom,top);
+  tensor_mem(this->gpu().input,(void*)(bottom_data));
+  gpu_run();
+  tensor_mem((void*)(top_data),this->gpu().output);
+}
+
+template <typename Dtype>
+ACLInnerProductLayer<Dtype>::~ACLInnerProductLayer() {
+}
+
+INSTANTIATE_CLASS(ACLInnerProductLayer);
+
+}  // namespace caffe
+#endif // USE_ACL
diff --git a/src/caffe/layers/acl_lrn_layer.cpp b/src/caffe/layers/acl_lrn_layer.cpp
new file mode 100644 (file)
index 0000000..9c03cad
--- /dev/null
@@ -0,0 +1,144 @@
+#ifdef USE_ACL
+#include <vector>
+
+#include "caffe/layers/acl_lrn_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLLRNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  LRNLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_LRN;
+}
+template <typename Dtype>
+void ACLLRNLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top){
+
+    TensorShape shape((unsigned int)this->width_,(unsigned int)this->height_, (unsigned int)this->channels_);
+    checkreshape(shape,Caffe::arm_gpu_mode());
+    if (!this->init_layer_) return;
+    // Initialize ACL.
+    if (Caffe::arm_gpu_mode()) {
+        new_gpulayer();
+    }else{
+        new_cpulayer();
+    }
+
+    //this->force_bypass_acl_path_=false;
+    NormalizationLayerInfo *norm_info;
+    if(this->layer_param_.lrn_param().norm_region() == LRNParameter_NormRegion_WITHIN_CHANNEL)
+       norm_info=new NormalizationLayerInfo(NormType::IN_MAP, this->size_, this->alpha_, this->beta_, this->k_);
+    else
+       norm_info=new NormalizationLayerInfo(NormType::CROSS_MAP, this->size_, this->alpha_, this->beta_, this->k_);
+
+    if (Caffe::arm_gpu_mode()) {
+        Dtype *top_data = top[0]->mutable_gpu_data(); 
+        const Dtype* bottom_data = bottom[0]->gpu_data();
+        this->gpu().input=new_tensor<GPUTensor>(shape,(void*)bottom_data);
+        this->gpu().output=new_tensor<GPUTensor>(shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->gpu().layer->configure(this->gpu().input,this->gpu().output,*norm_info);
+    }else{
+        Dtype *top_data = top[0]->mutable_cpu_data(); 
+        const Dtype* bottom_data = bottom[0]->cpu_data();
+        this->cpu().input=new_tensor<CPUTensor>(shape,(void*)bottom_data);
+        this->cpu().output=new_tensor<CPUTensor>(shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->cpu().layer->configure(this->cpu().input,this->cpu().output,*norm_info);
+    }
+    delete norm_info;
+}
+template <typename Dtype>
+void ACLLRNLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  LRNLayer<Dtype>::Reshape(bottom, top);
+  return;
+}
+
+template <typename Dtype>
+void ACLLRNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  if(Caffe::arm_gpu_mode()){
+    Forward_gpu(bottom, top);
+    return;
+  }         
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_LRN_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_ || this->layer_param_.lrn_param().norm_region() == LRNParameter_NormRegion_WITHIN_CHANNEL) {
+      LRNLayer<Dtype>::Forward_cpu(bottom,top);
+      return;
+  }
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  SetupACLLayer(bottom,top);
+  switch (this->layer_param_.lrn_param().norm_region()) {
+  case LRNParameter_NormRegion_ACROSS_CHANNELS:
+      for (int n = 0; n < this->num_; ++n) {
+          tensor_mem(this->cpu().input,(void*)(bottom_data+ bottom[0]->offset(n)));
+          cpu_run();
+          tensor_mem((void*)(top_data + top[0]->offset(n)),this->cpu().output);
+      }
+    break;
+  case LRNParameter_NormRegion_WITHIN_CHANNEL:
+      for (int n = 0; n < bottom[0]->num(); ++n) {
+            tensor_mem(this->cpu().input,(void*)(bottom_data));
+            cpu_run();
+            tensor_mem((void*)(top_data),this->cpu().output);
+            bottom_data += bottom[0]->offset(0, 1);
+            top_data += top[0]->offset(0, 1);
+      }
+    break;
+  default:
+    LOG(FATAL) << "Unknown normalization region.";
+  }
+}
+
+template <typename Dtype>
+void ACLLRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_LRN_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_) {
+       LRNLayer<Dtype>::Forward_cpu(bottom,top);
+       return;
+  }
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+  SetupACLLayer(bottom,top);
+  switch (this->layer_param_.lrn_param().norm_region()) {
+  case LRNParameter_NormRegion_ACROSS_CHANNELS:
+      for (int n = 0; n < this->num_; ++n) {
+          tensor_mem(this->gpu().input,(void*)(bottom_data+ bottom[0]->offset(n)));
+          gpu_run();
+          tensor_mem((void*)(top_data + top[0]->offset(n)),this->gpu().output);
+      }
+    break;
+  case LRNParameter_NormRegion_WITHIN_CHANNEL:
+      for (int n = 0; n < bottom[0]->num(); ++n) {
+            tensor_mem(this->gpu().input,(void*)(bottom_data));
+            gpu_run();
+            tensor_mem((void*)(top_data),this->gpu().output);
+            bottom_data += bottom[0]->offset(0, 1);
+            top_data += top[0]->offset(0, 1);
+      }
+    break;
+  default:
+    LOG(FATAL) << "Unknown normalization region.";
+  }
+}
+
+template <typename Dtype>
+ACLLRNLayer<Dtype>::~ACLLRNLayer() {
+}
+
+INSTANTIATE_CLASS(ACLLRNLayer);
+
+}   // namespace caffe
+#endif  // USE_ACL
diff --git a/src/caffe/layers/acl_pooling_layer.cpp b/src/caffe/layers/acl_pooling_layer.cpp
new file mode 100644 (file)
index 0000000..951de3e
--- /dev/null
@@ -0,0 +1,150 @@
+#ifdef USE_ACL
+#include <vector>
+
+#include "caffe/layers/acl_pooling_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  PoolingLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_POOLING;
+}
+template <typename Dtype>
+void ACLPoolingLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top){
+
+    TensorShape in_shape ((unsigned int)this->width_, (unsigned int)this->height_);
+    TensorShape out_shape((unsigned int)this->pooled_width_, (unsigned int)this->pooled_height_);
+    checkreshape(in_shape,Caffe::arm_gpu_mode());
+    if (!this->init_layer_) return;
+    this->init_layer_=false;
+    // Initialize ACL.
+    if (Caffe::arm_gpu_mode()) {
+        new_gpulayer();
+    }else{
+        new_cpulayer();
+    }
+
+    this->force_bypass_acl_path_=false;
+    PoolingLayerInfo *pool_info;
+    if(this->layer_param_.pooling_param().pool()==PoolingParameter_PoolMethod_MAX)
+       pool_info=new PoolingLayerInfo(PoolingType::MAX, this->kernel_w_, PadStrideInfo(this->stride_w_,this->stride_h_,this->pad_w_,this->pad_h_,DimensionRoundingType::CEIL));
+    else
+       pool_info=new PoolingLayerInfo(PoolingType::AVG, this->kernel_w_, PadStrideInfo(this->stride_w_,this->stride_h_,this->pad_w_,this->pad_h_,DimensionRoundingType::CEIL));
+
+    if (Caffe::arm_gpu_mode()) {
+        Dtype *top_data = top[0]->mutable_gpu_data(); 
+        const Dtype* bottom_data = bottom[0]->gpu_data();
+        this->gpu().input=new_tensor<GPUTensor>(in_shape,(void*)bottom_data);
+        this->gpu().output=new_tensor<GPUTensor>(out_shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->gpu().layer->configure(this->gpu().input,this->gpu().output,*pool_info);
+    }else{
+        Dtype *top_data = top[0]->mutable_cpu_data(); 
+        const Dtype* bottom_data = bottom[0]->cpu_data();
+        this->cpu().input=new_tensor<CPUTensor>(in_shape,(void*)bottom_data);
+        this->cpu().output=new_tensor<CPUTensor>(out_shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->cpu().layer->configure(this->cpu().input,this->cpu().output,*pool_info);
+    }
+    delete pool_info;
+}
+template <typename Dtype>
+void ACLPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  PoolingLayer<Dtype>::Reshape(bottom, top);
+
+}
+
+template <typename Dtype>
+void ACLPoolingLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  if(Caffe::arm_gpu_mode()){
+      Forward_gpu(bottom, top);
+      return;
+  }         
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_POOLING_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_) {
+      PoolingLayer<Dtype>::Forward_cpu(bottom,top);
+      return;
+  }
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  if (this->layer_param_.pooling_param().pool()!=PoolingParameter_PoolMethod_MAX && 
+      this->layer_param_.pooling_param().pool()!=PoolingParameter_PoolMethod_AVE) {
+      PoolingLayer<Dtype>::Forward_cpu(bottom,top);
+      return ;
+  }
+  if (this->kernel_h_!=this->kernel_w_ || top.size()>1) {
+      PoolingLayer<Dtype>::Forward_cpu(bottom,top);
+      return ;
+  }
+  if (this->kernel_h_!=2 && this->kernel_h_!=3) {
+      PoolingLayer<Dtype>::Forward_cpu(bottom,top);
+      return ;
+  }
+  SetupACLLayer(bottom,top);
+  for (int n = 0; n < bottom[0]->num(); ++n) {
+    for (int c = 0; c < this->channels_; ++c) {
+        tensor_mem(this->cpu().input,(void*)(bottom_data));
+        cpu_run();
+        tensor_mem((void*)(top_data),this->cpu().output);
+        bottom_data += bottom[0]->offset(0, 1);
+        top_data += top[0]->offset(0, 1);
+    }
+  }
+}
+
+template <typename Dtype>
+void ACLPoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_POOLING_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_) {
+      PoolingLayer<Dtype>::Forward_cpu(bottom,top);
+      return;
+  }
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+  if (this->layer_param_.pooling_param().pool()!=PoolingParameter_PoolMethod_MAX && 
+      this->layer_param_.pooling_param().pool()!=PoolingParameter_PoolMethod_AVE) {
+      PoolingLayer<Dtype>::Forward_cpu(bottom,top);
+      return ;
+  }
+  if (this->kernel_h_!=this->kernel_w_) {
+      PoolingLayer<Dtype>::Forward_cpu(bottom,top);
+      return ;
+  }
+  if (this->kernel_h_!=2 && this->kernel_h_!=3) {
+      PoolingLayer<Dtype>::Forward_cpu(bottom,top);
+      return ;
+  }
+  SetupACLLayer(bottom,top);
+  for (int n = 0; n < bottom[0]->num(); ++n) {
+    for (int c = 0; c < this->channels_; ++c) {
+        tensor_mem(this->gpu().input,(void*)(bottom_data));
+        gpu_run();
+        tensor_mem((void*)(top_data),this->gpu().output);
+        bottom_data += bottom[0]->offset(0, 1);
+        top_data += top[0]->offset(0, 1);
+    }
+  }
+}
+
+template <typename Dtype>
+ACLPoolingLayer<Dtype>::~ACLPoolingLayer() {
+}
+
+INSTANTIATE_CLASS(ACLPoolingLayer);
+
+}   // namespace caffe
+#endif  // USE_ACL
diff --git a/src/caffe/layers/acl_relu_layer.cpp b/src/caffe/layers/acl_relu_layer.cpp
new file mode 100644 (file)
index 0000000..0319453
--- /dev/null
@@ -0,0 +1,70 @@
+#ifdef USE_ACL
+#include <vector>
+
+#include "caffe/layers/acl_relu_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  ReLULayer<Dtype>::LayerSetUp(bottom, top);
+  ACLBaseActivationLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_RELU;
+}
+template <typename Dtype>
+void ACLReLULayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top){
+    ACLBaseActivationLayer<Dtype>::SetupACLLayer(bottom, top,ActivationLayerInfo::ActivationFunction::RELU);
+}
+template <typename Dtype>
+void ACLReLULayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  ReLULayer<Dtype>::Reshape(bottom, top);
+  ACLBaseActivationLayer<Dtype>::Reshape(bottom, top);
+}
+
+template <typename Dtype>
+void ACLReLULayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_RELU_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_) {
+      ReLULayer<Dtype>::Forward_cpu(bottom,top);
+      return;
+  }
+  // Fallback to standard Caffe for leaky ReLU.
+  if (ReLULayer<Dtype>::layer_param_.relu_param().negative_slope() != 0) {
+    ReLULayer<Dtype>::Forward_cpu(bottom, top);
+    return;
+  }
+  ACLBaseActivationLayer<Dtype>::Forward_cpu(bottom,top);
+}
+
+template <typename Dtype>
+void ACLReLULayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_RELU_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_) {
+      ReLULayer<Dtype>::Forward_cpu(bottom,top);
+      return;
+  }
+  // Fallback to standard Caffe for leaky ReLU.
+  if (ReLULayer<Dtype>::layer_param_.relu_param().negative_slope() != 0) {
+    ReLULayer<Dtype>::Forward_cpu(bottom, top);
+       return;
+  }
+  ACLBaseActivationLayer<Dtype>::Forward_gpu(bottom,top);
+}
+
+template <typename Dtype>
+ACLReLULayer<Dtype>::~ACLReLULayer() {
+}
+
+INSTANTIATE_CLASS(ACLReLULayer);
+
+}  // namespace caffe
+#endif // USE_ACL
diff --git a/src/caffe/layers/acl_sigmoid_layer.cpp b/src/caffe/layers/acl_sigmoid_layer.cpp
new file mode 100644 (file)
index 0000000..eac1565
--- /dev/null
@@ -0,0 +1,61 @@
+#ifdef USE_ACL
+#include <vector>
+
+#include "caffe/layers/acl_sigmoid_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLSigmoidLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  SigmoidLayer<Dtype>::LayerSetUp(bottom, top);
+  ACLBaseActivationLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_SIGMOID;
+}
+
+template <typename Dtype>
+void ACLSigmoidLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top,ActivationLayerInfo::ActivationFunction type){
+    ACLBaseActivationLayer<Dtype>::SetupACLLayer(bottom, top,ActivationLayerInfo::ActivationFunction::LOGISTIC);
+}
+template <typename Dtype>
+void ACLSigmoidLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  SigmoidLayer<Dtype>::Reshape(bottom, top);
+  ACLBaseActivationLayer<Dtype>::Reshape(bottom, top);
+}
+
+template <typename Dtype>
+void ACLSigmoidLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_SIGMOID_INFO);
+#endif //USE_PROFILING
+    if (this->force_bypass_acl_path_) {
+        SigmoidLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    ACLBaseActivationLayer<Dtype>::Forward_cpu(bottom,top);
+}
+
+template <typename Dtype>
+void ACLSigmoidLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_SIGMOID_INFO);
+#endif //USE_PROFILING
+    if (this->force_bypass_acl_path_) {
+        SigmoidLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+    }
+    ACLBaseActivationLayer<Dtype>::Forward_gpu(bottom,top);
+}
+
+template <typename Dtype>
+ACLSigmoidLayer<Dtype>::~ACLSigmoidLayer() {
+}
+
+INSTANTIATE_CLASS(ACLSigmoidLayer);
+
+}  // namespace caffe
+#endif // USE_ACL
diff --git a/src/caffe/layers/acl_softmax_layer.cpp b/src/caffe/layers/acl_softmax_layer.cpp
new file mode 100644 (file)
index 0000000..d32460b
--- /dev/null
@@ -0,0 +1,117 @@
+#ifdef USE_ACL
+#include <vector>
+#include "caffe/layers/acl_softmax_layer.hpp"
+#include <unistd.h>
+
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLSoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  SoftmaxLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_SOFTMAX;
+}
+template <typename Dtype>
+void ACLSoftmaxLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top){
+
+    unsigned int channels = bottom[0]->shape(this->softmax_axis_); 
+    TensorShape shape(channels*this->inner_num_);
+    checkreshape(shape,Caffe::arm_gpu_mode());
+    if (!this->init_layer_) return;
+    this->init_layer_=false;
+
+    // Initialize ACL.
+    if (Caffe::arm_gpu_mode()) {
+        new_gpulayer();
+    }else{
+        new_cpulayer();
+    }
+
+    //this->force_bypass_acl_path_=false;
+    if (Caffe::arm_gpu_mode()) {
+        Dtype *top_data = top[0]->mutable_gpu_data(); 
+        const Dtype* bottom_data = bottom[0]->gpu_data();
+        this->gpu().input=new_tensor<GPUTensor>(shape,(void*)bottom_data);
+        this->gpu().output=new_tensor<GPUTensor>(shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->gpu().layer->configure(this->gpu().input,this->gpu().output);
+    }else{
+        Dtype *top_data = top[0]->mutable_cpu_data(); 
+        const Dtype* bottom_data = bottom[0]->cpu_data();
+        this->cpu().input=new_tensor<CPUTensor>(shape,(void*)bottom_data);
+        this->cpu().output=new_tensor<CPUTensor>(shape,(void*)top_data);
+#ifdef USE_PROFILING
+        logtime_util log_time(ACL_CONFIG_INFO);
+#endif //USE_PROFILING
+        this->cpu().layer->configure(this->cpu().input,this->cpu().output);
+    }
+}
+template <typename Dtype>
+void ACLSoftmaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  SoftmaxLayer<Dtype>::Reshape(bottom, top);
+}
+
+template <typename Dtype>
+void ACLSoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  if(Caffe::arm_gpu_mode()){
+      Forward_gpu(bottom, top);
+      return;
+  }         
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_SOFTMAX_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_ || this->inner_num_>1) {
+      SoftmaxLayer<Dtype>::Forward_cpu(bottom,top);
+      return ;
+  }
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  SetupACLLayer(bottom,top);
+
+  int channels = bottom[0]->shape(this->softmax_axis_);
+
+  for (int i = 0; i < this->outer_num_; ++i) {
+      tensor_mem(this->cpu().input,(void*)(bottom_data));
+      cpu_run();
+      tensor_mem((void*)(top_data),this->cpu().output);
+      top_data += channels;
+      bottom_data += channels;
+  }
+}
+
+template <typename Dtype>
+void ACLSoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_SOFTMAX_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_|| this->inner_num_>1) {
+        SoftmaxLayer<Dtype>::Forward_cpu(bottom,top);
+        return;
+  }
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+  SetupACLLayer(bottom,top);
+  for (int i = 0; i < this->outer_num_; ++i) {
+      tensor_mem(this->gpu().input,(void*)(bottom_data));
+      gpu_run();
+      tensor_mem((void*)(top_data),this->gpu().output);
+      top_data += this->inner_num_;
+      bottom_data += this->inner_num_;
+  }
+}
+
+template <typename Dtype>
+ACLSoftmaxLayer<Dtype>::~ACLSoftmaxLayer() {
+}
+
+INSTANTIATE_CLASS(ACLSoftmaxLayer);
+}  // namespace caffe
+
+#endif  // USE_ACL
diff --git a/src/caffe/layers/acl_tanh_layer.cpp b/src/caffe/layers/acl_tanh_layer.cpp
new file mode 100644 (file)
index 0000000..a1bb632
--- /dev/null
@@ -0,0 +1,63 @@
+#ifdef USE_ACL
+#include <vector>
+
+#include "caffe/layers/acl_tanh_layer.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void ACLTanHLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  TanHLayer<Dtype>::LayerSetUp(bottom, top);
+  ACLBaseActivationLayer<Dtype>::LayerSetUp(bottom, top);
+  this->force_bypass_acl_path_= bypass_acl_class_layer & FLAGS_ENABLE_ACL_TANH;
+}
+
+template <typename Dtype>
+void ACLTanHLayer<Dtype>::SetupACLLayer(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top, ActivationLayerInfo::ActivationFunction type){
+    ACLBaseActivationLayer<Dtype>::SetupACLLayer(bottom, top,ActivationLayerInfo::ActivationFunction::TANH);
+}
+
+template <typename Dtype>
+void ACLTanHLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  TanHLayer<Dtype>::Reshape(bottom, top);
+  ACLBaseActivationLayer<Dtype>::Reshape(bottom, top);
+}
+
+template <typename Dtype>
+void ACLTanHLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+  logtime_util log_time(ACL_TANH_INFO);
+#endif //USE_PROFILING
+  if (this->force_bypass_acl_path_) {
+       TanHLayer<Dtype>::Forward_cpu(bottom,top);
+       return;
+  }
+  ACLBaseActivationLayer<Dtype>::Forward_cpu(bottom,top);
+}
+
+template <typename Dtype>
+void ACLTanHLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+#ifdef USE_PROFILING
+    logtime_util log_time(ACL_TANH_INFO);
+#endif //USE_PROFILING
+    if (this->force_bypass_acl_path_) {
+         TanHLayer<Dtype>::Forward_cpu(bottom,top);
+         return;
+    }
+    ACLBaseActivationLayer<Dtype>::Forward_gpu(bottom,top);
+}
+
+template <typename Dtype>
+ACLTanHLayer<Dtype>::~ACLTanHLayer() {
+}
+
+INSTANTIATE_CLASS(ACLTanHLayer);
+
+}  // namespace caffe
+
+#endif  // USE_ACL
index 448d86d..5293373 100644 (file)
@@ -42,6 +42,8 @@ STUB_GPU(BNLLLayer);
 #endif
 
 INSTANTIATE_CLASS(BNLLLayer);
+#ifndef USE_ACL
 REGISTER_LAYER_CLASS(BNLL);
+#endif
 
 }  // namespace caffe
index 00716a9..7668854 100644 (file)
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
 /*
 TODO:
 - load file in a separate thread ("prefetch")
@@ -184,3 +185,4 @@ INSTANTIATE_CLASS(HDF5DataLayer);
 REGISTER_LAYER_CLASS(HDF5Data);
 
 }  // namespace caffe
+#endif  // USE_HDF5
index 33eebd4..70cd9f3 100644 (file)
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
 /*
 TODO:
 - only load parts of the file, in accordance with a prototxt param "max_mem"
@@ -34,3 +35,4 @@ void HDF5DataLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 INSTANTIATE_LAYER_GPU_FUNCS(HDF5DataLayer);
 
 }  // namespace caffe
+#endif  // USE_HDF5
index f8f1edc..28c453a 100644 (file)
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
 #include <vector>
 
 #include "hdf5.h"
@@ -72,3 +73,4 @@ INSTANTIATE_CLASS(HDF5OutputLayer);
 REGISTER_LAYER_CLASS(HDF5Output);
 
 }  // namespace caffe
+#endif  // USE_HDF5
index c1685cd..891aea0 100644 (file)
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
 #include <vector>
 
 #include "hdf5.h"
@@ -37,3 +38,4 @@ void HDF5OutputLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
 INSTANTIATE_LAYER_GPU_FUNCS(HDF5OutputLayer);
 
 }  // namespace caffe
+#endif  // USE_HDF5
index e65349f..be3791e 100644 (file)
@@ -145,6 +145,8 @@ STUB_GPU(InnerProductLayer);
 #endif
 
 INSTANTIATE_CLASS(InnerProductLayer);
+#ifndef USE_ACL
 REGISTER_LAYER_CLASS(InnerProduct);
+#endif
 
 }  // namespace caffe
index 353c2f9..7bf33e1 100644 (file)
@@ -5,7 +5,9 @@
 #include <utility>
 #include <vector>
 
+#ifdef USE_HDF5
 #include "hdf5.h"
+#endif  // USE_HDF5
 
 #include "caffe/common.hpp"
 #include "caffe/layer.hpp"
@@ -769,7 +771,12 @@ void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
 
 template <typename Dtype>
 void Net<Dtype>::CopyTrainedLayersFrom(const string trained_filename) {
+#ifdef USE_HDF5
   if (H5Fis_hdf5(trained_filename.c_str())) {
+#else
+  if (trained_filename.size() >= 3 &&
+      trained_filename.compare(trained_filename.size() - 3, 3, ".h5") == 0) {
+#endif
     CopyTrainedLayersFromHDF5(trained_filename);
   } else {
     CopyTrainedLayersFromBinaryProto(trained_filename);
@@ -786,6 +793,7 @@ void Net<Dtype>::CopyTrainedLayersFromBinaryProto(
 
 template <typename Dtype>
 void Net<Dtype>::CopyTrainedLayersFromHDF5(const string trained_filename) {
+#ifdef USE_HDF5
   hid_t file_hid = H5Fopen(trained_filename.c_str(), H5F_ACC_RDONLY,
                            H5P_DEFAULT);
   CHECK_GE(file_hid, 0) << "Couldn't open " << trained_filename;
@@ -832,6 +840,10 @@ void Net<Dtype>::CopyTrainedLayersFromHDF5(const string trained_filename) {
   }
   H5Gclose(data_hid);
   H5Fclose(file_hid);
+#else
+  LOG(FATAL) << "CopyTrainedLayersFromHDF5 requires hdf5;"
+             << " compile with USE_HDF5.";
+#endif  // USE_HDF5
 }
 
 template <typename Dtype>
@@ -848,6 +860,7 @@ void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) const {
 
 template <typename Dtype>
 void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
+#ifdef USE_HDF5
   hid_t file_hid = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, H5P_DEFAULT,
       H5P_DEFAULT);
   CHECK_GE(file_hid, 0)
@@ -901,6 +914,9 @@ void Net<Dtype>::ToHDF5(const string& filename, bool write_diff) const {
     H5Gclose(diff_hid);
   }
   H5Fclose(file_hid);
+#else
+  LOG(FATAL) << "ToHDF5 requires hdf5; compile with USE_HDF5.";
+#endif  // USE_HDF5
 }
 
 template <typename Dtype>
index ad6abe5..725602a 100644 (file)
@@ -278,6 +278,7 @@ void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
 template <typename Dtype>
 void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
     const string& model_filename) {
+#ifdef USE_HDF5
   string snapshot_filename =
       Solver<Dtype>::SnapshotFilename(".solverstate.h5");
   LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
@@ -299,6 +300,10 @@ void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
   }
   H5Gclose(history_hid);
   H5Fclose(file_hid);
+#else
+  LOG(FATAL) << "SnapshotSolverStateToHDF5 requires hdf5;"
+             << " compile with USE_HDF5.";
+#endif  // USE_HDF5
 }
 
 template <typename Dtype>
@@ -323,6 +328,7 @@ void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
 
 template <typename Dtype>
 void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
+#ifdef USE_HDF5
   hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
   CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
   this->iter_ = hdf5_load_int(file_hid, "iter");
@@ -344,6 +350,10 @@ void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
   }
   H5Gclose(history_hid);
   H5Fclose(file_hid);
+#else
+  LOG(FATAL) << "RestoreSolverStateFromHDF5 requires hdf5;"
+             << " compile with USE_HDF5.";
+#endif  // USE_HDF5
 }
 
 INSTANTIATE_CLASS(SGDSolver);
index 88d9b78..d74ad31 100644 (file)
@@ -54,8 +54,10 @@ inline void SyncedMemory::to_cpu() {
     caffe_gpu_memcpy(size_, gpu_ptr_, cpu_ptr_);
     head_ = SYNCED;
 #else
+#ifndef USE_ACL
     NO_GPU;
 #endif
+#endif
     break;
   case HEAD_AT_CPU:
   case SYNCED:
@@ -113,9 +115,14 @@ const void* SyncedMemory::gpu_data() {
   to_gpu();
   return (const void*)gpu_ptr_;
 #else
+#ifdef USE_ACL
+  to_cpu();
+  return (const void*)cpu_ptr_;
+#else
   NO_GPU;
   return NULL;
 #endif
+#endif
 }
 
 void SyncedMemory::set_gpu_data(void* data) {
@@ -129,8 +136,14 @@ void SyncedMemory::set_gpu_data(void* data) {
   head_ = HEAD_AT_GPU;
   own_gpu_data_ = false;
 #else
+#ifdef USE_ACL
+  gpu_ptr_ = data;
+  head_ = HEAD_AT_GPU;
+  own_gpu_data_ = false;
+#else
   NO_GPU;
 #endif
+#endif
 }
 
 void* SyncedMemory::mutable_cpu_data() {
@@ -147,9 +160,15 @@ void* SyncedMemory::mutable_gpu_data() {
   head_ = HEAD_AT_GPU;
   return gpu_ptr_;
 #else
+#ifdef USE_ACL
+  to_cpu();
+  head_ = HEAD_AT_GPU;
+  return cpu_ptr_;
+#else
   NO_GPU;
   return NULL;
 #endif
+#endif
 }
 
 #ifndef CPU_ONLY
index 2bc2de1..3b0139d 100644 (file)
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
 #include <string>
 #include <vector>
 
@@ -121,3 +122,4 @@ TYPED_TEST(HDF5OutputLayerTest, TestForward) {
 }
 
 }  // namespace caffe
+#endif //USE_HDF5
\ No newline at end of file
index 487f517..7d27d77 100644 (file)
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
 #include <string>
 #include <vector>
 
@@ -164,3 +165,4 @@ TYPED_TEST(HDF5DataLayerTest, TestSkip) {
 }
 
 }  // namespace caffe
+#endif //USE_HDF5
\ No newline at end of file
index ed73742..cefd853 100644 (file)
@@ -1,3 +1,4 @@
+#ifdef USE_HDF5
 #include "caffe/util/hdf5.hpp"
 
 #include <string>
@@ -207,3 +208,4 @@ string hdf5_get_name_by_idx(hid_t loc_id, int idx) {
 }
 
 }  // namespace caffe
+#endif  // USE_HDF5
index 71c0227..b0dd302 100644 (file)
@@ -85,6 +85,9 @@ void caffe_add_scalar(const int N, const double alpha, double* Y) {
 template <typename Dtype>
 void caffe_copy(const int N, const Dtype* X, Dtype* Y) {
   if (X != Y) {
+#ifdef USE_ACL
+      memcpy(Y, X, sizeof(Dtype) * N);  // NOLINT(caffe/alt_fn)
+#else
     if (Caffe::mode() == Caffe::GPU) {
 #ifndef CPU_ONLY
       // NOLINT_NEXT_LINE(caffe/alt_fn)
@@ -95,6 +98,7 @@ void caffe_copy(const int N, const Dtype* X, Dtype* Y) {
     } else {
       memcpy(Y, X, sizeof(Dtype) * N);  // NOLINT(caffe/alt_fn)
     }
+#endif
   }
 }
 
diff --git a/unit_tests/Makefile b/unit_tests/Makefile
new file mode 100644 (file)
index 0000000..034ee02
--- /dev/null
@@ -0,0 +1,87 @@
+
+include ../Makefile.config
+
+CXX_SRCS+=test_pooling_layer.cpp
+CXX_SRCS+=test_softmax_layer.cpp
+CXX_SRCS+= test_inner_product_layer.cpp
+CXX_SRCS+=test_neuron_layer.cpp
+CXX_SRCS+=test_lrn_layer.cpp  #failed on single channel LRN
+#C_SRCS=pmu.c testbed.c
+CXX_SRCS+= test_convolution_layer.cpp  
+#CXX_SRCS= test_fail.cpp
+BIN_SRCS=test_caffe_main.cpp test.cpp
+
+HOME=/home/firefly
+#
+#  caffe related stuff
+#
+
+CAFFE_ROOT=$(HOME)/caffeOnACL
+CAFFE_INCS = -I$(CAFFE_ROOT)/include -I$(CAFFE_ROOT)/distribute/include/
+CAFFE_LIBS = -L$(CAFFE_ROOT)/distribute/lib -lcaffe  -lglog -lgflags -lprotobuf -lboost_system -lboost_filesystem
+CAFFE_RPATH =$(CAFFE_ROOT)/distribute/lib
+
+
+#
+# google test related stuff
+#
+
+GTEST_ROOT=/usr/local
+GTEST_LIBS=$(GTEST_ROOT)/lib/libgtest_main.a $(GTEST_ROOT)/lib/libgtest.a
+GTEST_INCS=-I$(GTEST_ROOT)/include/
+
+
+#
+# arm compute
+#
+
+ACL_ROOT=$(HOME)/ComputeLibrary
+ACL_LIBS=-L$(ACL_ROOT)/build -L$(ACL_ROOT)/build/arm_compute -larm_compute  -lOpenCL
+ACL_RPATH=$(ACL_ROOT)/build:$(ACL_ROOT)/build/arm_compute
+
+COMM_CFLAGS=$(GTEST_INCS) $(CAFFE_INCS) -Wall -g  -DCPU_ONLY -DUSE_ACL 
+#USE_PROFILING -- get profiling informations, is controled by LOGACL
+#LAYER_PERF_STAT -- haitao's net profiling information
+ifeq ($(USE_PROFILING), 1)
+        COMM_CFLAGS += -DUSE_PROFILING -DLAYER_PERF_STAT
+endif
+
+CXXFLAGS=$(COMM_CFLAGS) -Wno-sign-compare
+CFLAGS=$(COMM_CFLAGS)
+
+CC=gcc
+CXX=g++
+
+COMM_OBJS=$(CXX_SRCS:.cpp=.o) $(C_SRCS:.c=.o)
+BIN_OBJS+=$(BIN_SRCS:.cpp=.o)
+BIN_EXES=$(BIN_OBJS:.o=)
+
+LIBS+=$(CAFFE_LIBS) $(GTEST_LIBS) -lpthread -lopenblas $(ACL_LIBS)
+
+RT_PATH=-Wl,-rpath,$(CAFFE_RPATH):$(ACL_RPATH)
+
+LDFLAGS+=$(RT_PATH)
+
+
+all : $(BIN_EXES)
+
+$(BIN_EXES):%:%.o
+
+$(BIN_EXES):$(COMM_OBJS)
+
+
+
+clean:
+       rm -f $(BIN_EXES) *.o *.so
+
+.PHONY : all clean
+
+%.o : %.c
+       $(CC) $(CFLAGS)  -c $< -o $@ 
+
+%.o : %.cpp
+       $(CXX) $(CXXFLAGS)  -c $< -o $@ 
+
+%: %.o
+       $(CXX) $(LDFLAGS) $< $(COMM_OBJS) -o $@ $(LIBS)
+
diff --git a/unit_tests/pmu.c b/unit_tests/pmu.c
new file mode 100644 (file)
index 0000000..8c2f6b3
--- /dev/null
@@ -0,0 +1,376 @@
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "pmu.h"
+
+#define MAX_SLOT_NUM           64      /* defined by ARMv8 SPEC*/
+#define ARMV8_PMCR_N_SHIFT      11      /* Number of counters supported */                             
+#define ARMV8_PMCR_N_MASK       0x1f                                                                    
+
+
+static __thread  struct pmu_event_record * g_rec_ptr[MAX_SLOT_NUM];
+static __thread  int max_counter_slot;
+
+/* start and stop counter */
+
+static void stop_event_profile(struct pmu_event_record * p_record)
+{
+       int slot=p_record->p_evt->slot;
+
+       p_record->p_evt->enabled=0;
+
+       stop_pmu_counter(slot);
+}
+
+
+static void init_pmu_event_record(struct pmu_event * p_evt, struct pmu_event_record * p_record)
+{
+       struct prof_stat *p_stat;
+       int i=0;
+
+       p_record->p_evt=p_evt;
+       p_record->last_val=p_evt->init_val;
+       p_record->base_val=p_evt->init_val;
+       p_stat=p_record->prof_stat;
+
+       for(i=0;i<MAX_PROF_POINTS;i++)
+       {
+               p_stat[i].prof_seq=i;
+               p_stat[i].max_val=0;
+               p_stat[i].min_val=-1U;
+               p_stat[i].raw_val=0xdeadbeaf;
+               p_stat[i].cur_val=0;
+               p_stat[i].total_val=0;
+               p_stat[i].enter_count=0;         
+       }
+}
+
+static void start_event_profile(struct pmu_event_record * p_record)
+{
+       int slot=p_record->p_evt->slot;
+       struct prof_stat *p_stat;
+        int i;
+
+       p_record->p_evt->enabled=1;
+
+       p_stat=p_record->prof_stat;
+
+       for(i=0;i<MAX_PROF_POINTS;i++)
+       {
+               p_stat[i].prof_seq=i;
+               p_stat[i].max_val=0;
+               p_stat[i].min_val=-1U;
+               p_stat[i].raw_val=0xdeadbeaf;
+               p_stat[i].cur_val=0;
+               p_stat[i].total_val=0;
+               p_stat[i].enter_count=0;         
+        }
+
+       write_pmu_counter(slot,p_record->p_evt->init_val);
+       start_pmu_counter(slot);
+
+}
+
+/* create event and profile */
+
+
+int setup_event_counter(int slot, int event_id)
+{
+
+       if(slot==31)
+               return 0;
+
+       if(event_id>1023)
+               return -1;
+
+       write_32bit_sysreg(PMSELR_EL0,slot);
+       write_32bit_sysreg(PMXEVTYPER_EL0,event_id);
+
+       return 0;
+}
+
+
+static struct pmu_event_record * create_pmu_event_record(char *name, int slot, 
+               int event_id, uint32_t init_val, char * note)
+{
+       struct pmu_event * p_evt;
+       struct pmu_event_record * p_record;
+
+       if(setup_event_counter(slot,event_id)<0)
+               return NULL;
+
+       p_evt=malloc(sizeof(struct pmu_event));
+
+       if(p_evt==NULL)
+               return NULL;
+
+       p_evt->name=name;
+       p_evt->slot=slot;
+       p_evt->event_id=event_id;
+       p_evt->init_val=init_val;
+       p_evt->note=note;
+       p_evt->enabled=0;
+
+       p_record=malloc(sizeof(struct pmu_event_record));
+
+       if(p_record==NULL)
+       {
+               free(p_evt);
+               return NULL;
+       }
+
+       p_record->p_evt=p_evt;
+
+       init_pmu_event_record(p_evt,p_record);
+
+       return p_record;
+}
+
+
+static void record_event_prof(struct pmu_event_record * p_record, 
+               int prof_seq, int cal_offset, int update_last)
+{
+       struct prof_stat * p_stat;
+       uint32_t evt_val;
+
+       evt_val=read_pmu_counter(p_record->p_evt->slot);
+
+       p_stat=&p_record->prof_stat[prof_seq];
+
+       p_stat->cal_offset=cal_offset;
+       p_stat->update_last=update_last;
+       p_stat->raw_val=evt_val;
+
+       if(cal_offset)
+               p_stat->cur_val=evt_val-p_record->last_val;
+       else
+               p_stat->cur_val=evt_val-p_record->base_val;
+
+       if(update_last)
+               p_record->last_val=evt_val;
+
+       p_stat->total_val+=p_stat->cur_val;
+
+       if(p_stat->cur_val>p_stat->max_val)
+               p_stat->max_val=p_stat->cur_val;
+
+       if(p_stat->cur_val<p_stat->min_val)
+               p_stat->min_val=p_stat->cur_val;
+
+       p_stat->enter_count++;
+}
+
+
+static void release_pmu_event_record(struct pmu_event_record * p_record)
+{
+     struct pmu_event * p_evt;
+
+     p_evt=p_record->p_evt;
+
+     if(p_evt->enabled)
+         stop_pmu_counter(p_evt->slot);
+
+     free(p_evt);
+     free(p_record);
+}
+
+
+
+
+/* debugging */
+
+static void dump_pmu_event(struct pmu_event * p_evt)
+{
+        
+       printf("event[%s/0x%x]: slot [%d] init_val[0x%x] enabled[%d]",
+                       p_evt->name,p_evt->event_id,p_evt->slot,p_evt->init_val,
+                       p_evt->enabled);
+
+       if(p_evt->note)
+               printf(" note[%s]\n",p_evt->note);
+       else
+               printf("\n");
+}
+
+static void dump_pmu_event_record(struct pmu_event_record * p_record)
+{
+       int i;
+       struct prof_stat * p_stat;
+        uint64_t total_avg_val=0;
+        int count=0;
+        uint32_t avg;
+
+        printf("------------------------------------------------------------------------\n");
+
+       dump_pmu_event(p_record->p_evt);
+
+       p_stat=&p_record->prof_stat[0];
+
+       for(i=0;i<MAX_PROF_POINTS;i++)
+       {
+               if(p_stat[i].enter_count==0)
+                       continue;
+
+                avg=(uint32_t)(p_stat[i].total_val/p_stat[i].enter_count);
+
+               printf("stat [%d]: max/min/avg [0x%x/0x%x/0x%x] total [0x%lx] count[%u]\n",
+                               i,p_stat[i].max_val,p_stat[i].min_val,
+                               avg,
+                               p_stat[i].total_val,p_stat[i].enter_count);
+               printf("         raw_val[0x%x] cal_offset[%d] update_last[%d]\n",
+                               p_stat[i].raw_val,p_stat[i].cal_offset,p_stat[i].update_last);
+
+                count++;
+
+                total_avg_val+=avg;
+
+       }
+
+        printf("total [%d] points, the sum of average number is: [0x%lx]\n\n",count,total_avg_val);
+}
+
+/* output interface */
+
+void init_pmu_registers(void)
+{
+       /* enabled PMU in PMCR*/
+       write_32bit_sysreg(PMCR_EL0,0x1);
+       max_counter_slot=(read_32bit_sysreg(PMCR_EL0) >> ARMV8_PMCR_N_SHIFT)&ARMV8_PMCR_N_MASK;
+}
+
+
+#define dump_32bit_sysreg(reg) \
+       printf(__stringify(reg) " is [0x%08x]\n",read_32bit_sysreg(reg))
+
+#define dump_64bit_sysreg(reg) \
+       printf(__stringify(reg) " is [0x%016llx]\n",read_32bit_sysreg(reg))
+
+
+void dump_pmu_registers(void)
+{
+       dump_32bit_sysreg(PMCEID0_EL0);
+       dump_32bit_sysreg(PMCEID1_EL0);
+       dump_32bit_sysreg(PMOVSSET_EL0);
+       dump_32bit_sysreg(PMCR_EL0);
+       dump_32bit_sysreg(PMUSERENR_EL0);
+        dump_32bit_sysreg(PMCNTENSET_EL0);
+}
+
+struct pmu_event_record *  get_pmu_event_record(int slot)
+{
+       return  g_rec_ptr[slot];
+}
+
+int create_pmu_event(char *name,int event_id,
+               uint32_t init_val, char * note)
+{
+       int i;
+
+       struct pmu_event_record * p_record;
+
+       for(i=0;i<max_counter_slot;i++)
+       {
+               if(g_rec_ptr[i]==NULL)
+                       break;
+       }
+
+       if(i==max_counter_slot)
+               return -1;
+
+       p_record=create_pmu_event_record(name,i,event_id,init_val,note);
+
+       if(p_record==NULL)
+               return -1;
+
+       g_rec_ptr[i]=p_record;
+
+       return i;
+}
+
+void release_pmu_event(int slot)
+{
+       struct pmu_event_record * p_record;
+
+       p_record=g_rec_ptr[slot];
+
+       if(p_record)
+               release_pmu_event_record(p_record);
+
+       g_rec_ptr[slot]=NULL;
+}
+
+void start_pmu_event(int slot)
+{
+       struct pmu_event_record * p_record;
+
+       p_record=g_rec_ptr[slot];
+
+       start_event_profile(p_record);
+}
+
+void stop_pmu_event(int slot)
+{
+       struct pmu_event_record * p_record;
+
+       p_record=g_rec_ptr[slot];
+
+       stop_event_profile(p_record);
+}
+
+void record_pmu_event(int slot, int seq, int cal_offset, int update_last)
+{
+       struct pmu_event_record * p_record;
+
+       p_record=g_rec_ptr[slot];
+
+       record_event_prof(p_record,seq,cal_offset,update_last);
+}
+
+void dump_pmu_event_stat(int slot)
+{
+       struct pmu_event_record * p_record;
+
+       p_record=g_rec_ptr[slot];
+
+       dump_pmu_event_record(p_record);
+}
+
+
+uint32_t get_pmu_stat_avg(int slot)
+{
+       struct pmu_event_record * p_record;
+        struct prof_stat * p_stat;
+        uint32_t total_avg=0;
+        uint32_t avg;
+        int i;
+
+       p_record=g_rec_ptr[slot];
+
+        for(i=0;i<MAX_PROF_POINTS;i++)
+        {
+           p_stat=&p_record->prof_stat[i];
+
+           if(p_stat->enter_count==0)
+                 continue;
+            avg=p_stat->total_val/p_stat->enter_count;
+            total_avg+=avg;
+        }
+        return total_avg;
+}
+
+void set_pmu_event_base(int slot)
+{
+      struct pmu_event_record * p_record;
+
+      uint32_t val;
+
+      p_record=g_rec_ptr[slot];
+      val=read_pmu_counter(slot);
+
+       p_record->last_val=val;
+       p_record->base_val=val;
+
+}
diff --git a/unit_tests/pmu.h b/unit_tests/pmu.h
new file mode 100644 (file)
index 0000000..9c25206
--- /dev/null
@@ -0,0 +1,130 @@
+#ifndef AARCH64_PMU_H
+#define AARCH64_PMU_H
+
+#include <stdint.h>
+
+
+#define __stringify_1(x...)    #x
+#define __stringify(x...)      __stringify_1(x)
+
+#define read_32bit_sysreg(reg) \
+    ({\
+       uint32_t val;\
+       __asm__ __volatile__ (\
+          "mrs %0," __stringify(reg):"=r"(val));\
+        val;\
+    })
+
+
+#define read_64bit_sysreg(reg) \
+    ({\
+       uint64_t val;\
+       __asm__ __volatile__ (\
+          "mrs %0," __stringify(reg):"=r"(val));\
+        val;\
+    })
+
+#define write_32bit_sysreg(reg,val) \
+    ({\
+         __asm__ __volatile__ (\
+          "msr " __stringify(reg) " ,%0"::"r"(val));\
+      })
+
+#define write_64bit_sysreg(reg,val) write_32bit_sysreg(reg,val)
+
+#define MAX_PROF_POINTS 16
+
+struct pmu_event
+{
+  int  slot;
+  int  event_id;
+  char * name;
+  uint32_t init_val;
+  int enabled;
+  char * note;
+};
+
+struct  prof_stat
+{
+   int     prof_seq;
+   uint32_t max_val;
+   uint32_t min_val;
+   uint32_t cur_val;
+   uint32_t raw_val;
+   uint64_t total_val;
+   uint32_t enter_count;
+   int  cal_offset;
+   int  update_last;
+};
+
+
+struct pmu_event_record
+{
+  struct pmu_event*  p_evt;
+  uint32_t last_val;
+  uint32_t base_val;
+  struct prof_stat prof_stat[MAX_PROF_POINTS];
+};
+
+/* all functions in the group must be called on the same CPU */
+
+extern  void init_pmu_registers(void);
+extern void dump_pmu_registers(void);
+
+/* create one event with event_id, return slot number in success */
+extern int create_pmu_event(char *name,int event_id, 
+                          uint32_t init_val, char * note);
+
+extern void release_pmu_event(int slot);
+
+extern void start_pmu_event(int slot);
+
+extern void stop_pmu_event(int slot);
+
+extern void set_pmu_event_base(int slot);
+
+extern void record_pmu_event(int slot, int seq, int cal_offset, int update_last);
+
+extern void dump_pmu_event_stat(int slot);
+
+extern struct pmu_event_record *  get_pmu_event_record(int slot);
+
+extern uint32_t get_pmu_stat_avg(int slot); /* adding all phase avg together */
+
+/* regsiter level interface */
+
+extern int setup_event_counter(int slot, int event_id);
+
+static inline void start_pmu_counter(int slot)
+{
+           uint32_t mask=1<<slot;
+
+           write_32bit_sysreg(PMCNTENSET_EL0,mask);
+}
+
+static inline void stop_pmu_counter(int slot)
+{
+   uint32_t mask=1<<slot;
+
+   write_32bit_sysreg(PMCNTENCLR_EL0,mask);
+}
+
+
+static inline void write_pmu_counter(int slot,uint32_t val)
+{
+   write_32bit_sysreg(PMSELR_EL0,slot);
+
+   if(slot<31)
+      write_32bit_sysreg(PMXEVCNTR_EL0, val);
+   else
+      write_64bit_sysreg(PMXEVCNTR_EL0,val);
+
+}
+
+static inline uint32_t read_pmu_counter(int slot)
+{
+   write_32bit_sysreg(PMSELR_EL0,slot);
+   return read_32bit_sysreg(PMXEVCNTR_EL0);
+}
+
+#endif
diff --git a/unit_tests/prof_convolution_layer.cpp b/unit_tests/prof_convolution_layer.cpp
new file mode 100644 (file)
index 0000000..01f4fcb
--- /dev/null
@@ -0,0 +1,302 @@
+#include <vector>
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/conv_layer.hpp"
+
+#include <glog/logging.h>
+
+extern "C" {
+#include "testbed.h"
+}
+
+
+#define TYPED_TEST(a,b) template <typename TypeParam> void a <TypeParam>:: b (void)
+#define EXPECT_NEAR(a,b,c) {}
+#define EXPECT_EQ(a,b) {}
+
+namespace caffe {
+
+template <typename TypeParam>
+struct CPUDevice {
+  typedef TypeParam Dtype;
+  static const Caffe::Brew device = Caffe::CPU;
+};
+
+
+template <typename TypeParam>
+class ConvolutionLayerTest {
+  typedef typename TypeParam::Dtype Dtype;
+
+public:
+
+  void TestSimpleConvolution(void);
+
+  void TestDilatedConvolution(void);
+
+  void Test0DConvolution(void);
+
+  void TestSimple3DConvolution(void);
+
+  void TestDilated3DConvolution(void);
+
+  void Test1x1Convolution(void);
+
+  void TestSimpleConvolutionGroup(void);
+  
+  void TestNDAgainst2D(void);
+
+  void RunConvolution(void);
+
+  ConvolutionLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_bottom_2_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_top_(new Blob<Dtype>()),
+        blob_top_2_(new Blob<Dtype>()) {}
+  virtual void SetUp() {
+    // fill the values
+    FillerParameter filler_param;
+    filler_param.set_value(1.);
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    filler.Fill(this->blob_bottom_2_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+
+  virtual ~ConvolutionLayerTest() {
+    delete blob_bottom_;
+    delete blob_bottom_2_;
+    delete blob_top_;
+    delete blob_top_2_;
+  }
+
+  virtual Blob<Dtype>* MakeReferenceTop(Blob<Dtype>* top) {
+    this->ref_blob_top_.reset(new Blob<Dtype>());
+    this->ref_blob_top_->ReshapeLike(*top);
+    return this->ref_blob_top_.get();
+  }
+
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_bottom_2_;
+  Blob<Dtype>* const blob_top_;
+  Blob<Dtype>* const blob_top_2_;
+  shared_ptr<Blob<Dtype> > ref_blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+  shared_ptr<Layer<Dtype> > layer;
+};
+
+TYPED_TEST(ConvolutionLayerTest, RunConvolution) {
+
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+
+TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+  layer=shared_ptr<Layer<Dtype> > (new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+
+TYPED_TEST(ConvolutionLayerTest, TestDilatedConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  vector<int> bottom_shape;
+  bottom_shape.push_back(2);
+  bottom_shape.push_back(3);
+  bottom_shape.push_back(8);
+  bottom_shape.push_back(7);
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_dilation(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+  layer=shared_ptr<Layer<Dtype> > (new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+TYPED_TEST(ConvolutionLayerTest, Test0DConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  const int kNumOutput = 3;
+  convolution_param->set_num_output(kNumOutput);
+  convolution_param->set_axis(3);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  layer=shared_ptr<Layer<Dtype> > (
+      new ConvolutionLayer<Dtype>(layer_param));
+  vector<int> top_shape = this->blob_bottom_->shape();
+  top_shape[3] = kNumOutput;
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(top_shape, this->blob_top_->shape());
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestSimple3DConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  vector<int> bottom_shape(5);
+  bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0);
+  bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1);
+  bottom_shape[2] = 5;
+  bottom_shape[3] = this->blob_bottom_vec_[0]->shape(2);
+  bottom_shape[4] = this->blob_bottom_vec_[0]->shape(3);
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  layer=shared_ptr<Layer<Dtype> > (
+      new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestDilated3DConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  vector<int> bottom_shape(5);
+  bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0);
+  bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1);
+  bottom_shape[2] = 6;
+  bottom_shape[3] = 7;
+  bottom_shape[4] = 8;
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_dilation(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  layer=shared_ptr<Layer<Dtype> > (
+      new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(1);
+  convolution_param->add_stride(1);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+  layer=shared_ptr<Layer<Dtype> > (
+      new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(3);
+  convolution_param->set_group(3);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+  layer=shared_ptr<Layer<Dtype> > (
+      new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+}
+
+}  // namespace caffe
+
+using namespace caffe;
+ConvolutionLayerTest<CPUDevice<float> > * g_convptr;
+
+void single_forward(void * dummy )
+{
+  g_convptr->RunConvolution();
+}
+
+void forward_convolution(void)
+{
+   run_test(16,0,single_forward,NULL);
+}
+
+#define RUN_FUNC(test_case) test_ ## test_case ()
+
+#define DEF_TEST_FUNC(test_case) \
+void test_## test_case (void)\
+{\
+   std::cout<<__FUNCTION__<<"  start ..."<<std::endl;\
+   g_convptr=new ConvolutionLayerTest<CPUDevice<float> >;\
+   g_convptr->SetUp();\
+   g_convptr->Test ## test_case ();\
+   forward_convolution();\
+   delete  g_convptr;\
+   std::cout<<__FUNCTION__<<"  DONE"<<std::endl;\
+}
+
+DEF_TEST_FUNC(SimpleConvolution)
+DEF_TEST_FUNC(DilatedConvolution)
+DEF_TEST_FUNC(0DConvolution)
+DEF_TEST_FUNC(Simple3DConvolution)
+DEF_TEST_FUNC(Dilated3DConvolution)
+DEF_TEST_FUNC(1x1Convolution)
+DEF_TEST_FUNC(SimpleConvolutionGroup)
+
+
+int main(int argc, char * argv[])
+{
+    caffe::GlobalInit(&argc, &argv);
+
+    init_testbed();
+
+    RUN_FUNC(SimpleConvolution);
+    RUN_FUNC(DilatedConvolution);
+    RUN_FUNC(0DConvolution);
+    RUN_FUNC(Simple3DConvolution);
+    RUN_FUNC(Dilated3DConvolution);
+    RUN_FUNC(1x1Convolution);
+    RUN_FUNC(SimpleConvolutionGroup);
+
+    release_testbed();
+    return 0;
+}
diff --git a/unit_tests/sgemm.cpp b/unit_tests/sgemm.cpp
new file mode 100644 (file)
index 0000000..11a80a4
--- /dev/null
@@ -0,0 +1,74 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <malloc.h>
+
+#include <cblas.h>
+
+extern "C" {
+#include "testbed.h"
+}
+
+float * A;
+float * B;
+float * C;
+int M;
+int N;
+int K;
+
+
+void init_matrix(float * m, int size)
+{
+    int i;
+    for(i=0;i<size;i+=2)
+        m[i]=i;
+}
+
+void init_data(void)
+{
+    int sizeA;
+    int sizeB;
+    int sizeC;
+
+    sizeA=M*K;
+    sizeB=K*N;
+    sizeC=M*N;
+
+    A=(float *)malloc(sizeA*sizeof(float));
+    B=(float *)malloc(sizeB*sizeof(float));
+    C=(float *)malloc(sizeC*sizeof(float));
+
+    init_matrix(A,sizeA);
+    init_matrix(B,sizeB);
+    init_matrix(C,sizeC);
+}
+
+
+
+void run_sgemm(void * dummy)
+{
+   int i;
+   for(i=0;i<1;i++)
+   {
+    cblas_sgemm(CblasRowMajor,CblasNoTrans,CblasNoTrans,M,N,K,
+    1.0,A,K,B,N,0,C,N);
+   }
+   
+}
+
+int main(int argc, char * argv[])
+{
+  
+     M=27;
+     K=9;
+     N=37632;
+
+  
+     init_data();
+     init_testbed();
+
+     run_test(16,1,run_sgemm,NULL);
+
+     release_testbed();
+     
+     return 0;
+}
diff --git a/unit_tests/test.cpp b/unit_tests/test.cpp
new file mode 100644 (file)
index 0000000..80adcc9
--- /dev/null
@@ -0,0 +1,37 @@
+#include "gtest/gtest.h"
+
+
+template <typename TypeParam>
+class foo : public ::testing::Test {
+
+public:
+   foo(){};
+  ~foo(){};
+
+   TypeParam data;
+};
+
+
+typedef ::testing::Types<int,float > TestDtype;
+
+TYPED_TEST_CASE(foo,TestDtype);
+
+TYPED_TEST(foo,test1)
+{
+
+    TypeParam a=10;
+
+   this->data=10;
+
+   EXPECT_EQ(this->data,a);
+
+}
+
+
+int main(int argc, char * argv[])
+{
+  ::testing::InitGoogleTest(&argc, argv);
+   return RUN_ALL_TESTS(); 
+   return 0;
+}
diff --git a/unit_tests/test_caffe_main.cpp b/unit_tests/test_caffe_main.cpp
new file mode 100644 (file)
index 0000000..ee0dcd4
--- /dev/null
@@ -0,0 +1,34 @@
+#include "caffe/caffe.hpp"
+#include "caffe/test/test_caffe_main.hpp"
+
+extern "C" {
+#include "testbed.h"
+}
+
+class testbed_env: public ::testing::Environment {
+
+  public:
+      testbed_env(){};
+      ~testbed_env() {};
+
+    void SetUp(void) 
+    { 
+         std::cout<<"setting up testbed resource"<<std::endl;
+    }
+
+    void TearDown(void) 
+    { 
+        std::cout<<"release testbed resource"<<std::endl;
+    }
+
+};
+
+
+int main(int argc, char** argv) {
+  ::testing::InitGoogleTest(&argc, argv);
+  caffe::GlobalInit(&argc, &argv);
+
+  ::testing::AddGlobalTestEnvironment(new testbed_env);
+  // invoke the test.
+  return RUN_ALL_TESTS();
+}
diff --git a/unit_tests/test_common.cpp b/unit_tests/test_common.cpp
new file mode 100644 (file)
index 0000000..58ae5c6
--- /dev/null
@@ -0,0 +1,64 @@
+#include "gtest/gtest.h"
+
+#include "caffe/common.hpp"
+#include "caffe/syncedmem.hpp"
+#include "caffe/util/math_functions.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+class CommonTest : public ::testing::Test {};
+
+#ifndef CPU_ONLY  // GPU Caffe singleton test.
+
+TEST_F(CommonTest, TestCublasHandlerGPU) {
+  int cuda_device_id;
+  CUDA_CHECK(cudaGetDevice(&cuda_device_id));
+  EXPECT_TRUE(Caffe::cublas_handle());
+}
+
+#endif
+
+TEST_F(CommonTest, TestBrewMode) {
+  Caffe::set_mode(Caffe::CPU);
+  EXPECT_EQ(Caffe::mode(), Caffe::CPU);
+  Caffe::set_mode(Caffe::GPU);
+  EXPECT_EQ(Caffe::mode(), Caffe::GPU);
+}
+
+TEST_F(CommonTest, TestRandSeedCPU) {
+  SyncedMemory data_a(10 * sizeof(int));
+  SyncedMemory data_b(10 * sizeof(int));
+  Caffe::set_random_seed(1701);
+  caffe_rng_bernoulli(10, 0.5, static_cast<int*>(data_a.mutable_cpu_data()));
+
+  Caffe::set_random_seed(1701);
+  caffe_rng_bernoulli(10, 0.5, static_cast<int*>(data_b.mutable_cpu_data()));
+
+  for (int i = 0; i < 10; ++i) {
+    EXPECT_EQ(static_cast<const int*>(data_a.cpu_data())[i],
+        static_cast<const int*>(data_b.cpu_data())[i]);
+  }
+}
+
+#ifndef CPU_ONLY  // GPU Caffe singleton test.
+
+TEST_F(CommonTest, TestRandSeedGPU) {
+  SyncedMemory data_a(10 * sizeof(unsigned int));
+  SyncedMemory data_b(10 * sizeof(unsigned int));
+  Caffe::set_random_seed(1701);
+  CURAND_CHECK(curandGenerate(Caffe::curand_generator(),
+        static_cast<unsigned int*>(data_a.mutable_gpu_data()), 10));
+  Caffe::set_random_seed(1701);
+  CURAND_CHECK(curandGenerate(Caffe::curand_generator(),
+        static_cast<unsigned int*>(data_b.mutable_gpu_data()), 10));
+  for (int i = 0; i < 10; ++i) {
+    EXPECT_EQ(((const unsigned int*)(data_a.cpu_data()))[i],
+        ((const unsigned int*)(data_b.cpu_data()))[i]);
+  }
+}
+
+#endif
+
+}  // namespace caffe
diff --git a/unit_tests/test_convolution_layer.cpp b/unit_tests/test_convolution_layer.cpp
new file mode 100644 (file)
index 0000000..b2db63e
--- /dev/null
@@ -0,0 +1,888 @@
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/conv_layer.hpp"
+
+#ifdef USE_CUDNN
+#include "caffe/layers/cudnn_conv_layer.hpp"
+#endif
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+static void dump_blob(const Blob<Dtype> * blob, const char * outfile)
+{
+   std::ofstream os;
+   os.open(outfile);
+
+   for(int i=0;i<blob->LegacyShape(0);i++)
+   {
+     os<<"batch: "<<i<<std::endl;
+
+     for(int j=0;j<blob->LegacyShape(1);j++)
+     {
+        os<<"channel: "<<j<<std::endl;
+
+        for(int k=0;k<blob->LegacyShape(2);k++)
+        {
+            for(int l=0;l<blob->LegacyShape(3);l++)
+            {
+                Dtype data=blob->data_at(i,j,k,l);
+                os<<" "<<data;
+            }
+            os<<std::endl;
+        }
+      os<<std::endl;
+    }
+
+   }
+
+   os.close();
+
+}
+
+
+template <typename Dtype>
+static void fill_blob_data(Blob<Dtype >* bottom, int fixed, float val)
+{
+    for(int i=0;i<bottom->num();i++)
+      for(int j=0;j<bottom->channels();j++)
+        for(int l=0;l<bottom->height();l++)
+          for(int k=0;k<bottom->width();k++)
+        {
+           int offset;
+           Dtype * ptr;
+
+            offset=i*bottom->channels()*bottom->height()*bottom->width()+
+                    j*bottom->height()*bottom->width()+
+                   l*bottom->width()+k;
+
+           ptr=bottom->mutable_cpu_data();
+
+           if(fixed)
+              ptr[offset]=val;
+           else
+              ptr[offset]=offset;
+
+        }
+
+
+}
+
+
+// Reference convolution for checking results:
+// accumulate through explicit loops over input, output, and filters.
+template <typename Dtype>
+void caffe_conv(const Blob<Dtype>* in, ConvolutionParameter* conv_param,
+    const vector<shared_ptr<Blob<Dtype> > >& weights,
+    Blob<Dtype>* out) {
+  const bool has_depth = (out->num_axes() == 5);
+  if (!has_depth) { CHECK_EQ(4, out->num_axes()); }
+  // Kernel size, stride, and pad
+  int kernel_h, kernel_w;
+  if (conv_param->has_kernel_h() || conv_param->has_kernel_w()) {
+    kernel_h = conv_param->kernel_h();
+    kernel_w = conv_param->kernel_w();
+  } else {
+    kernel_h = kernel_w = conv_param->kernel_size(0);
+  }
+  int pad_h, pad_w;
+  if (conv_param->has_pad_h() || conv_param->has_pad_w()) {
+    pad_h = conv_param->pad_h();
+    pad_w = conv_param->pad_w();
+  } else {
+    pad_h = pad_w = conv_param->pad_size() ? conv_param->pad(0) : 0;
+  }
+  int stride_h, stride_w;
+  if (conv_param->has_stride_h() || conv_param->has_stride_w()) {
+    stride_h = conv_param->stride_h();
+    stride_w = conv_param->stride_w();
+  } else {
+    stride_h = stride_w = conv_param->stride_size() ? conv_param->stride(0) : 1;
+  }
+  int dilation_h, dilation_w;
+  dilation_h = dilation_w = conv_param->dilation_size() ?
+                            conv_param->dilation(0) : 1;
+  int kernel_d, pad_d, stride_d, dilation_d;
+  if (has_depth) {
+    kernel_d = kernel_h;
+    stride_d = stride_h;
+    pad_d = pad_h;
+    dilation_d = dilation_h;
+  } else {
+    kernel_d = stride_d = dilation_d = 1;
+    pad_d = 0;
+  }
+  // Groups
+  int groups = conv_param->group();
+  int o_g = out->shape(1) / groups;
+  int k_g = in->shape(1) / groups;
+  int o_head, k_head;
+  // Convolution
+  vector<int> weight_offset(4 + has_depth);
+  vector<int> in_offset(4 + has_depth);
+  vector<int> out_offset(4 + has_depth);
+  Dtype* out_data = out->mutable_cpu_data();
+  for (int n = 0; n < out->shape(0); n++) {
+    for (int g = 0; g < groups; g++) {
+      o_head = o_g * g;
+      k_head = k_g * g;
+      for (int o = 0; o < o_g; o++) {
+        for (int k = 0; k < k_g; k++) {
+          for (int z = 0; z < (has_depth ? out->shape(2) : 1); z++) {
+            for (int y = 0; y < out->shape(2 + has_depth); y++) {
+              for (int x = 0; x < out->shape(3 + has_depth); x++) {
+                for (int r = 0; r < kernel_d; r++) {
+                  for (int p = 0; p < kernel_h; p++) {
+                    for (int q = 0; q < kernel_w; q++) {
+                      int in_z = z * stride_d - pad_d + r * dilation_d;
+                      int in_y = y * stride_h - pad_h + p * dilation_h;
+                      int in_x = x * stride_w - pad_w + q * dilation_w;
+                      if (in_z >= 0 && in_z < (has_depth ? in->shape(2) : 1)
+                          && in_y >= 0 && in_y < in->shape(2 + has_depth)
+                          && in_x >= 0 && in_x < in->shape(3 + has_depth)) {
+                        weight_offset[0] = o + o_head;
+                        weight_offset[1] = k;
+                        if (has_depth) { weight_offset[2] = r; }
+                        weight_offset[2 + has_depth] = p;
+                        weight_offset[3 + has_depth] = q;
+                        in_offset[0] = n;
+                        in_offset[1] = k + k_head;
+                        if (has_depth) { in_offset[2] = in_z; }
+                        in_offset[2 + has_depth] = in_y;
+                        in_offset[3 + has_depth] = in_x;
+                        out_offset[0] = n;
+                        out_offset[1] = o + o_head;
+                        if (has_depth) { out_offset[2] = z; }
+                        out_offset[2 + has_depth] = y;
+                        out_offset[3 + has_depth] = x;
+                        out_data[out->offset(out_offset)] +=
+                            in->data_at(in_offset)
+                            * weights[0]->data_at(weight_offset);
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  // Bias
+  if (conv_param->bias_term()) {
+    const Dtype* bias_data = weights[1]->cpu_data();
+    for (int n = 0; n < out->shape(0); n++) {
+      for (int o = 0; o < out->shape(1); o++) {
+        for (int z = 0; z < (has_depth ? out->shape(2) : 1); z++) {
+          for (int y = 0; y < out->shape(2 + has_depth); y++) {
+            for (int x = 0; x < out->shape(3 + has_depth); x++) {
+              out_offset[0] = n;
+              out_offset[1] = o;
+              if (has_depth) { out_offset[2] = z; }
+              out_offset[2 + has_depth] = y;
+              out_offset[3 + has_depth] = x;
+              out_data[out->offset(out_offset)] += bias_data[o];
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+template void caffe_conv(const Blob<float>* in,
+    ConvolutionParameter* conv_param,
+    const vector<shared_ptr<Blob<float> > >& weights,
+    Blob<float>* out);
+template void caffe_conv(const Blob<double>* in,
+    ConvolutionParameter* conv_param,
+    const vector<shared_ptr<Blob<double> > >& weights,
+    Blob<double>* out);
+
+template <typename TypeParam>
+class ConvolutionLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  ConvolutionLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_bottom_2_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_top_(new Blob<Dtype>()),
+        blob_top_2_(new Blob<Dtype>()) {}
+  virtual void SetUp() {
+    // fill the values
+    FillerParameter filler_param;
+    filler_param.set_value(1.);
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    filler.Fill(this->blob_bottom_2_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+
+  virtual ~ConvolutionLayerTest() {
+    delete blob_bottom_;
+    delete blob_bottom_2_;
+    delete blob_top_;
+    delete blob_top_2_;
+  }
+
+  virtual Blob<Dtype>* MakeReferenceTop(Blob<Dtype>* top) {
+    this->ref_blob_top_.reset(new Blob<Dtype>());
+    this->ref_blob_top_->ReshapeLike(*top);
+    return this->ref_blob_top_.get();
+  }
+
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_bottom_2_;
+  Blob<Dtype>* const blob_top_;
+  Blob<Dtype>* const blob_top_2_;
+  shared_ptr<Blob<Dtype> > ref_blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+
+typedef ::testing::Types<CPUDevice<float> > float_only;
+
+#define TestDtypesAndDevices float_only
+TYPED_TEST_CASE(ConvolutionLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(ConvolutionLayerTest, TestSetup) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(4);
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+
+  layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 4);
+  EXPECT_EQ(this->blob_top_->height(), 2);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+  EXPECT_EQ(this->blob_top_2_->num(), 2);
+  EXPECT_EQ(this->blob_top_2_->channels(), 4);
+  EXPECT_EQ(this->blob_top_2_->height(), 2);
+  EXPECT_EQ(this->blob_top_2_->width(), 1);
+  // setting group should not change the shape
+  convolution_param->set_num_output(3);
+  convolution_param->set_group(3);
+  layer.reset(new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 3);
+  EXPECT_EQ(this->blob_top_->height(), 2);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+  EXPECT_EQ(this->blob_top_2_->num(), 2);
+  EXPECT_EQ(this->blob_top_2_->channels(), 3);
+  EXPECT_EQ(this->blob_top_2_->height(), 2);
+  EXPECT_EQ(this->blob_top_2_->width(), 1);
+}
+
+
+TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+
+
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(3);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+
+  vector<int> bottom_shape;
+  bottom_shape.push_back(2);
+  bottom_shape.push_back(3);
+  bottom_shape.push_back(5);
+  bottom_shape.push_back(5);
+
+  this->blob_bottom_->Reshape(bottom_shape);
+  this->blob_bottom_2_->Reshape(bottom_shape);
+
+   fill_blob_data(this->blob_bottom_,0,1);
+   fill_blob_data(this->blob_bottom_2_,1,1);
+
+  layer_param.set_type("Convolution");
+
+  shared_ptr<Layer<Dtype> > layer=
+   LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+
+   //fill_blob_data(layer->blobs()[0].get(),1,1);
+
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+#ifdef LAYER_PERF_STAT
+  perf_stat * p_time_stat;
+  p_time_stat=layer->get_time_stat();
+  std::cout<<"start: "<<p_time_stat->start;
+  std::cout<<" end: "<<p_time_stat->end;
+  std::cout<<" used: "<<p_time_stat->used;
+  std::cout<<" total: "<<p_time_stat->total;
+  std::cout<<" count: "<<p_time_stat->count<<std::endl;
+#endif
+
+  // Check against reference convolution.
+  const Dtype* top_data;
+  const Dtype* ref_top_data;
+  caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_));
+  top_data = this->blob_top_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+
+
+//   dump_blob(this->blob_bottom_,"bottom.data");
+//   dump_blob(this->blob_top_,"top.data");
+//   dump_blob(this->ref_blob_top_.get(),"reftop.data");
+//   dump_blob(layer->blobs()[0].get(),"weight.data");
+//   dump_blob(layer->blobs()[1].get(),"bias.data");
+#if 1
+  caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_2_));
+  top_data = this->blob_top_2_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+#endif
+}
+
+#if 0
+
+TYPED_TEST(ConvolutionLayerTest, TestDilatedConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  vector<int> bottom_shape;
+  bottom_shape.push_back(2);
+  bottom_shape.push_back(3);
+  bottom_shape.push_back(8);
+  bottom_shape.push_back(7);
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_dilation(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+
+ layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  const Dtype* top_data;
+  const Dtype* ref_top_data;
+  caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(),
+             this->MakeReferenceTop(this->blob_top_));
+  top_data = this->blob_top_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+  caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(),
+             this->MakeReferenceTop(this->blob_top_2_));
+  top_data = this->blob_top_2_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+}
+
+TYPED_TEST(ConvolutionLayerTest, Test0DConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  const int kNumOutput = 3;
+  convolution_param->set_num_output(kNumOutput);
+  convolution_param->set_axis(3);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+
+
+ layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  vector<int> top_shape = this->blob_bottom_->shape();
+  top_shape[3] = kNumOutput;
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(top_shape, this->blob_top_->shape());
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  vector<int> weight_offset(2);
+  const Blob<Dtype>* weight = layer->blobs()[0].get();
+  const Blob<Dtype>* bias = layer->blobs()[1].get();
+  const int num = this->blob_top_->count(3);
+  const int dim = this->blob_top_->shape(3);
+  const int bottom_dim = this->blob_bottom_->shape(3);
+  for (int n = 0; n < num; ++n) {
+    for (int d = 0; d < dim; ++d) {
+      weight_offset[0] = d;
+      Dtype value = bias->cpu_data()[d];
+      for (int bottom_d = 0; bottom_d < bottom_dim; ++bottom_d) {
+        weight_offset[1] = bottom_d;
+        value += weight->data_at(weight_offset) *
+                 this->blob_bottom_->cpu_data()[n * bottom_dim + bottom_d];
+      }
+      EXPECT_NEAR(value, this->blob_top_->cpu_data()[n * dim + d], 1e-4);
+    }
+  }
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestSimple3DConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  vector<int> bottom_shape(5);
+  bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0);
+  bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1);
+  bottom_shape[2] = 5;
+  bottom_shape[3] = this->blob_bottom_vec_[0]->shape(2);
+  bottom_shape[4] = this->blob_bottom_vec_[0]->shape(3);
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+
+  layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+    
+  
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  const Dtype* top_data;
+  const Dtype* ref_top_data;
+  caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_));
+  top_data = this->blob_top_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+  caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_2_));
+  top_data = this->blob_top_2_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestDilated3DConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  vector<int> bottom_shape(5);
+  bottom_shape[0] = this->blob_bottom_vec_[0]->shape(0);
+  bottom_shape[1] = this->blob_bottom_vec_[0]->shape(1);
+  bottom_shape[2] = 6;
+  bottom_shape[3] = 7;
+  bottom_shape[4] = 8;
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_dilation(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+
+  layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+    
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  const Dtype* top_data;
+  const Dtype* ref_top_data;
+  caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(),
+             this->MakeReferenceTop(this->blob_top_));
+  top_data = this->blob_top_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+  caffe_conv(this->blob_bottom_2_, convolution_param, layer->blobs(),
+             this->MakeReferenceTop(this->blob_top_2_));
+  top_data = this->blob_top_2_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+}
+
+#endif
+
+TYPED_TEST(ConvolutionLayerTest, Test1x1Convolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+#if 0    
+  convolution_param->add_kernel_size(1);
+  convolution_param->set_num_output(2);
+
+  vector<int> bottom_shape;
+  bottom_shape.push_back(1);
+  bottom_shape.push_back(32);
+  bottom_shape.push_back(133);
+  bottom_shape.push_back(98);
+
+  this->blob_bottom_vec_[0]->Reshape(bottom_shape);
+
+#else 
+  convolution_param->add_kernel_size(1);
+  convolution_param->add_stride(1);
+  convolution_param->set_num_output(4);
+#endif
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(1);
+
+  layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  
+  fill_blob_data(this->blob_bottom_,1,1);
+  fill_blob_data(layer->blobs()[0].get(),1,1);
+
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  fill_blob_data(this->blob_bottom_,1,3);
+
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  const Dtype* top_data;
+  const Dtype* ref_top_data;
+  caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_));
+  top_data = this->blob_top_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  //   std::cout<<i<<": "<< top_data[i]<<" "<<ref_top_data[i]<<std::endl;
+  }
+}
+
+
+#if 0
+TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolutionGroup) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(3);
+  convolution_param->set_group(3);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+
+  layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+    
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  const Dtype* top_data;
+  const Dtype* ref_top_data;
+  caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_));
+  top_data = this->blob_top_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestSobelConvolution) {
+  // Test separable convolution by computing the Sobel operator
+  // as a single filter then comparing the result
+  // as the convolution of two rectangular filters.
+  typedef typename TypeParam::Dtype Dtype;
+  // Fill bottoms with identical Gaussian noise.
+  shared_ptr<GaussianFiller<Dtype> > filler;
+  FillerParameter filler_param;
+  filler_param.set_value(1.);
+  filler.reset(new GaussianFiller<Dtype>(filler_param));
+  filler->Fill(this->blob_bottom_);
+  this->blob_bottom_2_->CopyFrom(*this->blob_bottom_);
+  // Compute Sobel G_x operator as 3 x 3 convolution.
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(1);
+  convolution_param->set_bias_term(false);
+
+  layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+    
+
+  layer->blobs().resize(1);
+  layer->blobs()[0].reset(new Blob<Dtype>(1, 3, 3, 3));
+  Dtype* weights = layer->blobs()[0]->mutable_cpu_data();
+  for (int c = 0; c < 3; ++c) {
+    int i = c * 9;  // 3 x 3 filter
+    weights[i +  0] = -1;
+    weights[i +  1] =  0;
+    weights[i +  2] =  1;
+    weights[i +  3] = -2;
+    weights[i +  4] =  0;
+    weights[i +  5] =  2;
+    weights[i +  6] = -1;
+    weights[i +  7] =  0;
+    weights[i +  8] =  1;
+  }
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Compute Sobel G_x operator as separable 3 x 1 and 1 x 3 convolutions.
+  // (1) the [1 2 1] column filter
+  vector<Blob<Dtype>*> sep_blob_bottom_vec;
+  vector<Blob<Dtype>*> sep_blob_top_vec;
+  shared_ptr<Blob<Dtype> > blob_sep(new Blob<Dtype>());
+  sep_blob_bottom_vec.push_back(this->blob_bottom_2_);
+  sep_blob_top_vec.push_back(this->blob_top_2_);
+  convolution_param->clear_kernel_size();
+  convolution_param->clear_stride();
+  convolution_param->set_kernel_h(3);
+  convolution_param->set_kernel_w(1);
+  convolution_param->set_stride_h(2);
+  convolution_param->set_stride_w(1);
+  convolution_param->set_num_output(1);
+  convolution_param->set_bias_term(false);
+  layer.reset(new ConvolutionLayer<Dtype>(layer_param));
+  layer->blobs().resize(1);
+  layer->blobs()[0].reset(new Blob<Dtype>(1, 3, 3, 1));
+  Dtype* weights_1 = layer->blobs()[0]->mutable_cpu_data();
+  for (int c = 0; c < 3; ++c) {
+    int i = c * 3;  // 3 x 1 filter
+    weights_1[i +  0] = 1;
+    weights_1[i +  1] = 2;
+    weights_1[i +  2] = 1;
+  }
+  layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec);
+  layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec);
+  // (2) the [-1 0 1] row filter
+  blob_sep->CopyFrom(*this->blob_top_2_, false, true);
+  sep_blob_bottom_vec.clear();
+  sep_blob_bottom_vec.push_back(blob_sep.get());
+  convolution_param->set_kernel_h(1);
+  convolution_param->set_kernel_w(3);
+  convolution_param->set_stride_h(1);
+  convolution_param->set_stride_w(2);
+  convolution_param->set_num_output(1);
+  convolution_param->set_bias_term(false);
+  layer.reset(new ConvolutionLayer<Dtype>(layer_param));
+  layer->blobs().resize(1);
+  layer->blobs()[0].reset(new Blob<Dtype>(1, 1, 1, 3));
+  Dtype* weights_2 = layer->blobs()[0]->mutable_cpu_data();
+  weights_2[0] = -1;
+  weights_2[1] =  0;
+  weights_2[2] =  1;
+  layer->SetUp(sep_blob_bottom_vec, sep_blob_top_vec);
+  layer->Forward(sep_blob_bottom_vec, sep_blob_top_vec);
+  // Test equivalence of full and separable filters.
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  const Dtype* sep_top_data = this->blob_top_2_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], sep_top_data[i], 1e-4);
+  }
+}
+
+TYPED_TEST(ConvolutionLayerTest, TestNDAgainst2D) {
+  typedef typename TypeParam::Dtype Dtype;
+  const int kernel_h = 11;
+  const int kernel_w = 13;
+  vector<int> bottom_shape(4);
+  bottom_shape[0] = 15;
+  bottom_shape[1] = 18;
+  bottom_shape[2] = kernel_h * 2;
+  bottom_shape[3] = kernel_w * 2;
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_num_output(12);
+  convolution_param->set_bias_term(false);
+  convolution_param->set_group(6);
+  convolution_param->set_kernel_h(kernel_h);
+  convolution_param->set_kernel_w(kernel_w);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  Blob<Dtype> weights;
+  Blob<Dtype> top_diff;
+  // Shape and fill weights and top_diff.
+  bool copy_diff;
+  bool reshape;
+  {
+    ConvolutionLayer<Dtype> layer(layer_param);
+    layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    top_diff.ReshapeLike(*this->blob_top_);
+    filler.Fill(&top_diff);
+    ASSERT_EQ(1, layer.blobs().size());
+    copy_diff = false; reshape = true;
+    weights.CopyFrom(*layer.blobs()[0], copy_diff, reshape);
+  }
+  vector<bool> propagate_down(1, true);
+  Blob<Dtype> result_2d;
+  Blob<Dtype> backward_result_2d;
+  Blob<Dtype> backward_weight_result_2d;
+  // Test with 2D im2col
+  {
+    caffe_set(this->blob_top_->count(), Dtype(0),
+              this->blob_top_->mutable_cpu_data());
+    caffe_set(this->blob_bottom_->count(), Dtype(0),
+              this->blob_bottom_->mutable_cpu_diff());
+    caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+    // Do SetUp and Forward; save Forward result in result_2d.
+    convolution_param->set_force_nd_im2col(false);
+
+    layer_param.set_type("Convolution");
+
+    Layer<Dtype> & layer_2d=*LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+
+    layer_2d.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    ASSERT_EQ(1, layer_2d.blobs().size());
+    copy_diff = false; reshape = false;
+    layer_2d.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+    layer_2d.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    copy_diff = false; reshape = true;
+    result_2d.CopyFrom(*this->blob_top_, copy_diff, reshape);
+    // Copy pre-generated top diff into actual top diff;
+    // do Backward and save result in backward_result_2d.
+    ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+    caffe_copy(top_diff.count(), top_diff.cpu_data(),
+               this->blob_top_->mutable_cpu_diff());
+    layer_2d.Backward(this->blob_top_vec_, propagate_down,
+                      this->blob_bottom_vec_);
+    copy_diff = true; reshape = true;
+    backward_result_2d.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+    backward_weight_result_2d.CopyFrom(weights, copy_diff, reshape);
+  }
+  Blob<Dtype> result_nd;
+  Blob<Dtype> backward_result_nd;
+  Blob<Dtype> backward_weight_result_nd;
+  // Test with ND im2col
+  {
+    caffe_set(this->blob_top_->count(), Dtype(0),
+              this->blob_top_->mutable_cpu_data());
+    caffe_set(this->blob_bottom_->count(), Dtype(0),
+              this->blob_bottom_->mutable_cpu_diff());
+    caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+    // Do SetUp and Forward; save Forward result in result_nd.
+    convolution_param->set_force_nd_im2col(true);
+
+   layer_param.set_type("Convolution");
+
+    Layer<Dtype>& layer_nd=*LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+    layer_nd.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    ASSERT_EQ(1, layer_nd.blobs().size());
+    copy_diff = false; reshape = false;
+    layer_nd.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+    layer_nd.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    copy_diff = false; reshape = true;
+    result_nd.CopyFrom(*this->blob_top_, copy_diff, reshape);
+    // Copy pre-generated top diff into actual top diff;
+    // do Backward and save result in backward_result_nd.
+    ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+    caffe_copy(top_diff.count(), top_diff.cpu_data(),
+               this->blob_top_->mutable_cpu_diff());
+    layer_nd.Backward(this->blob_top_vec_, propagate_down,
+                      this->blob_bottom_vec_);
+    copy_diff = true; reshape = true;
+    backward_result_nd.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+    backward_weight_result_nd.CopyFrom(weights, copy_diff, reshape);
+  }
+  ASSERT_EQ(result_nd.count(), result_2d.count());
+  for (int i = 0; i < result_2d.count(); ++i)  {
+    EXPECT_EQ(result_2d.cpu_data()[i], result_nd.cpu_data()[i]);
+  }
+  ASSERT_EQ(backward_result_nd.count(), backward_result_2d.count());
+  for (int i = 0; i < backward_result_2d.count(); ++i) {
+    EXPECT_EQ(backward_result_2d.cpu_diff()[i],
+              backward_result_nd.cpu_diff()[i]);
+  }
+  ASSERT_EQ(backward_weight_result_nd.count(),
+            backward_weight_result_2d.count());
+  for (int i = 0; i < backward_weight_result_2d.count(); ++i) {
+    EXPECT_EQ(backward_weight_result_2d.cpu_diff()[i],
+              backward_weight_result_nd.cpu_diff()[i]);
+  }
+}
+#endif
+
+
+}  // namespace caffe
diff --git a/unit_tests/test_fail.cpp b/unit_tests/test_fail.cpp
new file mode 100644 (file)
index 0000000..0bc5495
--- /dev/null
@@ -0,0 +1,419 @@
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/conv_layer.hpp"
+
+#ifdef USE_CUDNN
+#include "caffe/layers/cudnn_conv_layer.hpp"
+#endif
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+#include <iomanip>
+
+namespace caffe {
+
+template <typename Dtype>
+void dump_blob(const Blob<Dtype> * blob, const char * outfile)
+{
+   std::ofstream os;
+   os.open(outfile);
+
+   os<<setiosflags(ios::fixed);
+
+   for(int i=0;i<blob->LegacyShape(0);i++)
+   {
+
+     for(int j=0;j<blob->LegacyShape(1);j++)
+     {
+
+        for(int k=0;k<blob->LegacyShape(2);k++)
+        {
+            for(int l=0;l<blob->LegacyShape(3);l++)
+            {
+                Dtype data=blob->data_at(i,j,k,l);
+                os<<std::setprecision(12)<<data<<", ";
+            }
+            os<<std::endl;
+        }
+      os<<std::endl;
+    }
+     os<<std::endl;
+   }
+
+   os.close();
+
+}
+
+
+template <typename Dtype>
+void fill_blob_data(Blob<Dtype >* bottom, int fixed, float val)
+{
+    for(int i=0;i<bottom->num();i++)
+      for(int j=0;j<bottom->channels();j++)
+        for(int l=0;l<bottom->height();l++)
+          for(int k=0;k<bottom->width();k++)
+        {
+           int offset;
+           Dtype * ptr;
+
+            offset=i*bottom->channels()*bottom->height()*bottom->width()+
+                    j*bottom->height()*bottom->width()+
+                   l*bottom->width()+k;
+
+           ptr=bottom->mutable_cpu_data();
+
+           if(fixed)
+              ptr[offset]=val;
+           else
+              ptr[offset]=offset+100;
+
+        }
+
+
+}
+
+
+template <typename Dtype>
+void load_blob_data(Blob<Dtype >* bottom, Dtype * p_data)
+{
+    for(int i=0;i<bottom->num();i++)
+      for(int j=0;j<bottom->channels();j++)
+        for(int l=0;l<bottom->height();l++)
+          for(int k=0;k<bottom->width();k++)
+        {
+           int offset;
+           Dtype * ptr;
+
+            offset=i*bottom->channels()*bottom->height()*bottom->width()+
+                    j*bottom->height()*bottom->width()+
+                   l*bottom->width()+k;
+
+           ptr=bottom->mutable_cpu_data();
+
+            ptr[offset]=p_data[offset];
+
+        }
+
+}
+
+
+
+// Reference convolution for checking results:
+// accumulate through explicit loops over input, output, and filters.
+template <typename Dtype>
+void caffe_conv(const Blob<Dtype>* in, ConvolutionParameter* conv_param,
+    const vector<shared_ptr<Blob<Dtype> > >& weights,
+    Blob<Dtype>* out) {
+  const bool has_depth = (out->num_axes() == 5);
+  if (!has_depth) { CHECK_EQ(4, out->num_axes()); }
+  // Kernel size, stride, and pad
+  int kernel_h, kernel_w;
+  if (conv_param->has_kernel_h() || conv_param->has_kernel_w()) {
+    kernel_h = conv_param->kernel_h();
+    kernel_w = conv_param->kernel_w();
+  } else {
+    kernel_h = kernel_w = conv_param->kernel_size(0);
+  }
+  int pad_h, pad_w;
+  if (conv_param->has_pad_h() || conv_param->has_pad_w()) {
+    pad_h = conv_param->pad_h();
+    pad_w = conv_param->pad_w();
+  } else {
+    pad_h = pad_w = conv_param->pad_size() ? conv_param->pad(0) : 0;
+  }
+  int stride_h, stride_w;
+  if (conv_param->has_stride_h() || conv_param->has_stride_w()) {
+    stride_h = conv_param->stride_h();
+    stride_w = conv_param->stride_w();
+  } else {
+    stride_h = stride_w = conv_param->stride_size() ? conv_param->stride(0) : 1;
+  }
+  int dilation_h, dilation_w;
+  dilation_h = dilation_w = conv_param->dilation_size() ?
+                            conv_param->dilation(0) : 1;
+  int kernel_d, pad_d, stride_d, dilation_d;
+  if (has_depth) {
+    kernel_d = kernel_h;
+    stride_d = stride_h;
+    pad_d = pad_h;
+    dilation_d = dilation_h;
+  } else {
+    kernel_d = stride_d = dilation_d = 1;
+    pad_d = 0;
+  }
+  // Groups
+  int groups = conv_param->group();
+  int o_g = out->shape(1) / groups;
+  int k_g = in->shape(1) / groups;
+  int o_head, k_head;
+  // Convolution
+  vector<int> weight_offset(4 + has_depth);
+  vector<int> in_offset(4 + has_depth);
+  vector<int> out_offset(4 + has_depth);
+  Dtype* out_data = out->mutable_cpu_data();
+  for (int n = 0; n < out->shape(0); n++) {
+    for (int g = 0; g < groups; g++) {
+      o_head = o_g * g;
+      k_head = k_g * g;
+      for (int o = 0; o < o_g; o++) {
+        for (int k = 0; k < k_g; k++) {
+          for (int z = 0; z < (has_depth ? out->shape(2) : 1); z++) {
+            for (int y = 0; y < out->shape(2 + has_depth); y++) {
+              for (int x = 0; x < out->shape(3 + has_depth); x++) {
+                for (int r = 0; r < kernel_d; r++) {
+                  for (int p = 0; p < kernel_h; p++) {
+                    for (int q = 0; q < kernel_w; q++) {
+                      int in_z = z * stride_d - pad_d + r * dilation_d;
+                      int in_y = y * stride_h - pad_h + p * dilation_h;
+                      int in_x = x * stride_w - pad_w + q * dilation_w;
+                      if (in_z >= 0 && in_z < (has_depth ? in->shape(2) : 1)
+                          && in_y >= 0 && in_y < in->shape(2 + has_depth)
+                          && in_x >= 0 && in_x < in->shape(3 + has_depth)) {
+                        weight_offset[0] = o + o_head;
+                        weight_offset[1] = k;
+                        if (has_depth) { weight_offset[2] = r; }
+                        weight_offset[2 + has_depth] = p;
+                        weight_offset[3 + has_depth] = q;
+                        in_offset[0] = n;
+                        in_offset[1] = k + k_head;
+                        if (has_depth) { in_offset[2] = in_z; }
+                        in_offset[2 + has_depth] = in_y;
+                        in_offset[3 + has_depth] = in_x;
+                        out_offset[0] = n;
+                        out_offset[1] = o + o_head;
+                        if (has_depth) { out_offset[2] = z; }
+                        out_offset[2 + has_depth] = y;
+                        out_offset[3 + has_depth] = x;
+                        out_data[out->offset(out_offset)] +=
+                            in->data_at(in_offset)
+                            * weights[0]->data_at(weight_offset);
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+  // Bias
+  if (conv_param->bias_term()) {
+    const Dtype* bias_data = weights[1]->cpu_data();
+    for (int n = 0; n < out->shape(0); n++) {
+      for (int o = 0; o < out->shape(1); o++) {
+        for (int z = 0; z < (has_depth ? out->shape(2) : 1); z++) {
+          for (int y = 0; y < out->shape(2 + has_depth); y++) {
+            for (int x = 0; x < out->shape(3 + has_depth); x++) {
+              out_offset[0] = n;
+              out_offset[1] = o;
+              if (has_depth) { out_offset[2] = z; }
+              out_offset[2 + has_depth] = y;
+              out_offset[3 + has_depth] = x;
+              out_data[out->offset(out_offset)] += bias_data[o];
+            }
+          }
+        }
+      }
+    }
+  }
+}
+
+template void caffe_conv(const Blob<float>* in,
+    ConvolutionParameter* conv_param,
+    const vector<shared_ptr<Blob<float> > >& weights,
+    Blob<float>* out);
+template void caffe_conv(const Blob<double>* in,
+    ConvolutionParameter* conv_param,
+    const vector<shared_ptr<Blob<double> > >& weights,
+    Blob<double>* out);
+
+template <typename TypeParam>
+class ConvolutionLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  ConvolutionLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_bottom_2_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_top_(new Blob<Dtype>()),
+        blob_top_2_(new Blob<Dtype>()) {}
+  virtual void SetUp() {
+    // fill the values
+    FillerParameter filler_param;
+    filler_param.set_value(1.);
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    filler.Fill(this->blob_bottom_2_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+
+  virtual ~ConvolutionLayerTest() {
+    delete blob_bottom_;
+    delete blob_bottom_2_;
+    delete blob_top_;
+    delete blob_top_2_;
+  }
+
+  virtual Blob<Dtype>* MakeReferenceTop(Blob<Dtype>* top) {
+    this->ref_blob_top_.reset(new Blob<Dtype>());
+    this->ref_blob_top_->ReshapeLike(*top);
+    return this->ref_blob_top_.get();
+  }
+
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_bottom_2_;
+  Blob<Dtype>* const blob_top_;
+  Blob<Dtype>* const blob_top_2_;
+  shared_ptr<Blob<Dtype> > ref_blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+
+typedef ::testing::Types<CPUDevice<float> > float_only;
+
+#define TestDtypesAndDevices float_only
+TYPED_TEST_CASE(ConvolutionLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(ConvolutionLayerTest, TestSetup) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(4);
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+
+  layer_param.set_type("Convolution");
+  shared_ptr<Layer<Dtype> > layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 4);
+  EXPECT_EQ(this->blob_top_->height(), 2);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+  EXPECT_EQ(this->blob_top_2_->num(), 2);
+  EXPECT_EQ(this->blob_top_2_->channels(), 4);
+  EXPECT_EQ(this->blob_top_2_->height(), 2);
+  EXPECT_EQ(this->blob_top_2_->width(), 1);
+  // setting group should not change the shape
+  convolution_param->set_num_output(3);
+  convolution_param->set_group(3);
+  layer.reset(new ConvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 3);
+  EXPECT_EQ(this->blob_top_->height(), 2);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+  EXPECT_EQ(this->blob_top_2_->num(), 2);
+  EXPECT_EQ(this->blob_top_2_->channels(), 3);
+  EXPECT_EQ(this->blob_top_2_->height(), 2);
+  EXPECT_EQ(this->blob_top_2_->width(), 1);
+}
+
+float fail3_weight[]={
+-0.850632905960, -1.578843951225, -0.890021681786, 
+0.971448659897, -0.538104891777, 0.233876436949, 
+-1.242745161057, 2.211859703064, 0.525026142597, 
+
+-1.726792931557, -1.194667577744, 1.119420289993, 
+-1.539444208145, 1.725312829018, -1.573384165764, 
+0.519557833672, 0.376551657915, -0.615215837955, 
+
+0.758795797825, -0.498177528381, 0.254181325436, 
+-0.071698464453, -1.192728281021, 0.776199519634, 
+1.837580919266, -0.478745609522, -0.804457962513, 
+
+
+-2.220808744431, -0.892578184605, -1.422935843468, 
+-1.707052111626, -1.837757468224, -1.312300324440, 
+-1.251585721970, -1.591378808022, -0.577652215958, 
+
+1.727164268494, 0.176050186157, -1.804216146469, 
+0.547152698040, -0.024264926091, -2.040683984756, 
+-2.159983396530, 1.692966818810, -1.558626413345, 
+
+-1.242013096809, 0.122898645699, -0.146973758936, 
+-0.405744194984, -1.716119289398, 1.215066313744, 
+1.061164021492, -0.705341339111, -0.245370775461, 
+
+
+0.781007647514, -0.104610890150, 2.421228170395, 
+0.348720043898, 0.289468020201, 1.841132760048, 
+-0.835199236870, -0.242239400744, 1.169079542160, 
+
+0.165550187230, -0.418082803488, 0.479667782784, 
+-0.241552516818, 0.767971694469, -0.760977804661, 
+-2.419095993042, 0.774254024029, 0.541432976723, 
+
+0.855292022228, -0.144438281655, 0.251998007298, 
+-0.242634430528, -0.044748753309, -0.321820944548, 
+-0.487676948309, -0.761075556278, -0.646164357662
+};
+
+TYPED_TEST(ConvolutionLayerTest, TestSimpleConvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(1);
+  convolution_param->set_num_output(3);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+
+  vector<int> bottom_shape;
+  bottom_shape.push_back(1);
+  bottom_shape.push_back(3);
+  bottom_shape.push_back(5);
+  bottom_shape.push_back(5);
+
+  this->blob_bottom_->Reshape(bottom_shape);
+
+   fill_blob_data(this->blob_bottom_,0,1);
+
+  layer_param.set_type("Convolution");
+
+  shared_ptr<Layer<Dtype> > layer=
+   LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+
+   //fill_blob_data(layer->blobs()[0].get(),1,1);
+   load_blob_data(layer->blobs()[0].get(),fail3_weight);
+
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Check against reference convolution.
+  const Dtype* top_data;
+  const Dtype* ref_top_data;
+  caffe_conv(this->blob_bottom_, convolution_param, layer->blobs(),
+      this->MakeReferenceTop(this->blob_top_));
+  top_data = this->blob_top_->cpu_data();
+  ref_top_data = this->ref_blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], ref_top_data[i], 1e-4);
+  }
+
+
+   dump_blob(this->blob_bottom_,"bottom.data");
+   dump_blob(this->blob_top_,"top.data");
+   dump_blob(this->ref_blob_top_.get(),"reftop.data");
+   dump_blob(layer->blobs()[0].get(),"weight.data");
+   dump_blob(layer->blobs()[1].get(),"bias.data");
+}
+
+}
diff --git a/unit_tests/test_inner_product_layer.cpp b/unit_tests/test_inner_product_layer.cpp
new file mode 100644 (file)
index 0000000..efe3d87
--- /dev/null
@@ -0,0 +1,295 @@
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/inner_product_layer.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+#ifndef CPU_ONLY
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+#endif
+
+template <typename TypeParam>
+class InnerProductLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+ protected:
+  InnerProductLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+        blob_bottom_nobatch_(new Blob<Dtype>(1, 2, 3, 4)),
+        blob_top_(new Blob<Dtype>()) {
+    // fill the values
+    FillerParameter filler_param;
+    UniformFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~InnerProductLayerTest() {
+    delete blob_bottom_;
+    delete blob_bottom_nobatch_;
+    delete blob_top_;
+  }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_bottom_nobatch_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<CPUDevice<float> > float_only;
+
+#define TestDtypesAndDevices float_only
+
+TYPED_TEST_CASE(InnerProductLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(InnerProductLayerTest, TestSetUp) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_);
+  LayerParameter layer_param;
+  InnerProductParameter* inner_product_param =
+      layer_param.mutable_inner_product_param();
+  inner_product_param->set_num_output(10);
+  
+   layer_param.set_type("InnerProduct");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<InnerProductLayer<Dtype> > layer=
+   boost::static_pointer_cast<InnerProductLayer<Dtype>  > (new_layer);
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->height(), 1);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+  EXPECT_EQ(this->blob_top_->channels(), 10);
+}
+
+/** @brief TestSetUp while toggling transpose flag
+ */
+TYPED_TEST(InnerProductLayerTest, TestSetUpTransposeFalse) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_);
+  LayerParameter layer_param;
+  InnerProductParameter* inner_product_param =
+      layer_param.mutable_inner_product_param();
+  inner_product_param->set_num_output(10);
+  inner_product_param->set_transpose(false);
+
+    layer_param.set_type("InnerProduct");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<InnerProductLayer<Dtype> > layer=
+   boost::static_pointer_cast<InnerProductLayer<Dtype>  > (new_layer);
+
+
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(2, this->blob_top_->num());
+  EXPECT_EQ(1, this->blob_top_->height());
+  EXPECT_EQ(1, this->blob_top_->width());
+  EXPECT_EQ(10, this->blob_top_->channels());
+  EXPECT_EQ(2, layer->blobs()[0]->num_axes());
+  EXPECT_EQ(10, layer->blobs()[0]->shape(0));
+  EXPECT_EQ(60, layer->blobs()[0]->shape(1));
+}
+
+/** @brief TestSetUp while toggling transpose flag
+ */
+TYPED_TEST(InnerProductLayerTest, TestSetUpTransposeTrue) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_);
+  LayerParameter layer_param;
+  InnerProductParameter* inner_product_param =
+      layer_param.mutable_inner_product_param();
+  inner_product_param->set_num_output(10);
+  inner_product_param->set_transpose(true);
+
+  layer_param.set_type("InnerProduct");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<InnerProductLayer<Dtype> > layer=
+   boost::static_pointer_cast<InnerProductLayer<Dtype>  > (new_layer);
+
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(2, this->blob_top_->num());
+  EXPECT_EQ(1, this->blob_top_->height());
+  EXPECT_EQ(1, this->blob_top_->width());
+  EXPECT_EQ(10, this->blob_top_->channels());
+  EXPECT_EQ(2, layer->blobs()[0]->num_axes());
+  EXPECT_EQ(60, layer->blobs()[0]->shape(0));
+  EXPECT_EQ(10, layer->blobs()[0]->shape(1));
+}
+
+TYPED_TEST(InnerProductLayerTest, TestForward) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_);
+  bool IS_VALID_CUDA = false;
+#ifndef CPU_ONLY
+  IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2;
+#endif
+  if (Caffe::mode() == Caffe::CPU ||
+      sizeof(Dtype) == 4 || IS_VALID_CUDA) {
+    LayerParameter layer_param;
+    InnerProductParameter* inner_product_param =
+        layer_param.mutable_inner_product_param();
+    inner_product_param->set_num_output(10);
+    inner_product_param->mutable_weight_filler()->set_type("uniform");
+    inner_product_param->mutable_bias_filler()->set_type("uniform");
+    inner_product_param->mutable_bias_filler()->set_min(1);
+    inner_product_param->mutable_bias_filler()->set_max(2);
+
+      layer_param.set_type("InnerProduct");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<InnerProductLayer<Dtype> > layer=
+   boost::static_pointer_cast<InnerProductLayer<Dtype>  > (new_layer);
+
+
+    layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    const Dtype* data = this->blob_top_->cpu_data();
+    const int count = this->blob_top_->count();
+    for (int i = 0; i < count; ++i) {
+      EXPECT_GE(data[i], 1.);
+    }
+  } else {
+    LOG(ERROR) << "Skipping test due to old architecture.";
+  }
+}
+
+/**
+ * @brief Init. an IP layer without transpose + random weights,
+ * run Forward, save the result.
+ * Init. another IP layer with transpose.
+ * manually copy and transpose the weights from the first IP layer,
+ * then run Forward on the same input and check that the result is the same
+ */
+TYPED_TEST(InnerProductLayerTest, TestForwardTranspose) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_);
+  bool IS_VALID_CUDA = false;
+#ifndef CPU_ONLY
+  IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2;
+#endif
+  if (Caffe::mode() == Caffe::CPU ||
+      sizeof(Dtype) == 4 || IS_VALID_CUDA) {
+    LayerParameter layer_param;
+    InnerProductParameter* inner_product_param =
+        layer_param.mutable_inner_product_param();
+    inner_product_param->set_num_output(10);
+    inner_product_param->mutable_weight_filler()->set_type("uniform");
+    inner_product_param->mutable_bias_filler()->set_type("uniform");
+    inner_product_param->mutable_bias_filler()->set_min(1);
+    inner_product_param->mutable_bias_filler()->set_max(2);
+    inner_product_param->set_transpose(false);
+    
+    layer_param.set_type("InnerProduct");
+
+    shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+    shared_ptr<InnerProductLayer<Dtype> > layer=
+    boost::static_pointer_cast<InnerProductLayer<Dtype>  > (new_layer);
+
+
+    layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    const int count = this->blob_top_->count();
+    Blob<Dtype>* const top = new Blob<Dtype>();
+    top->ReshapeLike(*this->blob_top_);
+    caffe_copy(count, this->blob_top_->cpu_data(), top->mutable_cpu_data());
+    this->blob_top_vec_.clear();
+    this->blob_top_vec_.push_back(new Blob<Dtype>());
+    inner_product_param->set_transpose(true);
+    shared_ptr<InnerProductLayer<Dtype> > ip_t(
+        new InnerProductLayer<Dtype>(layer_param));
+    ip_t->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    const int count_w = layer->blobs()[0]->count();
+    EXPECT_EQ(count_w, ip_t->blobs()[0]->count());
+    // manually copy and transpose the weights from 1st IP layer into 2nd
+    const Dtype* w = layer->blobs()[0]->cpu_data();
+    Dtype* w_t = ip_t->blobs()[0]->mutable_cpu_data();
+    const int width = layer->blobs()[0]->shape(1);
+    const int width_t = ip_t->blobs()[0]->shape(1);
+    for (int i = 0; i < count_w; ++i) {
+      int r = i / width;
+      int c = i % width;
+      w_t[c*width_t+r] = w[r*width+c];  // copy while transposing
+    }
+    // copy bias from 1st IP layer to 2nd IP layer
+    ASSERT_EQ(layer->blobs()[1]->count(), ip_t->blobs()[1]->count());
+    caffe_copy(layer->blobs()[1]->count(), layer->blobs()[1]->cpu_data(),
+        ip_t->blobs()[1]->mutable_cpu_data());
+    ip_t->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    EXPECT_EQ(count, this->blob_top_->count())
+        << "Invalid count for top blob for IP with transpose.";
+    Blob<Dtype>* const top_t = new Blob<Dtype>();\
+    top_t->ReshapeLike(*this->blob_top_vec_[0]);
+    caffe_copy(count,
+      this->blob_top_vec_[0]->cpu_data(),
+      top_t->mutable_cpu_data());
+    const Dtype* data = top->cpu_data();
+    const Dtype* data_t = top_t->cpu_data();
+    for (int i = 0; i < count; ++i) {
+      EXPECT_FLOAT_EQ(data[i], data_t[i]);
+    }
+  } else {
+    LOG(ERROR) << "Skipping test due to old architecture.";
+  }
+}
+
+TYPED_TEST(InnerProductLayerTest, TestForwardNoBatch) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_nobatch_);
+  bool IS_VALID_CUDA = false;
+#ifndef CPU_ONLY
+  IS_VALID_CUDA = CAFFE_TEST_CUDA_PROP.major >= 2;
+#endif
+  if (Caffe::mode() == Caffe::CPU ||
+      sizeof(Dtype) == 4 || IS_VALID_CUDA) {
+    LayerParameter layer_param;
+    InnerProductParameter* inner_product_param =
+        layer_param.mutable_inner_product_param();
+    inner_product_param->set_num_output(10);
+    inner_product_param->mutable_weight_filler()->set_type("uniform");
+    inner_product_param->mutable_bias_filler()->set_type("uniform");
+    inner_product_param->mutable_bias_filler()->set_min(1);
+    inner_product_param->mutable_bias_filler()->set_max(2);
+
+    layer_param.set_type("InnerProduct");
+
+    shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+    shared_ptr<InnerProductLayer<Dtype> > layer=
+    boost::static_pointer_cast<InnerProductLayer<Dtype>  > (new_layer);
+
+
+    layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    const Dtype* data = this->blob_top_->cpu_data();
+    const int count = this->blob_top_->count();
+    for (int i = 0; i < count; ++i) {
+      EXPECT_GE(data[i], 1.);
+    }
+  } else {
+    LOG(ERROR) << "Skipping test due to old architecture.";
+  }
+}
+
+
+}  // namespace caffe
diff --git a/unit_tests/test_lrn_layer.cpp b/unit_tests/test_lrn_layer.cpp
new file mode 100644 (file)
index 0000000..22ae8c6
--- /dev/null
@@ -0,0 +1,344 @@
+#include <algorithm>
+#include <vector>
+#include <iostream>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/lrn_layer.hpp"
+
+#ifdef USE_CUDNN
+#include "caffe/layers/cudnn_lcn_layer.hpp"
+#include "caffe/layers/cudnn_lrn_layer.hpp"
+#endif
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+using std::min;
+using std::max;
+
+int test_h=5;
+int test_w=5;
+
+namespace caffe {
+
+template <typename Dtype>
+static void dump_blob(const Blob<Dtype> * blob, const char * outfile)
+{
+   std::ofstream os;
+   os.open(outfile);
+
+   for(int i=0;i<blob->shape(0);i++)
+     for(int j=0;j<blob->shape(1);j++)
+        for(int k=0;k<blob->shape(2);k++)
+            for(int l=0;l<blob->shape(3);l++)
+       {
+          Dtype data=blob->data_at(i,j,k,l);
+
+          os<<data<<std::endl;
+        }
+
+   os.close();
+
+}
+
+template <typename Dtype>
+static void fill_blob_data(Blob<Dtype >* bottom, int fixed, float val)
+{
+    for(int i=0;i<bottom->num();i++)
+      for(int j=0;j<bottom->channels();j++)
+        for(int l=0;l<bottom->height();l++)
+          for(int k=0;k<bottom->width();k++)
+        {
+           int offset;
+           Dtype * ptr;
+
+            offset=i*bottom->channels()*bottom->height()*bottom->width()+
+                    j*bottom->height()*bottom->width()+
+                   l*bottom->width()+k;
+
+           ptr=bottom->mutable_cpu_data();
+
+           if(fixed)
+              ptr[offset]=val;
+           else
+              ptr[offset]=offset;
+
+        }
+
+
+}
+
+
+template <typename TypeParam>
+class LRNLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  LRNLayerTest()
+      : epsilon_(Dtype(1e-5)),
+        blob_bottom_(new Blob<Dtype>()),
+        blob_top_(new Blob<Dtype>()) {}
+  virtual void SetUp() {
+    Caffe::set_random_seed(1701);
+    blob_bottom_->Reshape(2, 7, test_h,test_w);
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~LRNLayerTest() { delete blob_bottom_; delete blob_top_; }
+  void ReferenceLRNForward(const Blob<Dtype>& blob_bottom,
+      const LayerParameter& layer_param, Blob<Dtype>* blob_top);
+
+  Dtype epsilon_;
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+template <typename TypeParam>
+void LRNLayerTest<TypeParam>::ReferenceLRNForward(
+    const Blob<Dtype>& blob_bottom, const LayerParameter& layer_param,
+    Blob<Dtype>* blob_top) {
+  typedef typename TypeParam::Dtype Dtype;
+  blob_top->Reshape(blob_bottom.num(), blob_bottom.channels(),
+      blob_bottom.height(), blob_bottom.width());
+  Dtype* top_data = blob_top->mutable_cpu_data();
+  LRNParameter lrn_param = layer_param.lrn_param();
+  Dtype alpha = lrn_param.alpha();
+  Dtype beta = lrn_param.beta();
+  int size = lrn_param.local_size();
+  switch (lrn_param.norm_region()) {
+  case LRNParameter_NormRegion_ACROSS_CHANNELS:
+    for (int n = 0; n < blob_bottom.num(); ++n) {
+      for (int c = 0; c < blob_bottom.channels(); ++c) {
+        for (int h = 0; h < blob_bottom.height(); ++h) {
+          for (int w = 0; w < blob_bottom.width(); ++w) {
+            int c_start = c - (size - 1) / 2;
+            int c_end = min(c_start + size, blob_bottom.channels());
+            c_start = max(c_start, 0);
+            Dtype scale = 1.;
+            for (int i = c_start; i < c_end; ++i) {
+              Dtype value = blob_bottom.data_at(n, i, h, w);
+              scale += value * value * alpha / size;
+            }
+            *(top_data + blob_top->offset(n, c, h, w)) =
+              blob_bottom.data_at(n, c, h, w) / pow(scale, beta);
+          }
+        }
+      }
+    }
+    break;
+  case LRNParameter_NormRegion_WITHIN_CHANNEL:
+    for (int n = 0; n < blob_bottom.num(); ++n) {
+      for (int c = 0; c < blob_bottom.channels(); ++c) {
+        for (int h = 0; h < blob_bottom.height(); ++h) {
+          int h_start = h - (size - 1) / 2;
+          int h_end = min(h_start + size, blob_bottom.height());
+          h_start = max(h_start, 0);
+          for (int w = 0; w < blob_bottom.width(); ++w) {
+            Dtype scale = 1.;
+            int w_start = w - (size - 1) / 2;
+            int w_end = min(w_start + size, blob_bottom.width());
+            w_start = max(w_start, 0);
+
+//            std::cout<<"h,w ("<<h<<","<<w<<"): ";
+//            std::cout<<"box: ( h "<<h_start<<","<<h_end<<")";
+//           std::cout<<" (w "<<w_start<<","<<w_end<<")"<<std::endl;
+
+            for (int nh = h_start; nh < h_end; ++nh) {
+              for (int nw = w_start; nw < w_end; ++nw) {
+                Dtype value = blob_bottom.data_at(n, c, nh, nw);
+                scale += value * value * alpha / (size * size);
+              }
+            }
+            *(top_data + blob_top->offset(n, c, h, w)) =
+              blob_bottom.data_at(n, c, h, w) / pow(scale, beta);
+          }
+        }
+      }
+    }
+    break;
+  default:
+    LOG(FATAL) << "Unknown normalization region.";
+  }
+}
+
+typedef ::testing::Types<CPUDevice<float> > float_only;
+
+#define TestDtypesAndDevices float_only
+
+TYPED_TEST_CASE(LRNLayerTest, TestDtypesAndDevices);
+
+#if 1
+TYPED_TEST(LRNLayerTest, TestSetupAcrossChannels) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  LRNLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 7);
+  EXPECT_EQ(this->blob_top_->height(), test_h);
+  EXPECT_EQ(this->blob_top_->width(), test_w);
+}
+
+TYPED_TEST(LRNLayerTest, TestForwardAcrossChannels) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+//  LRNLayer<Dtype> layer(layer_param);
+
+  layer_param.mutable_lrn_param()->set_local_size(3);
+
+  layer_param.set_type("LRN");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<LRNLayer<Dtype> > layer=
+   boost::static_pointer_cast<LRNLayer<Dtype>  > (new_layer);
+
+  vector<int> bottom_shape;
+  bottom_shape.push_back(1);
+  bottom_shape.push_back(5);
+  bottom_shape.push_back(5);
+  bottom_shape.push_back(5);
+
+
+  this->blob_bottom_vec_[0]->Reshape(bottom_shape);
+
+  fill_blob_data(this->blob_bottom_,1,1);
+
+
+   
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+
+  Blob<Dtype> top_reference;
+  this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
+      &top_reference);
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i],
+                this->epsilon_);
+  }
+}
+
+
+
+TYPED_TEST(LRNLayerTest, TestForwardAcrossChannelsLargeRegion) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_lrn_param()->set_local_size(15);
+
+  layer_param.set_type("LRN");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<LRNLayer<Dtype> > layer=
+   boost::static_pointer_cast<LRNLayer<Dtype>  > (new_layer);
+
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  Blob<Dtype> top_reference;
+  this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
+      &top_reference);
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i],
+                this->epsilon_);
+  }
+}
+
+
+TYPED_TEST(LRNLayerTest, TestSetupWithinChannel) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_lrn_param()->set_norm_region(
+      LRNParameter_NormRegion_WITHIN_CHANNEL);
+  layer_param.mutable_lrn_param()->set_local_size(3);
+
+  
+  layer_param.set_type("LRN");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<LRNLayer<Dtype> > layer=
+   boost::static_pointer_cast<LRNLayer<Dtype>  > (new_layer);
+
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 7);
+  EXPECT_EQ(this->blob_top_->height(), test_h);
+  EXPECT_EQ(this->blob_top_->width(), test_w);
+}
+#endif
+
+#if 1
+
+TYPED_TEST(LRNLayerTest, TestForwardWithinChannel) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_lrn_param()->set_norm_region(
+      LRNParameter_NormRegion_WITHIN_CHANNEL);
+  layer_param.mutable_lrn_param()->set_local_size(3);
+//  layer_param.mutable_lrn_param()->set_beta(1);
+
+  
+  layer_param.set_type("LRN");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<LRNLayer<Dtype> > layer=
+   boost::static_pointer_cast<LRNLayer<Dtype>  > (new_layer);
+
+/* presetting bottom_vec and data */
+
+  vector<int> bottom_shape;
+  bottom_shape.push_back(1);
+  bottom_shape.push_back(1);
+  bottom_shape.push_back(5);
+  bottom_shape.push_back(5);
+
+
+  this->blob_bottom_vec_[0]->Reshape(bottom_shape);
+
+  fill_blob_data(this->blob_bottom_,1,1);
+
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+
+  Blob<Dtype> top_reference;
+  this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
+      &top_reference);
+//  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+//    EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i],
+//                this->epsilon_);
+//  }
+
+  dump_blob(this->blob_bottom_,"lrn.bottom.data");
+  dump_blob(this->blob_top_,"lrn.top.data");
+  dump_blob(&top_reference,"lrn.reftop.data");
+  
+}
+
+#endif
+
+
+}  // namespace caffe
diff --git a/unit_tests/test_neuron_layer.cpp b/unit_tests/test_neuron_layer.cpp
new file mode 100644 (file)
index 0000000..10c101d
--- /dev/null
@@ -0,0 +1,358 @@
+#include <algorithm>
+#include <vector>
+#include <cmath>
+
+#include "google/protobuf/text_format.h"
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+
+#include "caffe/layers/absval_layer.hpp"
+#include "caffe/layers/bnll_layer.hpp"
+#include "caffe/layers/dropout_layer.hpp"
+#include "caffe/layers/elu_layer.hpp"
+#include "caffe/layers/exp_layer.hpp"
+#include "caffe/layers/inner_product_layer.hpp"
+#include "caffe/layers/log_layer.hpp"
+#include "caffe/layers/power_layer.hpp"
+#include "caffe/layers/prelu_layer.hpp"
+#include "caffe/layers/relu_layer.hpp"
+#include "caffe/layers/sigmoid_layer.hpp"
+#include "caffe/layers/tanh_layer.hpp"
+#include "caffe/layers/threshold_layer.hpp"
+
+#ifdef USE_CUDNN
+#include "caffe/layers/cudnn_relu_layer.hpp"
+#include "caffe/layers/cudnn_sigmoid_layer.hpp"
+#include "caffe/layers/cudnn_tanh_layer.hpp"
+#endif
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+typedef ::testing::Types<CPUDevice<float> > float_only;
+#define TestDtypesAndDevices float_only
+
+
+#define SET_LAYER(name) \
+ layer_param.set_type(#name);\
+  shared_ptr<Layer<Dtype> > new_layer=\
+    LayerRegistry<Dtype>::CreateLayer(layer_param);\
+  shared_ptr< name ## Layer <Dtype> > layer= \
+   boost::static_pointer_cast< name ## Layer <Dtype>  > (new_layer);\
+   if(0) layer=shared_ptr<name ## Layer<Dtype> >(new  name ## Layer<Dtype>(layer_param));\
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+
+
+template <typename TypeParam>
+class NeuronLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  NeuronLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 4, 5)),
+        blob_top_(new Blob<Dtype>()) {
+    Caffe::set_random_seed(1701);
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~NeuronLayerTest() { delete blob_bottom_; delete blob_top_; }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+
+
+  void TestPReLU(PReLULayer<Dtype> *layer) {
+    layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Now, check values
+    const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+    const Dtype* top_data = this->blob_top_->cpu_data();
+    const Dtype* slope_data = layer->blobs()[0]->cpu_data();
+    int hw = this->blob_bottom_->height() * this->blob_bottom_->width();
+    int channels = this->blob_bottom_->channels();
+    bool channel_shared = layer->layer_param().prelu_param().channel_shared();
+    for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+      int c = channel_shared ? 0 : (i / hw) % channels;
+      EXPECT_EQ(top_data[i],
+          std::max(bottom_data[i], (Dtype)(0))
+          + slope_data[c] * std::min(bottom_data[i], (Dtype)(0)));
+    }
+  }
+
+};
+
+TYPED_TEST_CASE(NeuronLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(NeuronLayerTest, TestAbsVal) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+
+  SET_LAYER(AbsVal);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data    = this->blob_top_->cpu_data();
+  const int count = this->blob_bottom_->count();
+  for (int i = 0; i < count; ++i) {
+    EXPECT_EQ(top_data[i], fabs(bottom_data[i]));
+  }
+}
+
+
+TYPED_TEST(NeuronLayerTest, TestReLU) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+
+
+  SET_LAYER(ReLU);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_GE(top_data[i], 0.);
+    EXPECT_TRUE(top_data[i] == 0 || top_data[i] == bottom_data[i]);
+  }
+}
+
+#if 1
+
+TYPED_TEST(NeuronLayerTest, TestReLUWithNegativeSlope) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CHECK(google::protobuf::TextFormat::ParseFromString(
+      "relu_param { negative_slope: 0.01 }", &layer_param));
+
+  SET_LAYER(ReLU);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    if (top_data[i] >= 0) {
+      EXPECT_FLOAT_EQ(top_data[i], bottom_data[i]);
+    } else {
+      EXPECT_FLOAT_EQ(top_data[i], bottom_data[i] * 0.01);
+    }
+  }
+}
+
+
+TYPED_TEST(NeuronLayerTest, TestSigmoid) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+
+  SET_LAYER(Sigmoid);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_FLOAT_EQ(top_data[i], 1. / (1 + exp(-bottom_data[i])));
+    // check that we squashed the value between 0 and 1
+    EXPECT_GE(top_data[i], 0.);
+    EXPECT_LE(top_data[i], 1.);
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestTanH) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+
+  int number=10;
+
+  this->blob_bottom_->Reshape(1,2,number,2);
+
+  for(int i=0;i<number;i++)
+    this->blob_bottom_->mutable_cpu_data()[i]=i*10;
+
+  SET_LAYER(TanH);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  // Test exact values
+  for (int i = 0; i < this->blob_bottom_->num(); ++i) {
+    for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
+      for (int k = 0; k < this->blob_bottom_->height(); ++k) {
+        for (int l = 0; l < this->blob_bottom_->width(); ++l) {
+
+          EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
+          EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) - 1) /
+             (exp(2*this->blob_bottom_->data_at(i, j, k, l)) + 1));
+        }
+      }
+    }
+  }
+}
+
+
+TYPED_TEST(NeuronLayerTest, TestBNLL) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+
+  SET_LAYER(BNLL);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  // Now, check values
+  const Dtype* bottom_data = this->blob_bottom_->cpu_data();
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    Dtype target=log(1+exp(bottom_data[i]));
+    EXPECT_NEAR(top_data[i], target,1e-4);
+  }
+}
+#endif
+
+#if 0 /* Not try PReLU now */
+
+TYPED_TEST(NeuronLayerTest, TestPReLUParam) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  PReLULayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  const Dtype* slopes = layer.blobs()[0]->cpu_data();
+  int count = layer.blobs()[0]->count();
+  for (int i = 0; i < count; ++i, ++slopes) {
+    EXPECT_EQ(*slopes, 0.25);
+  }
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUForward) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  PReLULayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  filler.Fill(layer.blobs()[0].get());
+  this->TestPReLU(&layer);
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUForwardChannelShared) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_prelu_param()->set_channel_shared(true);
+  PReLULayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  this->TestPReLU(&layer);
+}
+
+
+TYPED_TEST(NeuronLayerTest, TestPReLUConsistencyReLU) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter prelu_layer_param;
+  LayerParameter relu_layer_param;
+  relu_layer_param.mutable_relu_param()->set_negative_slope(0.25);
+  PReLULayer<Dtype> prelu(prelu_layer_param);
+  ReLULayer<Dtype> relu(relu_layer_param);
+  // Set up blobs
+  vector<Blob<Dtype>*> blob_bottom_vec_2;
+  vector<Blob<Dtype>*> blob_top_vec_2;
+  shared_ptr<Blob<Dtype> > blob_bottom_2(new Blob<Dtype>());
+  shared_ptr<Blob<Dtype> > blob_top_2(new Blob<Dtype>());
+  blob_bottom_vec_2.push_back(blob_bottom_2.get());
+  blob_top_vec_2.push_back(blob_top_2.get());
+  blob_bottom_2->CopyFrom(*this->blob_bottom_, false, true);
+  // SetUp layers
+  prelu.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  relu.SetUp(blob_bottom_vec_2, blob_top_vec_2);
+  // Check forward
+  prelu.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  relu.Forward(this->blob_bottom_vec_, blob_top_vec_2);
+  for (int s = 0; s < blob_top_2->count(); ++s) {
+    EXPECT_EQ(this->blob_top_->cpu_data()[s], blob_top_2->cpu_data()[s]);
+  }
+  // Check backward
+}
+
+TYPED_TEST(NeuronLayerTest, TestPReLUInPlace) {
+  typedef typename TypeParam::Dtype Dtype;
+  // Set layer parameters
+  LayerParameter ip_layer_param;
+  LayerParameter prelu_layer_param;
+  InnerProductParameter *ip_param =
+      ip_layer_param.mutable_inner_product_param();
+  ip_param->mutable_weight_filler()->set_type("gaussian");
+  ip_param->set_num_output(3);
+  InnerProductLayer<Dtype> ip(ip_layer_param);
+  PReLULayer<Dtype> prelu(prelu_layer_param);
+  InnerProductLayer<Dtype> ip2(ip_layer_param);
+  PReLULayer<Dtype> prelu2(prelu_layer_param);
+  // Set up blobs
+  vector<Blob<Dtype>*> blob_bottom_vec_2;
+  vector<Blob<Dtype>*> blob_middle_vec_2;
+  vector<Blob<Dtype>*> blob_top_vec_2;
+  shared_ptr<Blob<Dtype> > blob_bottom_2(new Blob<Dtype>());
+  shared_ptr<Blob<Dtype> > blob_middle_2(new Blob<Dtype>());
+  shared_ptr<Blob<Dtype> > blob_top_2(new Blob<Dtype>());
+  blob_bottom_vec_2.push_back(blob_bottom_2.get());
+  blob_middle_vec_2.push_back(blob_middle_2.get());
+  blob_top_vec_2.push_back(blob_top_2.get());
+  blob_bottom_2->CopyFrom(*this->blob_bottom_, false, true);
+  // SetUp layers
+  ip.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  prelu.SetUp(this->blob_top_vec_, this->blob_top_vec_);
+  ip2.SetUp(blob_bottom_vec_2, blob_middle_vec_2);
+  prelu2.SetUp(blob_middle_vec_2, blob_top_vec_2);
+  caffe_copy(ip2.blobs()[0]->count(), ip.blobs()[0]->cpu_data(),
+      ip2.blobs()[0]->mutable_cpu_data());
+  // Forward in-place
+  ip.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  prelu.Forward(this->blob_top_vec_, this->blob_top_vec_);
+  // Forward non-in-place
+  ip2.Forward(blob_bottom_vec_2, blob_middle_vec_2);
+  prelu2.Forward(blob_middle_vec_2, blob_top_vec_2);
+  // Check numbers
+  for (int s = 0; s < blob_top_2->count(); ++s) {
+    EXPECT_EQ(this->blob_top_->cpu_data()[s], blob_top_2->cpu_data()[s]);
+  }
+  // Fill top diff with random numbers
+  shared_ptr<Blob<Dtype> > tmp_blob(new Blob<Dtype>());
+  tmp_blob->ReshapeLike(*blob_top_2.get());
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  filler.Fill(tmp_blob.get());
+  caffe_copy(blob_top_2->count(), tmp_blob->cpu_data(),
+      this->blob_top_->mutable_cpu_diff());
+  caffe_copy(blob_top_2->count(), tmp_blob->cpu_data(),
+      blob_top_2->mutable_cpu_diff());
+  // Backward in-place
+  vector<bool> propagate_down;
+  propagate_down.push_back(true);
+  prelu.Backward(this->blob_top_vec_, propagate_down, this->blob_top_vec_);
+  ip.Backward(this->blob_top_vec_, propagate_down, this->blob_bottom_vec_);
+  // Backward non-in-place
+  prelu2.Backward(blob_top_vec_2, propagate_down, blob_middle_vec_2);
+  ip2.Backward(blob_middle_vec_2, propagate_down, blob_bottom_vec_2);
+  // Check numbers
+  for (int s = 0; s < blob_bottom_2->count(); ++s) {
+    EXPECT_EQ(this->blob_bottom_->cpu_diff()[s], blob_bottom_2->cpu_diff()[s]);
+  }
+  for (int s = 0; s < ip.blobs()[0]->count(); ++s) {
+    EXPECT_EQ(ip.blobs()[0]->cpu_diff()[s], ip2.blobs()[0]->cpu_diff()[s]);
+  }
+  for (int s = 0; s < ip.blobs()[1]->count(); ++s) {
+    EXPECT_EQ(ip.blobs()[1]->cpu_diff()[s], ip2.blobs()[1]->cpu_diff()[s]);
+  }
+  for (int s = 0; s < prelu.blobs()[0]->count(); ++s) {
+    EXPECT_EQ(prelu.blobs()[0]->cpu_diff()[s],
+        prelu2.blobs()[0]->cpu_diff()[s]);
+  }
+}
+
+#endif
+
+}  // namespace caffe
diff --git a/unit_tests/test_pooling_layer.cpp b/unit_tests/test_pooling_layer.cpp
new file mode 100644 (file)
index 0000000..bde8844
--- /dev/null
@@ -0,0 +1,652 @@
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/pooling_layer.hpp"
+
+#ifdef USE_CUDNN
+#include "caffe/layers/cudnn_pooling_layer.hpp"
+#endif
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+typedef ::testing::Types<CPUDevice<float> > float_only;
+#define TestDtypesAndDevices float_only
+
+
+#define SET_LAYER(name) \
+ layer_param.set_type(#name);\
+  shared_ptr<Layer<Dtype> > new_layer=\
+    LayerRegistry<Dtype>::CreateLayer(layer_param);\
+  shared_ptr< name ## Layer <Dtype> > layer= \
+   boost::static_pointer_cast< name ## Layer <Dtype>  > (new_layer);\
+   if(0) layer=shared_ptr<name ## Layer<Dtype> >(new  name ## Layer<Dtype>(layer_param));\
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+
+
+template <typename TypeParam>
+class PoolingLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  PoolingLayerTest()
+      : blob_bottom_(new Blob<Dtype>()),
+        blob_top_(new Blob<Dtype>()),
+        blob_top_mask_(new Blob<Dtype>()) {}
+  virtual void SetUp() {
+    Caffe::set_random_seed(1701);
+    blob_bottom_->Reshape(2, 3, 6, 5);
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~PoolingLayerTest() {
+    delete blob_bottom_;
+    delete blob_top_;
+    delete blob_top_mask_;
+  }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  Blob<Dtype>* const blob_top_mask_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+  // Test for 2x 2 square pooling layer
+  void TestForwardSquare() {
+    LayerParameter layer_param;
+    PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+    pooling_param->set_kernel_size(2);
+    pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+    const int num = 2;
+    const int channels = 2;
+    blob_bottom_->Reshape(num, channels, 3, 5);
+    // Input: 2x 2 channels of:
+    //     [1 2 5 2 3]
+    //     [9 4 1 4 8]
+    //     [1 2 5 2 3]
+    for (int i = 0; i < 15 * num * channels; i += 15) {
+      blob_bottom_->mutable_cpu_data()[i +  0] = 1;
+      blob_bottom_->mutable_cpu_data()[i +  1] = 2;
+      blob_bottom_->mutable_cpu_data()[i +  2] = 5;
+      blob_bottom_->mutable_cpu_data()[i +  3] = 2;
+      blob_bottom_->mutable_cpu_data()[i +  4] = 3;
+      blob_bottom_->mutable_cpu_data()[i +  5] = 9;
+      blob_bottom_->mutable_cpu_data()[i +  6] = 4;
+      blob_bottom_->mutable_cpu_data()[i +  7] = 1;
+      blob_bottom_->mutable_cpu_data()[i +  8] = 4;
+      blob_bottom_->mutable_cpu_data()[i +  9] = 8;
+      blob_bottom_->mutable_cpu_data()[i + 10] = 1;
+      blob_bottom_->mutable_cpu_data()[i + 11] = 2;
+      blob_bottom_->mutable_cpu_data()[i + 12] = 5;
+      blob_bottom_->mutable_cpu_data()[i + 13] = 2;
+      blob_bottom_->mutable_cpu_data()[i + 14] = 3;
+    }
+
+    SET_LAYER(Pooling);
+
+    EXPECT_EQ(blob_top_->num(), num);
+    EXPECT_EQ(blob_top_->channels(), channels);
+    EXPECT_EQ(blob_top_->height(), 2);
+    EXPECT_EQ(blob_top_->width(), 4);
+    if (blob_top_vec_.size() > 1) {
+      EXPECT_EQ(blob_top_mask_->num(), num);
+      EXPECT_EQ(blob_top_mask_->channels(), channels);
+      EXPECT_EQ(blob_top_mask_->height(), 2);
+      EXPECT_EQ(blob_top_mask_->width(), 4);
+    }
+    layer->Forward(blob_bottom_vec_, blob_top_vec_);
+    // Expected output: 2x 2 channels of:
+    //     [9 5 5 8]
+    //     [9 5 5 8]
+    for (int i = 0; i < 8 * num * channels; i += 8) {
+      EXPECT_EQ(blob_top_->cpu_data()[i + 0], 9);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 1], 5);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 2], 5);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 3], 8);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 4], 9);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 5], 5);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 6], 5);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 7], 8);
+    }
+    if (blob_top_vec_.size() > 1) {
+      // Expected mask output: 2x 2 channels of:
+      //     [5  2  2 9]
+      //     [5 12 12 9]
+      for (int i = 0; i < 8 * num * channels; i += 8) {
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 0],  5);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 1],  2);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 2],  2);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 3],  9);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 4],  5);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 5], 12);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 6], 12);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 7],  9);
+      }
+    }
+  }
+  // Test for 3x 2 rectangular pooling layer with kernel_h > kernel_w
+  void TestForwardRectHigh() {
+    LayerParameter layer_param;
+    PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+    pooling_param->set_kernel_h(3);
+    pooling_param->set_kernel_w(2);
+    pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+    const int num = 2;
+    const int channels = 2;
+    blob_bottom_->Reshape(num, channels, 6, 6);
+    // Input: 2x 2 channels of:
+    // [35     1     6    26    19    24]
+    // [ 3    32     7    21    23    25]
+    // [31     9     2    22    27    20]
+    // [ 8    28    33    17    10    15]
+    // [30     5    34    12    14    16]
+    // [ 4    36    29    13    18    11]
+    // (this is generated by magic(6) in MATLAB)
+    for (int i = 0; i < 36 * num * channels; i += 36) {
+      blob_bottom_->mutable_cpu_data()[i +  0] = 35;
+      blob_bottom_->mutable_cpu_data()[i +  1] = 1;
+      blob_bottom_->mutable_cpu_data()[i +  2] = 6;
+      blob_bottom_->mutable_cpu_data()[i +  3] = 26;
+      blob_bottom_->mutable_cpu_data()[i +  4] = 19;
+      blob_bottom_->mutable_cpu_data()[i +  5] = 24;
+      blob_bottom_->mutable_cpu_data()[i +  6] = 3;
+      blob_bottom_->mutable_cpu_data()[i +  7] = 32;
+      blob_bottom_->mutable_cpu_data()[i +  8] = 7;
+      blob_bottom_->mutable_cpu_data()[i +  9] = 21;
+      blob_bottom_->mutable_cpu_data()[i + 10] = 23;
+      blob_bottom_->mutable_cpu_data()[i + 11] = 25;
+      blob_bottom_->mutable_cpu_data()[i + 12] = 31;
+      blob_bottom_->mutable_cpu_data()[i + 13] = 9;
+      blob_bottom_->mutable_cpu_data()[i + 14] = 2;
+      blob_bottom_->mutable_cpu_data()[i + 15] = 22;
+      blob_bottom_->mutable_cpu_data()[i + 16] = 27;
+      blob_bottom_->mutable_cpu_data()[i + 17] = 20;
+      blob_bottom_->mutable_cpu_data()[i + 18] = 8;
+      blob_bottom_->mutable_cpu_data()[i + 19] = 28;
+      blob_bottom_->mutable_cpu_data()[i + 20] = 33;
+      blob_bottom_->mutable_cpu_data()[i + 21] = 17;
+      blob_bottom_->mutable_cpu_data()[i + 22] = 10;
+      blob_bottom_->mutable_cpu_data()[i + 23] = 15;
+      blob_bottom_->mutable_cpu_data()[i + 24] = 30;
+      blob_bottom_->mutable_cpu_data()[i + 25] = 5;
+      blob_bottom_->mutable_cpu_data()[i + 26] = 34;
+      blob_bottom_->mutable_cpu_data()[i + 27] = 12;
+      blob_bottom_->mutable_cpu_data()[i + 28] = 14;
+      blob_bottom_->mutable_cpu_data()[i + 29] = 16;
+      blob_bottom_->mutable_cpu_data()[i + 30] = 4;
+      blob_bottom_->mutable_cpu_data()[i + 31] = 36;
+      blob_bottom_->mutable_cpu_data()[i + 32] = 29;
+      blob_bottom_->mutable_cpu_data()[i + 33] = 13;
+      blob_bottom_->mutable_cpu_data()[i + 34] = 18;
+      blob_bottom_->mutable_cpu_data()[i + 35] = 11;
+    }
+
+    SET_LAYER(Pooling);
+
+    EXPECT_EQ(blob_top_->channels(), channels);
+    EXPECT_EQ(blob_top_->height(), 4);
+    EXPECT_EQ(blob_top_->width(), 5);
+    if (blob_top_vec_.size() > 1) {
+      EXPECT_EQ(blob_top_mask_->num(), num);
+      EXPECT_EQ(blob_top_mask_->channels(), channels);
+      EXPECT_EQ(blob_top_mask_->height(), 4);
+      EXPECT_EQ(blob_top_mask_->width(), 5);
+    }
+    layer->Forward(blob_bottom_vec_, blob_top_vec_);
+    // Expected output: 2x 2 channels of:
+    // [35    32    26    27    27]
+    // [32    33    33    27    27]
+    // [31    34    34    27    27]
+    // [36    36    34    18    18]
+    for (int i = 0; i < 20 * num * channels; i += 20) {
+      EXPECT_EQ(blob_top_->cpu_data()[i +  0], 35);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  1], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  2], 26);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  3], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  4], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  5], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  6], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  7], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  8], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  9], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 10], 31);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 11], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 13], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 14], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 15], 36);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 17], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 18], 18);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18);
+    }
+    if (blob_top_vec_.size() > 1) {
+        // [ 1     8     4    17    17]
+        // [ 8    21    21    17    17]
+        // [13    27    27    17    17]
+        // [32    32    27    35    35]
+      for (int i = 0; i < 20 * num * channels; i += 20) {
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  0],  0);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  1],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  2],  3);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  3], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  4], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  5],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  6], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  7], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  8], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  9], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 12);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 31);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 34);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34);
+      }
+    }
+  }
+  // Test for rectangular pooling layer with kernel_w > kernel_h
+  void TestForwardRectWide() {
+    LayerParameter layer_param;
+    PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+    pooling_param->set_kernel_h(2);
+    pooling_param->set_kernel_w(3);
+    pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+    const int num = 2;
+    const int channels = 2;
+    blob_bottom_->Reshape(num, channels, 6, 6);
+    // Input: 2x 2 channels of:
+    // [35     1     6    26    19    24]
+    // [ 3    32     7    21    23    25]
+    // [31     9     2    22    27    20]
+    // [ 8    28    33    17    10    15]
+    // [30     5    34    12    14    16]
+    // [ 4    36    29    13    18    11]
+    // (this is generated by magic(6) in MATLAB)
+    for (int i = 0; i < 36 * num * channels; i += 36) {
+      blob_bottom_->mutable_cpu_data()[i +  0] = 35;
+      blob_bottom_->mutable_cpu_data()[i +  1] = 1;
+      blob_bottom_->mutable_cpu_data()[i +  2] = 6;
+      blob_bottom_->mutable_cpu_data()[i +  3] = 26;
+      blob_bottom_->mutable_cpu_data()[i +  4] = 19;
+      blob_bottom_->mutable_cpu_data()[i +  5] = 24;
+      blob_bottom_->mutable_cpu_data()[i +  6] = 3;
+      blob_bottom_->mutable_cpu_data()[i +  7] = 32;
+      blob_bottom_->mutable_cpu_data()[i +  8] = 7;
+      blob_bottom_->mutable_cpu_data()[i +  9] = 21;
+      blob_bottom_->mutable_cpu_data()[i + 10] = 23;
+      blob_bottom_->mutable_cpu_data()[i + 11] = 25;
+      blob_bottom_->mutable_cpu_data()[i + 12] = 31;
+      blob_bottom_->mutable_cpu_data()[i + 13] = 9;
+      blob_bottom_->mutable_cpu_data()[i + 14] = 2;
+      blob_bottom_->mutable_cpu_data()[i + 15] = 22;
+      blob_bottom_->mutable_cpu_data()[i + 16] = 27;
+      blob_bottom_->mutable_cpu_data()[i + 17] = 20;
+      blob_bottom_->mutable_cpu_data()[i + 18] = 8;
+      blob_bottom_->mutable_cpu_data()[i + 19] = 28;
+      blob_bottom_->mutable_cpu_data()[i + 20] = 33;
+      blob_bottom_->mutable_cpu_data()[i + 21] = 17;
+      blob_bottom_->mutable_cpu_data()[i + 22] = 10;
+      blob_bottom_->mutable_cpu_data()[i + 23] = 15;
+      blob_bottom_->mutable_cpu_data()[i + 24] = 30;
+      blob_bottom_->mutable_cpu_data()[i + 25] = 5;
+      blob_bottom_->mutable_cpu_data()[i + 26] = 34;
+      blob_bottom_->mutable_cpu_data()[i + 27] = 12;
+      blob_bottom_->mutable_cpu_data()[i + 28] = 14;
+      blob_bottom_->mutable_cpu_data()[i + 29] = 16;
+      blob_bottom_->mutable_cpu_data()[i + 30] = 4;
+      blob_bottom_->mutable_cpu_data()[i + 31] = 36;
+      blob_bottom_->mutable_cpu_data()[i + 32] = 29;
+      blob_bottom_->mutable_cpu_data()[i + 33] = 13;
+      blob_bottom_->mutable_cpu_data()[i + 34] = 18;
+      blob_bottom_->mutable_cpu_data()[i + 35] = 11;
+    }
+
+    SET_LAYER(Pooling);
+
+    EXPECT_EQ(blob_top_->num(), num);
+    EXPECT_EQ(blob_top_->channels(), channels);
+    EXPECT_EQ(blob_top_->height(), 5);
+    EXPECT_EQ(blob_top_->width(), 4);
+    if (blob_top_vec_.size() > 1) {
+      EXPECT_EQ(blob_top_mask_->num(), num);
+      EXPECT_EQ(blob_top_mask_->channels(), channels);
+      EXPECT_EQ(blob_top_mask_->height(), 5);
+      EXPECT_EQ(blob_top_mask_->width(), 4);
+    }
+    layer->Forward(blob_bottom_vec_, blob_top_vec_);
+    // Expected output: 2x 2 channels of:
+    // [35    32    26    26]
+    // [32    32    27    27]
+    // [33    33    33    27]
+    // [34    34    34    17]
+    // [36    36    34    18]
+    for (int i = 0; i < 20 * num * channels; i += 20) {
+      EXPECT_EQ(blob_top_->cpu_data()[i +  0], 35);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  1], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  2], 26);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  3], 26);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  4], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  5], 32);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  6], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  7], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  8], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i +  9], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 10], 33);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 11], 27);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 12], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 13], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 14], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 15], 17);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 16], 36);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 17], 36);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 18], 34);
+      EXPECT_EQ(blob_top_->cpu_data()[i + 19], 18);
+    }
+    if (blob_top_vec_.size() > 1) {
+        // [ 1     8     4     4]
+        // [ 8     8    17    17]
+        // [21    21    21    17]
+        // [27    27    27    22]
+        // [32    32    27    35]
+      for (int i = 0; i < 20 * num * channels; i += 20) {
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  0],  0);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  1],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  2],  3);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  3],  3);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  4],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  5],  7);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  6], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  7], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  8], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i +  9], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 10], 20);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 11], 16);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 12], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 13], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 14], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 15], 21);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 16], 31);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 17], 31);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 18], 26);
+        EXPECT_EQ(blob_top_mask_->cpu_data()[i + 19], 34);
+      }
+    }
+  }
+};
+
+
+
+
+TYPED_TEST_CASE(PoolingLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(PoolingLayerTest, TestSetup) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(2);
+
+   SET_LAYER(Pooling); 
+ //PoolingLayer<Dtype> layer(layer_param);
+  //layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
+  EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
+  EXPECT_EQ(this->blob_top_->height(), 3);
+  EXPECT_EQ(this->blob_top_->width(), 2);
+}
+
+TYPED_TEST(PoolingLayerTest, TestSetupPadded) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(2);
+  pooling_param->set_pad(1);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+//  PoolingLayer<Dtype> layer(layer_param);
+//  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  SET_LAYER(Pooling);
+
+  EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
+  EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
+  EXPECT_EQ(this->blob_top_->height(), 4);
+  EXPECT_EQ(this->blob_top_->width(), 3);
+}
+
+TYPED_TEST(PoolingLayerTest, TestSetupGlobalPooling) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_global_pooling(true);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+
+//  PoolingLayer<Dtype> layer(layer_param);
+//  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  SET_LAYER(Pooling);
+
+  EXPECT_EQ(this->blob_top_->num(), this->blob_bottom_->num());
+  EXPECT_EQ(this->blob_top_->channels(), this->blob_bottom_->channels());
+  EXPECT_EQ(this->blob_top_->height(), 1);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+}
+
+TYPED_TEST(PoolingLayerTest, TestForwardMax) {
+  this->TestForwardSquare();
+  this->TestForwardRectHigh();
+  this->TestForwardRectWide();
+}
+
+TYPED_TEST(PoolingLayerTest, TestForwardMaxTopMask) {
+  this->blob_top_vec_.push_back(this->blob_top_mask_);
+  this->TestForwardSquare();
+  this->TestForwardRectHigh();
+  this->TestForwardRectWide();
+}
+
+TYPED_TEST(PoolingLayerTest, TestForwardMaxPadded) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(2);
+  pooling_param->set_pad(2);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+  this->blob_bottom_->Reshape(1, 1, 3, 3);
+  // Input:
+  //     [ 1 2 4 ]
+  //     [ 2 3 2 ]
+  //     [ 4 2 1 ]
+  this->blob_bottom_->mutable_cpu_data()[0] = 1;
+  this->blob_bottom_->mutable_cpu_data()[1] = 2;
+  this->blob_bottom_->mutable_cpu_data()[2] = 4;
+  this->blob_bottom_->mutable_cpu_data()[3] = 2;
+  this->blob_bottom_->mutable_cpu_data()[4] = 3;
+  this->blob_bottom_->mutable_cpu_data()[5] = 2;
+  this->blob_bottom_->mutable_cpu_data()[6] = 4;
+  this->blob_bottom_->mutable_cpu_data()[7] = 2;
+  this->blob_bottom_->mutable_cpu_data()[8] = 1;
+
+
+  SET_LAYER(Pooling);
+
+  EXPECT_EQ(this->blob_top_->num(), 1);
+  EXPECT_EQ(this->blob_top_->channels(), 1);
+  EXPECT_EQ(this->blob_top_->height(), 3);
+  EXPECT_EQ(this->blob_top_->width(), 3);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  Dtype epsilon = 1e-8;
+  // Output:
+  //     [ 1 4 4 ]
+  //     [ 4 4 4 ]
+  //     [ 4 4 1 ]
+  EXPECT_NEAR(this->blob_top_->cpu_data()[0], 1, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[2], 4, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[4], 4, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[6], 4, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[8], 1, epsilon);
+}
+
+template <typename Dtype>
+void fill_bottom_data(Blob<Dtype >* bottom)
+{
+    for(int i=0;i<bottom->num();i++)
+      for(int j=0;j<bottom->channels();j++)
+        for(int l=0;l<bottom->height();l++)
+          for(int k=0;k<bottom->width();k++)
+        {
+           int offset;
+           Dtype * ptr;
+
+            offset=i*bottom->channels()*bottom->height()*bottom->width()+
+                    j*bottom->height()*bottom->width()+
+                   l*bottom->width()+k;
+
+           ptr=bottom->mutable_cpu_data();
+
+           ptr[offset]=offset;
+
+        }
+    
+}
+
+template <typename Dtype>
+void check_top_data(Blob<Dtype>* bottom, Blob<Dtype>* top, int kernel_size, int stride)
+{
+     Dtype epsilon = 1e-5;
+     int error_count=0;
+
+    for(int i=0;i<top->num();i++)
+      for(int j=0;j<top->channels();j++)
+        for(int l=0;l<top->height();l++)
+          for(int k=0;k<top->width();k++)
+        {
+          Dtype pool_data=top->data_at(i,j,l,k);
+          Dtype max_bottom_data=-100000000;
+
+          int top_h=stride*l;
+          int top_w=stride*k;
+         
+          /* calculate max **/
+          for(int x=0;x<kernel_size;x++)
+              for(int y=0;y<kernel_size;y++)
+           {
+                Dtype d=bottom->data_at(i,j,top_h+x,top_w+y);
+
+                if(d>max_bottom_data)
+                    max_bottom_data=d;     
+           }
+
+           Dtype offset=pool_data-max_bottom_data;
+          if(offset>epsilon || offset<-epsilon)
+
+          {
+              std::cout<<error_count<<": ";
+              std::cout<<i<<" "<<j<<" "<<l<<" "<<k<<" bottom: ";
+              std::cout<<top_h<<" "<<top_w<<std::endl;
+              std::cout<<"pooled: "<<pool_data<<" bottom:"<<max_bottom_data<<std::endl;
+             
+                   error_count++;
+          }
+
+           if(error_count==10)
+              return ;
+        }
+      
+
+      
+}
+
+
+TYPED_TEST(PoolingLayerTest, TestMax) {
+  typedef typename TypeParam::Dtype Dtype;
+
+  int kernel_size=3;
+  int stride=2;
+  int channel_number=96;
+   int h=55;
+   int w=55;
+
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(kernel_size);
+  pooling_param->set_stride(stride);
+  pooling_param->set_pad(0);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_MAX);
+  this->blob_bottom_->Reshape(1, channel_number, h, w);
+  FillerParameter filler_param;
+  filler_param.set_value(Dtype(2));
+  GaussianFiller<Dtype> filler(filler_param);
+  filler.Fill(this->blob_bottom_);
+
+  SET_LAYER(Pooling);
+  //fill_bottom_data(this->blob_bottom_);
+
+ #if 0
+  EXPECT_EQ(this->blob_top_->num(), 1);
+  EXPECT_EQ(this->blob_top_->channels(), 96);
+  EXPECT_EQ(this->blob_top_->height(), 27);
+  EXPECT_EQ(this->blob_top_->width(), 27);
+#endif
+
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+  check_top_data(this->blob_bottom_,this->blob_top_,kernel_size,stride);
+
+}
+
+TYPED_TEST(PoolingLayerTest, TestForwardAve) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
+  pooling_param->set_kernel_size(3);
+  pooling_param->set_stride(1);
+  pooling_param->set_pad(1);
+  pooling_param->set_pool(PoolingParameter_PoolMethod_AVE);
+  this->blob_bottom_->Reshape(1, 1, 3, 3);
+  FillerParameter filler_param;
+  filler_param.set_value(Dtype(2));
+  ConstantFiller<Dtype> filler(filler_param);
+  filler.Fill(this->blob_bottom_);
+  PoolingLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 1);
+  EXPECT_EQ(this->blob_top_->channels(), 1);
+  EXPECT_EQ(this->blob_top_->height(), 3);
+  EXPECT_EQ(this->blob_top_->width(), 3);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  Dtype epsilon = 1e-5;
+  EXPECT_NEAR(this->blob_top_->cpu_data()[0], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[1], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[2], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[3], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[4], 2.0    , epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[5], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[6], 8.0 / 9, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[7], 4.0 / 3, epsilon);
+  EXPECT_NEAR(this->blob_top_->cpu_data()[8], 8.0 / 9, epsilon);
+}
+
+
+}  // namespace caffe
diff --git a/unit_tests/test_softmax_layer.cpp b/unit_tests/test_softmax_layer.cpp
new file mode 100644 (file)
index 0000000..ab2f836
--- /dev/null
@@ -0,0 +1,99 @@
+#include <cmath>
+#include <vector>
+
+#include "gtest/gtest.h"
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/layers/softmax_layer.hpp"
+
+#ifdef USE_CUDNN
+#include "caffe/layers/cudnn_softmax_layer.hpp"
+#endif
+
+#include "caffe/test/test_caffe_main.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+namespace caffe {
+
+template <typename TypeParam>
+class SoftmaxLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+ protected:
+  SoftmaxLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 10, 1, 1)),
+        blob_top_(new Blob<Dtype>()) {
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~SoftmaxLayerTest() { delete blob_bottom_; delete blob_top_; }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+
+typedef ::testing::Types<CPUDevice<float> > float_only;
+
+#define TestDtypesAndDevices float_only
+
+
+TYPED_TEST_CASE(SoftmaxLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(SoftmaxLayerTest, TestForward) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+
+
+
+ layer_param.set_type("Softmax");
+
+  shared_ptr<Layer<Dtype> > new_layer=
+    LayerRegistry<Dtype>::CreateLayer(layer_param);
+
+  shared_ptr<SoftmaxLayer<Dtype> > layer=
+   boost::static_pointer_cast<SoftmaxLayer<Dtype>  > (new_layer);
+
+//  layer=shared_ptr<SoftmaxLayer<Dtype> >(new  SoftmaxLayer<Dtype>(layer_param));
+
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+
+
+  // Test sum
+  for (int i = 0; i < this->blob_bottom_->num(); ++i) {
+    for (int k = 0; k < this->blob_bottom_->height(); ++k) {
+      for (int l = 0; l < this->blob_bottom_->width(); ++l) {
+        Dtype sum = 0;
+        for (int j = 0; j < this->blob_top_->channels(); ++j) {
+          sum += this->blob_top_->data_at(i, j, k, l);
+        }
+        EXPECT_GE(sum, 0.999);
+        EXPECT_LE(sum, 1.001);
+        // Test exact values
+        Dtype scale = 0;
+        for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
+          scale += exp(this->blob_bottom_->data_at(i, j, k, l));
+        }
+        for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
+          EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
+              exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)
+              << "debug: " << i << " " << j;
+          EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
+              exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)
+              << "debug: " << i << " " << j;
+        }
+      }
+    }
+  }
+}
+
+
+
+}  // namespace caffe
diff --git a/unit_tests/testbed.c b/unit_tests/testbed.c
new file mode 100644 (file)
index 0000000..d33433d
--- /dev/null
@@ -0,0 +1,146 @@
+#include <stdio.h>
+#include <stdlib.h>
+#include <sys/time.h>
+#include <unistd.h>
+#include <getopt.h>
+#include <string.h>
+
+
+#include  "pmu.h"
+#include "testbed.h"
+
+struct armv8_event
+{
+       char * name;
+       int id;
+       uint32_t init_val;
+       char * note;
+};
+
+
+static struct armv8_event a57_list[6]=
+{
+       {"INST",0x8,0,"instruction retired"},
+       {"CYCL",0x11,0,"CPU running cycle"},
+       {"L1D MISS",0x3,0,"L1D CACHE MISS/REFILL"},
+       {"L1D ACCESS",0x4,0,"L1D CACHE ACCESS"},
+       {"L2 MISS",0x17,0,"L2 CACHE MISS/REFILL"},
+       {"L2 ACCESS",0x16,0,"L2 CACHE ACCESS"}
+};
+
+static int e[6];
+
+void init_testbed(void)
+{
+        int i;
+        struct armv8_event * p_list;
+
+       init_pmu_registers();
+
+       p_list=a57_list;
+
+       for(i=0;i<6;i++)
+       {
+               e[i]=create_pmu_event(p_list[i].name,p_list[i].id,
+                               p_list[i].init_val,p_list[i].note);
+       }
+
+}
+
+void run_test(int reptition, int warm_up,void (*test_func)(void *),void * arg)
+{
+        uint32_t t0,t1;
+       uint32_t freq;
+       uint32_t cycle;
+       uint64_t total_time=0;
+       uint32_t loop_count=0;
+       int i;
+     
+        if(warm_up)
+           test_func(arg);
+         
+
+       freq=read_32bit_sysreg(CNTFRQ_EL0);
+
+       start_pmu_event(e[0]);
+       start_pmu_event(e[1]);
+       start_pmu_event(e[2]);
+       start_pmu_event(e[3]);
+       start_pmu_event(e[4]);
+       start_pmu_event(e[5]);
+
+       set_pmu_event_base(e[0]);
+       set_pmu_event_base(e[1]);
+       set_pmu_event_base(e[2]);
+       set_pmu_event_base(e[3]);
+       set_pmu_event_base(e[4]);
+       set_pmu_event_base(e[5]);
+
+       t0=read_32bit_sysreg(CNTVCT_EL0);
+
+       for(i=0;i<reptition;i++)
+       {
+           test_func(arg);
+
+       record_pmu_event(e[0],0,1,1);
+       record_pmu_event(e[1],0,1,1);
+       record_pmu_event(e[2],0,1,1);
+       record_pmu_event(e[3],0,1,1);
+       record_pmu_event(e[4],0,1,1);
+       record_pmu_event(e[5],0,1,1);
+
+       t1=read_32bit_sysreg(CNTVCT_EL0);
+       loop_count++;
+       total_time+=(t1-t0);
+        t0=t1;
+
+       }
+        
+
+       stop_pmu_event(e[0]);
+       stop_pmu_event(e[1]);
+       stop_pmu_event(e[2]);
+       stop_pmu_event(e[3]);
+       stop_pmu_event(e[4]);
+       stop_pmu_event(e[5]);
+
+       dump_pmu_event_stat(e[0]);
+       dump_pmu_event_stat(e[1]);
+       dump_pmu_event_stat(e[2]);
+       dump_pmu_event_stat(e[3]);
+       dump_pmu_event_stat(e[4]);
+       dump_pmu_event_stat(e[5]);
+
+
+       printf("\n------------------------------------\n\n");
+
+
+       cycle=get_pmu_stat_avg(e[1]);
+       t0=total_time/loop_count;
+
+
+       printf("freq is 0x%x\n",freq);
+       printf("pysical counter pass: 0x%x (0x%lx/%u)\n",t0,total_time,loop_count);
+       printf("coverted to ms: %.3f\n",1000.0*t0/freq);
+
+
+       printf("CPU freq: %.2f MHZ (cycle:0x%x)\n",(float)freq*cycle/t0/1000000,cycle);
+
+       printf("IPC is: %.2f \n",(float)get_pmu_stat_avg(e[0])/cycle);
+       printf("L1 CACHE MISS  is: %.2f \n",(float)get_pmu_stat_avg(e[2])/get_pmu_stat_avg(e[3]));
+       printf("L2 CACHE MISS  is: %.2f \n",(float)get_pmu_stat_avg(e[4])/get_pmu_stat_avg(e[5]));
+
+        /*reset all record */
+
+}
+
+void release_testbed(void)
+{
+
+       release_pmu_event(e[0]);
+       release_pmu_event(e[1]);
+       release_pmu_event(e[2]);
+       release_pmu_event(e[3]);
+       release_pmu_event(e[4]);
+       release_pmu_event(e[5]);
+}
diff --git a/unit_tests/testbed.h b/unit_tests/testbed.h
new file mode 100644 (file)
index 0000000..7a7f3df
--- /dev/null
@@ -0,0 +1,10 @@
+#ifndef __TESTBED_H__
+#define __TESTBED_H__
+
+void init_testbed(void);
+
+void run_test(int reptition, int warm_up,void (*test_func)(void *),void * arg);
+
+void release_testbed(void);
+
+#endif