softmax layer, test to be written
authorYangqing Jia <jiayq84@gmail.com>
Wed, 25 Sep 2013 22:38:04 +0000 (15:38 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 25 Sep 2013 22:38:04 +0000 (15:38 -0700)
src/caffe/blob.cpp
src/caffe/blob.hpp
src/caffe/layer_factory.hpp
src/caffe/layers/softmax_layer.cpp [new file with mode: 0644]
src/caffe/net.cpp [new file with mode: 0644]
src/caffe/net.cpp.work [deleted file]
src/caffe/net.hpp
src/caffe/test/test_softmax_layer.cpp [new file with mode: 0644]
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.hpp
src/caffe/vision_layers.hpp

index fdd7036..aacb05c 100644 (file)
@@ -37,9 +37,21 @@ Blob<Dtype>::Blob(const Blob<Dtype>& source) {
   } else {
     Reshape(source.num(), source.channels(), source.height(),
         source.width());
-    // create the synced memories.
-    data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
-    diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+    if (count_ > 0) {
+      // Copy the data.
+      memcpy(data_->mutable_cpu_data(), source.cpu_data(),
+          count_ * sizeof(Dtype));
+      memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
+          count_ * sizeof(Dtype));
+    }
+  }
+}
+
+template <typename Dtype>
+const Blob<Dtype>& Blob<Dtype>::operator=(const Blob<Dtype>& source) {
+  Reshape(source.num(), source.channels(), source.height(),
+        source.width());
+  if (count_ > 0) {
     // Copy the data.
     memcpy(data_->mutable_cpu_data(), source.cpu_data(),
         count_ * sizeof(Dtype));
@@ -110,9 +122,11 @@ void Blob<Dtype>::FromProto(const BlobProto& proto) {
   for (int i = 0; i < count_; ++i) {
     data_vec[i] = proto.data(i);
   }
-  Dtype* diff_vec = mutable_cpu_diff();
-  for (int i = 0; i < count_; ++i) {
-    diff_vec[i] = proto.diff(i);
+  if (proto.diff_size() > 0) {
+    Dtype* diff_vec = mutable_cpu_diff();
+    for (int i = 0; i < count_; ++i) {
+      diff_vec[i] = proto.diff(i);
+    }
   }
 }
 
index 7818a05..a14c046 100644 (file)
@@ -18,6 +18,7 @@ class Blob {
   explicit Blob(const int num, const int channels, const int height,
     const int width);
   Blob(const Blob<Dtype>& source);
+  const Blob<Dtype>& operator=(const Blob<Dtype>& src);
   virtual ~Blob() {}
   void Reshape(const int num, const int height,
       const int width, const int channels);
index 46b5516..90e6d66 100644 (file)
@@ -21,6 +21,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
   const std::string& type = param.type();
   if (type == "conv") {
     return new ConvolutionLayer<Dtype>(param);
+  } else if (type == "data") {
+    return new DataLayer<Dtype>(param);
   } else if (type == "dropout") {
     return new DropoutLayer<Dtype>(param);
   } else if (type == "im2col") {
diff --git a/src/caffe/layers/softmax_layer.cpp b/src/caffe/layers/softmax_layer.cpp
new file mode 100644 (file)
index 0000000..31f25c3
--- /dev/null
@@ -0,0 +1,85 @@
+// Copyright 2013 Yangqing Jia
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+#include <algorithm>
+
+using std::max;
+
+namespace caffe {
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 1) << "Softmax Layer takes a single blob as input.";
+  CHECK_EQ(top->size(), 1) << "Softmax Layer takes a single blob as output.";
+  (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+  sum_multiplier_.Reshape(1, bottom[0]->channels(),
+      bottom[0]->height(), bottom[0]->width());
+  Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
+  for (int i = 0; i < bottom[0]->num(); ++i) {
+    multiplier_data[i] = 1.;
+  }
+  scale_.Reshape(bottom[0]->num(), 1, 1, 1);
+};
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    vector<Blob<Dtype>*>* top) {
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = (*top)[0]->mutable_cpu_data();
+  Dtype* scale_data = scale_.mutable_cpu_data();
+  int num = bottom[0]->num();
+  int dim = bottom[0]->count() / bottom[0]->num();
+  memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count());
+  // we need to subtract the sum to avoid numerical issues, compute the exp,
+  // and then normalize.
+  // Compute sum
+  caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., bottom_data,
+      sum_multiplier_.cpu_data(), 0., scale_data);
+  // subtraction
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+    scale_data, sum_multiplier_.cpu_data(), 1., top_data);
+  // Perform exponentiation
+  caffe_exp<Dtype>(num * dim, top_data, top_data);
+  // sum after exp
+  caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
+      sum_multiplier_.cpu_data(), 0., scale_data);
+  // Do division
+  for (int i = 0; i < num; ++i) {
+    caffe_scal<Dtype>(dim, Dtype(1.) / scale_data[i], top_data + i * dim);
+  }
+}
+
+template <typename Dtype>
+Dtype SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  const Dtype* top_diff = top[0]->cpu_diff();
+  const Dtype* top_data = top[0]->cpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  Dtype* scale_data = scale_.mutable_cpu_data();
+  int num = top[0]->num();
+  int dim = top[0]->count() / top[0]->num();
+  memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count());
+  // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
+  for (int i = 0; i < num; ++i) {
+    scale_data[i] = caffe_cpu_dot<Dtype>(dim, top_diff + i * dim,
+        top_data + i * dim);
+  }
+  // subtraction
+  caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+      scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff);
+  // elementwise multiplication
+  caffe_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
+  return Dtype(0);
+}
+
+// TODO: implement the GPU version of softmax.
+
+INSTANTIATE_CLASS(SoftmaxLayer);
+
+
+}  // namespace caffe
diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp
new file mode 100644 (file)
index 0000000..93969b1
--- /dev/null
@@ -0,0 +1,160 @@
+// Copyright Yangqing Jia 2013
+
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/layer_factory.hpp"
+#include "caffe/net.hpp"
+
+using std::pair;
+using std::map;
+using std::set;
+
+namespace caffe {
+
+template <typename Dtype>
+Net<Dtype>::Net(const NetParameter& param,
+    const vector<Blob<Dtype>* >& bottom) {
+  // Basically, build all the layers and set up its connections.
+  name_ = param.name();
+  map<string, int> blob_name_to_idx;
+  set<string> available_blobs;
+  int num_layers = param.layers_size();
+  CHECK_EQ(bottom.size(), param.bottom_size())
+      << "Incorrect bottom blob size.";
+  // set the input blobs
+  for (int i = 0; i < param.bottom_size(); ++i) {
+    const string& blob_name = param.bottom(i);
+    blobs_.push_back(Blob<Dtype>(*bottom[i]));
+    blob_names_.push_back(blob_name);
+    net_input_blob_indices_.push_back(i);
+    blob_name_to_idx[blob_name] = i;
+    available_blobs.insert(blob_name);
+  }
+  // For each layer, set up their input and output
+  layers_.resize(param.layers_size());
+  bottom_vecs_.resize(param.layers_size());
+  top_vecs_.resize(param.layers_size());
+  for (int i = 0; i < param.top_size(); ++i) {
+    const LayerConnection& layer_connection = param.layers(i);
+    const LayerParameter& layer_param = layer_connection.layer();
+    layers_[i].reset(GetLayer<Dtype>(layer_param));
+    // Figure out this layer's input and output
+    for (int j = 0; j < layer_connection.bottom_size(); ++j) {
+      const string& blob_name = layer_connection.bottom(j);
+      if (available_blobs.find(blob_name) == available_blobs.end()) {
+        LOG(FATAL) << "Unknown blob input " << blob_name <<
+            " to layer" << j;
+      }
+      bottom_vecs_[i].push_back(
+          &blobs_[blob_name_to_idx[blob_name]]);
+      available_blobs.erase(blob_name);
+    }
+    for (int j = 0; j < layer_connection.top_size(); ++j) {
+      const string& blob_name = layer_connection.top(j);
+      if (blob_name_to_idx.find(blob_name) != blob_name_to_idx.end()) {
+        LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
+      }
+      blobs_.push_back(Blob<Dtype>());
+      blob_names_.push_back(blob_name);
+      blob_name_to_idx[blob_name] = blob_names_.size() - 1;
+      available_blobs.insert(blob_name);
+      top_vecs_[i].push_back(&blobs_[blob_names_.size() - 1]);
+    }
+  }
+  // In the end, check if all remaining available blobs are top blobs.
+  for (int i = 0; i < param.top_size(); ++i) {
+    const string& blob_name = param.top(i);
+    if (blob_name_to_idx.find(blob_name) == blob_name_to_idx.end()) {
+      LOG(FATAL) << "Unknown blob input " << blob_name;
+    }
+    net_output_blob_indices_.push_back(blob_name_to_idx[blob_name]);
+    available_blobs.erase(blob_name);
+  }
+  if (!available_blobs.empty()) {
+    LOG(WARNING) << "There are some internal blobs not used:";
+    for (set<string>::iterator it = available_blobs.begin();
+        it != available_blobs.end(); ++it) {
+      LOG(WARNING) << "    " << *it;
+    }
+  }
+
+  LOG(INFO) << "Setting up the layers.";
+  for (int i = 0; i < layers_.size(); ++i) {
+    layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
+  }
+}
+
+template <typename Dtype>
+void Net<Dtype>::Forward(const vector<Blob<Dtype>*> & bottom,
+    vector<Blob<Dtype>*>* top) {
+  // Copy bottom to internal bottom
+  for (int i = 0; i < bottom.size(); ++i) {
+    blobs_[net_input_blob_indices_[i]] = *bottom[i];
+  }
+  for (int i = 0; i < layers_.size(); ++i) {
+    layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
+  }
+  // Copy internal top to top
+  for (int i = 0; i < (*top).size(); ++i) {
+    NOT_IMPLEMENTED;
+  }
+}
+
+template <typename Dtype>
+Dtype Net<Dtype>::Backward() {
+  Dtype loss = 0;
+  // TODO(Yangqing): figure out those layers that do not need backward.
+  for (int i = layers_.size() - 1; i >= 0; --i) {
+    loss += layers_[i]->Backward(top_vecs_[i], true, &bottom_vecs_[i]);
+  }
+  return loss;
+}
+
+template <typename Dtype>
+void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
+  int num_source_layers = param.layers_size();
+  for (int i = 0; i < num_source_layers; ++i) {
+    const LayerParameter& source_layer = param.layers(i).layer();
+    const string& source_layer_name = source_layer.name();
+    int target_layer_id = 0;
+    while (target_layer_id != layer_names_.size() &&
+        layer_names_[target_layer_id] != source_layer_name) {
+      ++target_layer_id;
+    }
+    if (target_layer_id == layer_names_.size()) {
+      LOG(INFO) << "Ignoring source layer " << source_layer_name;
+      continue;
+    }
+    LOG(INFO) << "Loading source layer " << source_layer_name;
+    vector<Blob<Dtype> >& target_blobs = layers_[target_layer_id]->params();
+    CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
+        << "Incompatible number of blobs for layer " << source_layer_name;
+    for (int j = 0; j < target_blobs.size(); ++j) {
+      target_blobs[j].FromProto(source_layer.blobs(j));
+    }
+  }
+}
+
+template <typename Dtype>
+void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
+  param->Clear();
+  param->set_name(name_);
+  // Add bottom and top
+  for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
+    param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
+  }
+  for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
+    param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
+  }
+  for (int i = 0; i < layers_.size(); ++i) {
+    LayerConnection* layer_connection = param->add_layers();
+  }
+}
+
+INSTANTIATE_CLASS(Net);
+
+}  // namespace caffe
diff --git a/src/caffe/net.cpp.work b/src/caffe/net.cpp.work
deleted file mode 100644 (file)
index 0ab0afb..0000000
+++ /dev/null
@@ -1,6 +0,0 @@
-// Copyright Yangqing Jia 2013
-
-#include <string>
-#include <vector>
-
-#include "caffe/proto/layer_param.proto"
index c8426ea..debb59e 100644 (file)
@@ -1,7 +1,7 @@
 // Copyright 2013 Yangqing Jia
 
