From e9c08e5a8c39508ecff052ef36df63bf883d2dad Mon Sep 17 00:00:00 2001 From: Sergio Guadarrama Date: Thu, 11 Sep 2014 15:19:57 -0700 Subject: [PATCH] Fixed MemoryDataLayer to make it work with pycaffe --- include/caffe/data_layers.hpp | 3 +++ include/caffe/data_transformer.hpp | 13 +------------ python/caffe/_caffe.cpp | 4 ++-- src/caffe/data_transformer.cpp | 14 +++++++++++++- src/caffe/layers/data_layer.cpp | 1 - src/caffe/layers/window_data_layer.cpp | 2 +- 6 files changed, 20 insertions(+), 17 deletions(-) diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index bf06865..e598a71 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -283,6 +283,9 @@ class MemoryDataLayer : public BaseDataLayer { void Reset(Dtype* data, Dtype* label, int n); int batch_size() { return batch_size_; } + int channels() { return channels_; } + int height() { return height_; } + int width() { return width_; } protected: virtual void Forward_cpu(const vector*>& bottom, diff --git a/include/caffe/data_transformer.hpp b/include/caffe/data_transformer.hpp index 33afa72..fcd1011 100644 --- a/include/caffe/data_transformer.hpp +++ b/include/caffe/data_transformer.hpp @@ -14,18 +14,7 @@ namespace caffe { template class DataTransformer { public: - explicit DataTransformer(const TransformationParameter& param) - : param_(param) { - phase_ = Caffe::phase(); - // check if we want to have mean - if (param_.has_mean_file()) { - const string& mean_file = param.mean_file(); - LOG(INFO) << "Loading mean file from" << mean_file; - BlobProto blob_proto; - ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); - data_mean_.FromProto(blob_proto); - } - } + explicit DataTransformer(const TransformationParameter& param); virtual ~DataTransformer() {} void InitRand(); diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index 5a8d99d..33e68fe 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -106,8 +106,8 @@ void PyNet::set_input_arrays(bp::object data_obj, bp::object labels_obj) { reinterpret_cast(data_obj.ptr()); PyArrayObject* labels_arr = reinterpret_cast(labels_obj.ptr()); - check_contiguous_array(data_arr, "data array", md_layer->datum_channels(), - md_layer->datum_height(), md_layer->datum_width()); + check_contiguous_array(data_arr, "data array", md_layer->channels(), + md_layer->height(), md_layer->width()); check_contiguous_array(labels_arr, "labels array", 1, 1, 1); if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) { throw std::runtime_error("data and labels must have the same first" diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index 81cb17b..553717a 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -6,12 +6,24 @@ namespace caffe { +template +DataTransformer::DataTransformer(const TransformationParameter& param) + : param_(param) { + phase_ = Caffe::phase(); + // check if we want to have mean + if (param_.has_mean_file()) { + const string& mean_file = param.mean_file(); + LOG(INFO) << "Loading mean file from" << mean_file; + BlobProto blob_proto; + ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); + data_mean_.FromProto(blob_proto); + } +} template void DataTransformer::Transform(const int batch_item_id, const Datum& datum, Dtype* transformed_data) { - CHECK_GT(datum.channels(), 0); CHECK_GE(datum.height(), param_.crop_size()); CHECK_GE(datum.height(), param_.crop_size()); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 3867a87..40c4873 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -134,7 +134,6 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, this->prefetch_label_.Reshape(this->layer_param_.data_param().batch_size(), 1, 1, 1); } - } // This function is used to create a thread that prefetches the data. diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index 88afe82..6b70aa3 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -170,7 +170,7 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, // data mean if (this->layer_param_.window_data_param().has_mean_file()) { - const string& mean_file = + const string& mean_file = this->layer_param_.window_data_param().mean_file(); LOG(INFO) << "Loading mean file from" << mean_file; BlobProto blob_proto; -- 2.7.4