} else {
Reshape(source.num(), source.channels(), source.height(),
source.width());
- // create the synced memories.
- data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
- diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
+ if (count_ > 0) {
+ // Copy the data.
+ memcpy(data_->mutable_cpu_data(), source.cpu_data(),
+ count_ * sizeof(Dtype));
+ memcpy(diff_->mutable_cpu_data(), source.cpu_diff(),
+ count_ * sizeof(Dtype));
+ }
+ }
+}
+
+template <typename Dtype>
+const Blob<Dtype>& Blob<Dtype>::operator=(const Blob<Dtype>& source) {
+ Reshape(source.num(), source.channels(), source.height(),
+ source.width());
+ if (count_ > 0) {
// Copy the data.
memcpy(data_->mutable_cpu_data(), source.cpu_data(),
count_ * sizeof(Dtype));
for (int i = 0; i < count_; ++i) {
data_vec[i] = proto.data(i);
}
- Dtype* diff_vec = mutable_cpu_diff();
- for (int i = 0; i < count_; ++i) {
- diff_vec[i] = proto.diff(i);
+ if (proto.diff_size() > 0) {
+ Dtype* diff_vec = mutable_cpu_diff();
+ for (int i = 0; i < count_; ++i) {
+ diff_vec[i] = proto.diff(i);
+ }
}
}
explicit Blob(const int num, const int channels, const int height,
const int width);
Blob(const Blob<Dtype>& source);
+ const Blob<Dtype>& operator=(const Blob<Dtype>& src);
virtual ~Blob() {}
void Reshape(const int num, const int height,
const int width, const int channels);
const std::string& type = param.type();
if (type == "conv") {
return new ConvolutionLayer<Dtype>(param);
+ } else if (type == "data") {
+ return new DataLayer<Dtype>(param);
} else if (type == "dropout") {
return new DropoutLayer<Dtype>(param);
} else if (type == "im2col") {
--- /dev/null
+// Copyright 2013 Yangqing Jia
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+#include <algorithm>
+
+using std::max;
+
+namespace caffe {
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ CHECK_EQ(bottom.size(), 1) << "Softmax Layer takes a single blob as input.";
+ CHECK_EQ(top->size(), 1) << "Softmax Layer takes a single blob as output.";
+ (*top)[0]->Reshape(bottom[0]->num(), bottom[0]->channels(),
+ bottom[0]->height(), bottom[0]->width());
+ sum_multiplier_.Reshape(1, bottom[0]->channels(),
+ bottom[0]->height(), bottom[0]->width());
+ Dtype* multiplier_data = sum_multiplier_.mutable_cpu_data();
+ for (int i = 0; i < bottom[0]->num(); ++i) {
+ multiplier_data[i] = 1.;
+ }
+ scale_.Reshape(bottom[0]->num(), 1, 1, 1);
+};
+
+template <typename Dtype>
+void SoftmaxLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top) {
+ const Dtype* bottom_data = bottom[0]->cpu_data();
+ Dtype* top_data = (*top)[0]->mutable_cpu_data();
+ Dtype* scale_data = scale_.mutable_cpu_data();
+ int num = bottom[0]->num();
+ int dim = bottom[0]->count() / bottom[0]->num();
+ memcpy(top_data, bottom_data, sizeof(Dtype) * bottom[0]->count());
+ // we need to subtract the sum to avoid numerical issues, compute the exp,
+ // and then normalize.
+ // Compute sum
+ caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., bottom_data,
+ sum_multiplier_.cpu_data(), 0., scale_data);
+ // subtraction
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+ scale_data, sum_multiplier_.cpu_data(), 1., top_data);
+ // Perform exponentiation
+ caffe_exp<Dtype>(num * dim, top_data, top_data);
+ // sum after exp
+ caffe_cpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
+ sum_multiplier_.cpu_data(), 0., scale_data);
+ // Do division
+ for (int i = 0; i < num; ++i) {
+ caffe_scal<Dtype>(dim, Dtype(1.) / scale_data[i], top_data + i * dim);
+ }
+}
+
+template <typename Dtype>
+Dtype SoftmaxLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down,
+ vector<Blob<Dtype>*>* bottom) {
+ const Dtype* top_diff = top[0]->cpu_diff();
+ const Dtype* top_data = top[0]->cpu_data();
+ Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+ Dtype* scale_data = scale_.mutable_cpu_data();
+ int num = top[0]->num();
+ int dim = top[0]->count() / top[0]->num();
+ memcpy(bottom_diff, top_diff, sizeof(Dtype) * top[0]->count());
+ // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
+ for (int i = 0; i < num; ++i) {
+ scale_data[i] = caffe_cpu_dot<Dtype>(dim, top_diff + i * dim,
+ top_data + i * dim);
+ }
+ // subtraction
+ caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
+ scale_data, sum_multiplier_.cpu_data(), 1., bottom_diff);
+ // elementwise multiplication
+ caffe_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
+ return Dtype(0);
+}
+
+// TODO: implement the GPU version of softmax.
+
+INSTANTIATE_CLASS(SoftmaxLayer);
+
+
+} // namespace caffe
--- /dev/null
+// Copyright Yangqing Jia 2013
+
+#include <map>
+#include <set>
+#include <string>
+#include <vector>
+
+#include "caffe/proto/caffe.pb.h"
+#include "caffe/layer_factory.hpp"
+#include "caffe/net.hpp"
+
+using std::pair;
+using std::map;
+using std::set;
+
+namespace caffe {
+
+template <typename Dtype>
+Net<Dtype>::Net(const NetParameter& param,
+ const vector<Blob<Dtype>* >& bottom) {
+ // Basically, build all the layers and set up its connections.
+ name_ = param.name();
+ map<string, int> blob_name_to_idx;
+ set<string> available_blobs;
+ int num_layers = param.layers_size();
+ CHECK_EQ(bottom.size(), param.bottom_size())
+ << "Incorrect bottom blob size.";
+ // set the input blobs
+ for (int i = 0; i < param.bottom_size(); ++i) {
+ const string& blob_name = param.bottom(i);
+ blobs_.push_back(Blob<Dtype>(*bottom[i]));
+ blob_names_.push_back(blob_name);
+ net_input_blob_indices_.push_back(i);
+ blob_name_to_idx[blob_name] = i;
+ available_blobs.insert(blob_name);
+ }
+ // For each layer, set up their input and output
+ layers_.resize(param.layers_size());
+ bottom_vecs_.resize(param.layers_size());
+ top_vecs_.resize(param.layers_size());
+ for (int i = 0; i < param.top_size(); ++i) {
+ const LayerConnection& layer_connection = param.layers(i);
+ const LayerParameter& layer_param = layer_connection.layer();
+ layers_[i].reset(GetLayer<Dtype>(layer_param));
+ // Figure out this layer's input and output
+ for (int j = 0; j < layer_connection.bottom_size(); ++j) {
+ const string& blob_name = layer_connection.bottom(j);
+ if (available_blobs.find(blob_name) == available_blobs.end()) {
+ LOG(FATAL) << "Unknown blob input " << blob_name <<
+ " to layer" << j;
+ }
+ bottom_vecs_[i].push_back(
+ &blobs_[blob_name_to_idx[blob_name]]);
+ available_blobs.erase(blob_name);
+ }
+ for (int j = 0; j < layer_connection.top_size(); ++j) {
+ const string& blob_name = layer_connection.top(j);
+ if (blob_name_to_idx.find(blob_name) != blob_name_to_idx.end()) {
+ LOG(FATAL) << "Duplicate blobs produced by multiple sources.";
+ }
+ blobs_.push_back(Blob<Dtype>());
+ blob_names_.push_back(blob_name);
+ blob_name_to_idx[blob_name] = blob_names_.size() - 1;
+ available_blobs.insert(blob_name);
+ top_vecs_[i].push_back(&blobs_[blob_names_.size() - 1]);
+ }
+ }
+ // In the end, check if all remaining available blobs are top blobs.
+ for (int i = 0; i < param.top_size(); ++i) {
+ const string& blob_name = param.top(i);
+ if (blob_name_to_idx.find(blob_name) == blob_name_to_idx.end()) {
+ LOG(FATAL) << "Unknown blob input " << blob_name;
+ }
+ net_output_blob_indices_.push_back(blob_name_to_idx[blob_name]);
+ available_blobs.erase(blob_name);
+ }
+ if (!available_blobs.empty()) {
+ LOG(WARNING) << "There are some internal blobs not used:";
+ for (set<string>::iterator it = available_blobs.begin();
+ it != available_blobs.end(); ++it) {
+ LOG(WARNING) << " " << *it;
+ }
+ }
+
+ LOG(INFO) << "Setting up the layers.";
+ for (int i = 0; i < layers_.size(); ++i) {
+ layers_[i]->SetUp(bottom_vecs_[i], &top_vecs_[i]);
+ }
+}
+
+template <typename Dtype>
+void Net<Dtype>::Forward(const vector<Blob<Dtype>*> & bottom,
+ vector<Blob<Dtype>*>* top) {
+ // Copy bottom to internal bottom
+ for (int i = 0; i < bottom.size(); ++i) {
+ blobs_[net_input_blob_indices_[i]] = *bottom[i];
+ }
+ for (int i = 0; i < layers_.size(); ++i) {
+ layers_[i]->Forward(bottom_vecs_[i], &top_vecs_[i]);
+ }
+ // Copy internal top to top
+ for (int i = 0; i < (*top).size(); ++i) {
+ NOT_IMPLEMENTED;
+ }
+}
+
+template <typename Dtype>
+Dtype Net<Dtype>::Backward() {
+ Dtype loss = 0;
+ // TODO(Yangqing): figure out those layers that do not need backward.
+ for (int i = layers_.size() - 1; i >= 0; --i) {
+ loss += layers_[i]->Backward(top_vecs_[i], true, &bottom_vecs_[i]);
+ }
+ return loss;
+}
+
+template <typename Dtype>
+void Net<Dtype>::CopyTrainedLayersFrom(const NetParameter& param) {
+ int num_source_layers = param.layers_size();
+ for (int i = 0; i < num_source_layers; ++i) {
+ const LayerParameter& source_layer = param.layers(i).layer();
+ const string& source_layer_name = source_layer.name();
+ int target_layer_id = 0;
+ while (target_layer_id != layer_names_.size() &&
+ layer_names_[target_layer_id] != source_layer_name) {
+ ++target_layer_id;
+ }
+ if (target_layer_id == layer_names_.size()) {
+ LOG(INFO) << "Ignoring source layer " << source_layer_name;
+ continue;
+ }
+ LOG(INFO) << "Loading source layer " << source_layer_name;
+ vector<Blob<Dtype> >& target_blobs = layers_[target_layer_id]->params();
+ CHECK_EQ(target_blobs.size(), source_layer.blobs_size())
+ << "Incompatible number of blobs for layer " << source_layer_name;
+ for (int j = 0; j < target_blobs.size(); ++j) {
+ target_blobs[j].FromProto(source_layer.blobs(j));
+ }
+ }
+}
+
+template <typename Dtype>
+void Net<Dtype>::ToProto(NetParameter* param, bool write_diff) {
+ param->Clear();
+ param->set_name(name_);
+ // Add bottom and top
+ for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
+ param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
+ }
+ for (int i = 0; i < net_input_blob_indices_.size(); ++i) {
+ param->add_bottom(blob_names_[net_input_blob_indices_[i]]);
+ }
+ for (int i = 0; i < layers_.size(); ++i) {
+ LayerConnection* layer_connection = param->add_layers();
+ }
+}
+
+INSTANTIATE_CLASS(Net);
+
+} // namespace caffe
+++ /dev/null
-// Copyright Yangqing Jia 2013
-
-#include <string>
-#include <vector>
-
-#include "caffe/proto/layer_param.proto"
// Copyright 2013 Yangqing Jia
-#ifndef CAFFE_LAYER_H_
-#define CAFFE_LAYER_H_
+#ifndef CAFFE_NET_HPP_
+#define CAFFE_NET_HPP_
#include <map>
#include <string>
template <typename Dtype>
class Net {
public:
- explicit Net(const NetParameter& param);
- ~Net();
- void Forward(const vector<Blob<Dtype*>> & bottom,
- vector<Blob<Dtype*>* top);
- Dtype Backward(const vector<Blob<Dtype*>> & bottom,
- vector<Blob<Dtype*>* top);
+ Net(const NetParameter& param,
+ const vector<Blob<Dtype>* >& bottom);
+ ~Net() {}
+ void Forward(const vector<Blob<Dtype>* > & bottom,
+ vector<Blob<Dtype>*>* top);
+ // The network backward should take no input and output, since it solely
+ // computes the gradient w.r.t the parameters, and the data has already
+ // been provided during the forward pass.
+ Dtype Backward();
// For an already initialized net, CopyTrainedLayersFrom() copies the already
// trained layers from another net parameter instance.
// Writes the net to a proto.
void ToProto(NetParameter* param, bool write_diff = false);
+ // returns the network name.
+ const string& name() { return name_; }
+
protected:
// Individual layers in the net
vector<shared_ptr<Layer<Dtype> > > layers_;
- vector<shared_ptr<Layer<Dtype> > > layer_names_;
- // bottom_vecs stores the vectors containing the input for each layer
- vector<vector<Blob<Dtype>*> > bottom_vecs_;
- // top_vecs stores the vectors containing the output for each layer
- vector<vector<Blob<Dtype>* > top_vecs_;
+ vector<string> layer_names_;
// blobs stores the blobs that store intermediate results between the
// layers.
- vector<shared_ptr<Blob<Dtype> > blobs_;
- vector<shared_ptr<Blob<Dtype> > blob_names_;
+ vector<Blob<Dtype> > blobs_;
+ vector<string> blob_names_;
+ // bottom_vecs stores the vectors containing the input for each layer, except
+ // for the first layer whose bottom vec is provided by the network's input.
+ vector<vector<Blob<Dtype>*> > bottom_vecs_;
+ // top_vecs stores the vectors containing the output for each layer, except
+ // for the last layer (likewise)
+ vector<vector<Blob<Dtype>*> > top_vecs_;
+ // blob indices for the input and the output of the net.
+ vector<int> net_input_blob_indices_;
+ vector<int> net_output_blob_indices_;
+ string name_;
};
} // namespace caffe
-#endif // CAFFE_LAYER_H_
+#endif // CAFFE_NET_HPP_
--- /dev/null
+// Copyright 2013 Yangqing Jia
+
+#include <cstring>
+#include <cuda_runtime.h>
+
+#include "gtest/gtest.h"
+#include "caffe/blob.hpp"
+#include "caffe/common.hpp"
+#include "caffe/filler.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/test/test_gradient_check_util.hpp"
+
+#include "caffe/test/test_caffe_main.hpp"
+
+namespace caffe {
+
+extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
+
+template <typename Dtype>
+class SoftmaxLayerTest : public ::testing::Test {
+ protected:
+ SoftmaxLayerTest()
+ : blob_bottom_(new Blob<Dtype>(2, 10, 1, 1)),
+ blob_top_(new Blob<Dtype>()) {
+ // fill the values
+ FillerParameter filler_param;
+ GaussianFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ blob_bottom_vec_.push_back(blob_bottom_);
+ blob_top_vec_.push_back(blob_top_);
+ };
+ virtual ~SoftmaxLayerTest() { delete blob_bottom_; delete blob_top_; }
+ Blob<Dtype>* const blob_bottom_;
+ Blob<Dtype>* const blob_top_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+typedef ::testing::Types<float, double> Dtypes;
+TYPED_TEST_CASE(SoftmaxLayerTest, Dtypes);
+
+TYPED_TEST(SoftmaxLayerTest, TestReLUCPU) {
+ LayerParameter layer_param;
+ Caffe::set_mode(Caffe::CPU);
+ SoftmaxLayer<TypeParam> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+ NOT_IMPLEMENTED;
+}
+
+
+
+}
Caffe::vsl_stream(), n, r, a, sigma));
}
+template <>
+void caffe_exp<float>(const int n, const float* a, float* y) {
+ vsExp(n, a, y);
+}
+
+template <>
+void caffe_exp<double>(const int n, const double* a, double* y) {
+ vdExp(n, a, y);
+}
+
+template <>
+float caffe_cpu_dot<float>(const int n, const float* x, const float* y) {
+ return cblas_sdot(n, x, 1, y, 1);
+}
+
+template <>
+double caffe_cpu_dot<double>(const int n, const double* x, const double* y) {
+ return cblas_ddot(n, x, 1, y, 1);
+}
} // namespace caffe
void caffe_vRngGaussian(const int n, Dtype* r, const Dtype a,
const Dtype sigma);
+template <typename Dtype>
+void caffe_exp(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
+Dtype caffe_cpu_dot(const int n, const Dtype* x, const Dtype* y);
+
} // namespace caffe
int datum_size_;
};
+
+template <typename Dtype>
+class SoftmaxLayer : public Layer<Dtype> {
+ public:
+ explicit SoftmaxLayer(const LayerParameter& param)
+ : Layer<Dtype>(param) {}
+ virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+
+ protected:
+ virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+ vector<Blob<Dtype>*>* top);
+ //virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ // vector<Blob<Dtype>*>* top);
+ virtual Dtype Backward_cpu(const vector<Blob<Dtype>*>& top,
+ const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+ //virtual Dtype Backward_gpu(const vector<Blob<Dtype>*>& top,
+ // const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+ // sum_multiplier is just used to carry out sum using blas
+ Blob<Dtype> sum_multiplier_;
+ // scale is an intermediate blob to hold temporary results.
+ Blob<Dtype> scale_;
+};
+
+
} // namespace caffe
#endif // CAFFE_VISION_LAYERS_HPP_