misc update...
authorYangqing Jia <jiayq84@gmail.com>
Tue, 1 Oct 2013 21:00:18 +0000 (14:00 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Tue, 1 Oct 2013 21:00:18 +0000 (14:00 -0700)
src/Makefile
src/caffe/layer.hpp
src/caffe/layers/loss_layer.cu
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.cu [new file with mode: 0644]
src/caffe/util/math_functions.hpp
src/examples/convolution.cpp [deleted file]
src/programs/demo_mnist.cpp [new file with mode: 0644]

index f86db8f..05b7bc0 100644 (file)
@@ -12,17 +12,20 @@ CXX_SRCS := $(shell find caffe ! -name "test_*.cpp" -name "*.cpp")
 CU_SRCS := $(shell find caffe -name "*.cu")
 TEST_SRCS := $(shell find caffe -name "test_*.cpp")
 GTEST_SRC := gtest/gtest-all.cpp
+PROGRAM_SRCS := $(shell find programs -name "*.cpp")
 PROTO_SRCS := $(wildcard caffe/proto/*.proto)
 PROTO_GEN_HEADER := ${PROTO_SRCS:.proto=.pb.h}
 PROTO_GEN_CC := ${PROTO_SRCS:.proto=.pb.cc}
 PROTO_GEN_PY := ${PROTO_SRCS:.proto=_pb2.py}
 CXX_OBJS := ${CXX_SRCS:.cpp=.o}
 CU_OBJS := ${CU_SRCS:.cu=.cuo}
+PROGRAM_OBJS := ${PROGRAM_SRCS:.cpp=.o}
 PROTO_OBJS := ${PROTO_SRCS:.proto=.pb.o}
 OBJS := $(PROTO_OBJS) $(CXX_OBJS) $(CU_OBJS)
 TEST_OBJS := ${TEST_SRCS:.cpp=.o}
 GTEST_OBJ := ${GTEST_SRC:.cpp=.o}
 TEST_BINS := ${TEST_OBJS:.o=.testbin}
+PROGRAM_BINS :=${PROGRAM_OBJS:.o=.bin}
 
 CUDA_DIR := /usr/local/cuda
 CUDA_ARCH := -arch=sm_20
@@ -46,7 +49,7 @@ LDFLAGS += $(foreach library,$(LIBRARIES),-l$(library))
 LINK = $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(LDFLAGS) $(WARNINGS)
 NVCC = nvcc ${CXXFLAGS:-fPIC=-Xcompiler -fPIC} $(CPPFLAGS) $(CUDA_ARCH)
 
-.PHONY: all test clean distclean linecount
+.PHONY: all test clean distclean linecount program
 
 all: $(NAME)
 
@@ -55,12 +58,17 @@ linecount: clean
 
 test: $(OBJS) $(GTEST_OBJ) $(TEST_BINS)
 
+program: $(OBJS) $(PROGRAM_BINS)
+
 runtest: test
        for testbin in $(TEST_BINS); do $$testbin; done
 
 $(TEST_BINS): %.testbin : %.o
        $(CXX) $< $(OBJS) $(GTEST_OBJ) -o $@ $(LDFLAGS) $(WARNINGS)
 
+$(PROGRAM_BINS): %.bin : %.o
+       $(CXX) $< $(OBJS) -o $@ $(LDFLAGS) $(WARNINGS)
+
 $(NAME): $(PROTO_GEN_CC) $(OBJS)
        $(LINK) -shared $(OBJS) -o $(NAME)
 
index f575935..9898cbc 100644 (file)
@@ -54,7 +54,7 @@ class Layer {
   // If no gpu code is provided, we will simply use cpu code.
   virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
       vector<Blob<Dtype>*>* top) {
-    LOG(WARNING) << "Using CPU code as backup.";
+    // LOG(WARNING) << "Using CPU code as backup.";
     Forward_cpu(bottom, top);
   };
 
@@ -67,7 +67,7 @@ class Layer {
   virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
       const bool propagate_down,
       vector<Blob<Dtype>*>* bottom) {
-    LOG(WARNING) << "Using CPU code as backup.";
+    // LOG(WARNING) << "Using CPU code as backup.";
     return Backward_cpu(top, propagate_down, bottom);
   };
 
index 9aedc3d..1ea0626 100644 (file)
@@ -111,7 +111,7 @@ void AccuracyLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
     }
   }
   accuracy /= num;
-  LOG(INFO) << "Accuracy: " << accuracy;
+  // LOG(INFO) << "Accuracy: " << accuracy;
   (*top)[0]->mutable_cpu_data()[0] = accuracy;
 }
 
index a074545..60656b8 100644 (file)
@@ -281,4 +281,16 @@ double caffe_cpu_dot<double>(const int n, const double* x, const double* y) {
   return cblas_ddot(n, x, 1, y, 1);
 }
 
+template <>
+void caffe_gpu_dot<float>(const int n, const float* x, const float* y,
+    float* out) {
+  CUBLAS_CHECK(cublasSdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
+}
+
+template <>
+void caffe_gpu_dot<double>(const int n, const double* x, const double* y,
+    double * out) {
+  CUBLAS_CHECK(cublasDdot(Caffe::cublas_handle(), n, x, 1, y, 1, out));
+}
+
 }  // namespace caffe
diff --git a/src/caffe/util/math_functions.cu b/src/caffe/util/math_functions.cu
new file mode 100644 (file)
index 0000000..e930581
--- /dev/null
@@ -0,0 +1,36 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cmath>
+#include <cstdlib>
+#include <cstring>
+
+#include "caffe/common.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+__global__ void mul_kernel(const int n, const Dtype* a,
+    const Dtype* b, Dtype* y) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < n) {
+    y[index] = a[index] * b[index];
+  }
+}
+
+template <>
+void caffe_gpu_mul<float>(const int N, const float* a,
+    const float* b, float* y) {
+  mul_kernel<float><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+      N, a, b, y);
+}
+
+template <>
+void caffe_gpu_mul<double>(const int N, const double* a,
+    const double* b, double* y) {
+  mul_kernel<double><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+      N, a, b, y);
+}
+
+
+}  // namespace caffe
index a71f28e..e9e2db8 100644 (file)
@@ -76,6 +76,9 @@ template <typename Dtype>
 void caffe_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
 
 template <typename Dtype>
+void caffe_gpu_mul(const int N, const Dtype* a, const Dtype* b, Dtype* y);
+
+template <typename Dtype>
 void caffe_div(const int N, const Dtype* a, const Dtype* b, Dtype* y);
 
 template <typename Dtype>
@@ -94,6 +97,9 @@ void caffe_exp(const int n, const Dtype* a, Dtype* y);
 template <typename Dtype>
 Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
 
+template <typename Dtype>
+void caffe_gpu_dot(const int n, const Dtype* x, const Dtype* y, Dtype* out);
+
 }  // namespace caffe
 
 
diff --git a/src/examples/convolution.cpp b/src/examples/convolution.cpp
deleted file mode 100644 (file)
index e69de29..0000000
diff --git a/src/programs/demo_mnist.cpp b/src/programs/demo_mnist.cpp
new file mode 100644 (file)
index 0000000..7c0937b
--- /dev/null
@@ -0,0 +1,93 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cuda_runtime.h>
+#include <fcntl.h>
+#include <google/protobuf/text_format.h>
+
+#include <cstring>
+
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/net.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/util/io.hpp"
+#include "caffe/optimization/solver.hpp"
+
+using namespace caffe;
+
+int main(int argc, char** argv) {
+  cudaSetDevice(1);
+  Caffe::set_mode(Caffe::GPU);
+
+  NetParameter net_param;
+  ReadProtoFromTextFile("caffe/test/data/lenet.prototxt",
+      &net_param);
+  vector<Blob<float>*> bottom_vec;
+  Net<float> caffe_net(net_param, bottom_vec);
+
+  // Run the network without training.
+  LOG(ERROR) << "Performing Forward";
+  caffe_net.Forward(bottom_vec);
+  LOG(ERROR) << "Performing Backward";
+  LOG(ERROR) << "Initial loss: " << caffe_net.Backward();
+
+  SolverParameter solver_param;
+  solver_param.set_base_lr(0.01);
+  solver_param.set_display(0);
+  solver_param.set_max_iter(6000);
+  solver_param.set_lr_policy("inv");
+  solver_param.set_gamma(0.0001);
+  solver_param.set_power(0.75);
+  solver_param.set_momentum(0.9);
+
+  LOG(ERROR) << "Starting Optimization";
+  SGDSolver<float> solver(solver_param);
+  solver.Solve(&caffe_net);
+  LOG(ERROR) << "Optimization Done.";
+
+  // Run the network after training.
+  LOG(ERROR) << "Performing Forward";
+  caffe_net.Forward(bottom_vec);
+  LOG(ERROR) << "Performing Backward";
+  float loss = caffe_net.Backward();
+  LOG(ERROR) << "Final loss: " << loss;
+
+  NetParameter trained_net_param;
+  caffe_net.ToProto(&trained_net_param);
+
+  NetParameter traintest_net_param;
+  ReadProtoFromTextFile("caffe/test/data/lenet_traintest.prototxt",
+      &traintest_net_param);
+  Net<float> caffe_traintest_net(traintest_net_param, bottom_vec);
+  caffe_traintest_net.CopyTrainedLayersFrom(trained_net_param);
+
+  // Test run
+  double train_accuracy = 0;
+  int batch_size = traintest_net_param.layers(0).layer().batchsize();
+  for (int i = 0; i < 60000 / batch_size; ++i) {
+    const vector<Blob<float>*>& result =
+        caffe_traintest_net.Forward(bottom_vec);
+    train_accuracy += result[0]->cpu_data()[0];
+  }
+  train_accuracy /= 60000 / batch_size;
+  LOG(ERROR) << "Train accuracy:" << train_accuracy;
+
+  NetParameter test_net_param;
+  ReadProtoFromTextFile("caffe/test/data/lenet_test.prototxt", &test_net_param);
+  Net<float> caffe_test_net(test_net_param, bottom_vec);
+  caffe_test_net.CopyTrainedLayersFrom(trained_net_param);
+
+  // Test run
+  double test_accuracy = 0;
+  batch_size = test_net_param.layers(0).layer().batchsize();
+  for (int i = 0; i < 10000 / batch_size; ++i) {
+    const vector<Blob<float>*>& result =
+        caffe_test_net.Forward(bottom_vec);
+    test_accuracy += result[0]->cpu_data()[0];
+  }
+  test_accuracy /= 10000 / batch_size;
+  LOG(ERROR) << "Test accuracy:" << test_accuracy;
+
+  return 0;
+}