-#ifndef CAFFE_LAYER_H_
-#define CAFFE_LAYER_H_
+#ifndef CAFFE_NET_HPP_
+#define CAFFE_NET_HPP_
 
 #include <map>
 #include <string>
@@ -20,12 +20,15 @@ namespace caffe {
 template <typename Dtype>
 class Net {
  public:
-  explicit Net(const NetParameter& param);
-  ~Net();
-  void Forward(const vector<Blob<Dtype*>> & bottom,
-      vector<Blob<Dtype*>* top);
-  Dtype Backward(const vector<Blob<Dtype*>> & bottom,
-      vector<Blob<Dtype*>* top);
+  Net(const NetParameter& param,
+      const vector<Blob<Dtype>* >& bottom);
+  ~Net() {}
+  void Forward(const vector<Blob<Dtype>* > & bottom,
+      vector<Blob<Dtype>*>* top);
+  // The network backward should take no input and output, since it solely
+  // computes the gradient w.r.t the parameters, and the data has already
+  // been provided during the forward pass.
+  Dtype Backward();
 
   // For an already initialized net, CopyTrainedLayersFrom() copies the already
   // trained layers from another net parameter instance.
@@ -33,21 +36,30 @@ class Net {
   // Writes the net to a proto.
   void ToProto(NetParameter* param, bool write_diff = false);
 
+  // returns the network name.
+  const string& name() { return name_; }
+
  protected:
   // Individual layers in the net
   vector<shared_ptr<Layer<Dtype> > > layers_;
-  vector<shared_ptr<Layer<Dtype> > > layer_names_;
-  // bottom_vecs stores the vectors containing the input for each layer
-  vector<vector<Blob<Dtype>*> > bottom_vecs_;
-  // top_vecs stores the vectors containing the output for each layer
-  vector<vector<Blob<Dtype>* > top_vecs_;
+  vector<string> layer_names_;
   // blobs stores the blobs that store intermediate results between the
   // layers.
-  vector<shared_ptr<Blob<Dtype> > blobs_;
-  vector<shared_ptr<Blob<Dtype> > blob_names_;
+  vector<Blob<Dtype> > blobs_;
+  vector<string> blob_names_;
+  // bottom_vecs stores the vectors containing the input for each layer, except
+  // for the first layer whose bottom vec is provided by the network's input.
+  vector<vector<Blob<Dtype>*> > bottom_vecs_;
+  // top_vecs stores the vectors containing the output for each layer, except
+  // for the last layer (likewise)
+  vector<vector<Blob<Dtype>*> > top_vecs_;
+  // blob indices for the input and the output of the net.
+  vector<int> net_input_blob_indices_;
+  vector<int> net_output_blob_indices_;
+  string name_;
 };
 
 
 }  // namespace caffe
 
-#endif  // CAFFE_LAYER_H_
+#endif  // CAFFE_NET_HPP_
diff --git a/src/caffe/test/test_softmax_layer.cpp b/src/caffe/test/test_softmax_layer.cpp
new file mode 100644 (file)
index 0000000..37391ea
--- /dev/null
@@ -0,0 +1,53 @@
+// Copyright 2013 Yangqing Jia
+
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class SoftmaxLayerTest : public ::testing::Test {
+ protected:
+  SoftmaxLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 10, 1, 1)),
+        blob_top_(new Blob<Dtype>()) {
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  };
+  virtual ~SoftmaxLayerTest() { delete blob_bottom_; delete blob_top_; }
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(SoftmaxLayerTest, Dtypes);
+
+TYPED_TEST(SoftmaxLayerTest, TestReLUCPU) {
+  LayerParameter layer_param;
+  Caffe::set_mode(Caffe::CPU);
+  SoftmaxLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  NOT_IMPLEMENTED;
+}
+
+
+
+}
index ecb0fe0..1949a70 100644 (file)
@@ -186,5 +186,24 @@ void caffe_vRngGaussian<double>(const int n, double* r, const double a,
       Caffe::vsl_stream(), n, r, a, sigma));
 }
 
+template <>
+void caffe_exp<float>(const int n, const float* a, float* y) {
+  vsExp(n, a, y);
+}
+
+template <>
+void caffe_exp<double>(const int n, const double* a, double* y) {
+  vdExp(n, a, y);
+}
+
+template <>
+float caffe_cpu_dot<float>(const int n, const float* x, const float* y) {
+  return cblas_sdot(n, x, 1, y, 1);
+}
+
+template <>
+double caffe_cpu_dot<double>(const int n, const double* x, const double* y) {
+  return cblas_ddot(n, x, 1, y, 1);
+}
 
 }  // namespace caffe
index e06a3a7..822ef31 100644 (file)
@@ -64,6 +64,12 @@ template <typename Dtype>
 void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
     const Dtype sigma);
 
+template <typename Dtype>
+void caffe_exp(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
+Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
+
 }  // namespace caffe
 
 
index 336ed0f..2c7af47 100644 (file)
@@ -255,6 +255,32 @@ class DataLayer : public Layer<Dtype> {
   int datum_size_;
 };
 
+
+template <typename Dtype>
+class SoftmaxLayer : public Layer<Dtype> {
+ public:
+  explicit SoftmaxLayer(const LayerParameter& param)
+      : Layer<Dtype>(param) {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+  //    vector<Blob<Dtype>*>* top);
+  virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+  //    const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  // sum_multiplier is just used to carry out sum using blas
+  Blob<Dtype> sum_multiplier_;
+  // scale is an intermediate blob to hold temporary results.
+  Blob<Dtype> scale_;
+};
+
+
 }  // namespace caffe
 
 #endif  // CAFFE_VISION_LAYERS_HPP_