3dc00345743b0e63ad48c16d917ab2c0fd55e244
[platform/upstream/caffeonacl.git] / src / caffe / test / test_memory_data_layer.cpp
1 #include <string>
2 #include <vector>
3
4 #include "caffe/data_layers.hpp"
5 #include "caffe/filler.hpp"
6
7 #include "caffe/test/test_caffe_main.hpp"
8
9 namespace caffe {
10
11 template <typename TypeParam>
12 class MemoryDataLayerTest : public MultiDeviceTest<TypeParam> {
13   typedef typename TypeParam::Dtype Dtype;
14
15  protected:
16   MemoryDataLayerTest()
17     : data_(new Blob<Dtype>()),
18       labels_(new Blob<Dtype>()),
19       data_blob_(new Blob<Dtype>()),
20       label_blob_(new Blob<Dtype>()) {}
21   virtual void SetUp() {
22     batch_size_ = 8;
23     batches_ = 12;
24     channels_ = 4;
25     height_ = 7;
26     width_ = 11;
27     blob_top_vec_.push_back(data_blob_);
28     blob_top_vec_.push_back(label_blob_);
29     // pick random input data
30     FillerParameter filler_param;
31     GaussianFiller<Dtype> filler(filler_param);
32     data_->Reshape(batches_ * batch_size_, channels_, height_, width_);
33     labels_->Reshape(batches_ * batch_size_, 1, 1, 1);
34     filler.Fill(this->data_);
35     filler.Fill(this->labels_);
36   }
37
38   virtual ~MemoryDataLayerTest() {
39     delete data_blob_;
40     delete label_blob_;
41     delete data_;
42     delete labels_;
43   }
44   int batch_size_;
45   int batches_;
46   int channels_;
47   int height_;
48   int width_;
49   // we don't really need blobs for the input data, but it makes it
50   //  easier to call Filler
51   Blob<Dtype>* const data_;
52   Blob<Dtype>* const labels_;
53   // blobs for the top of MemoryDataLayer
54   Blob<Dtype>* const data_blob_;
55   Blob<Dtype>* const label_blob_;
56   vector<Blob<Dtype>*> blob_bottom_vec_;
57   vector<Blob<Dtype>*> blob_top_vec_;
58 };
59
60 TYPED_TEST_CASE(MemoryDataLayerTest, TestDtypesAndDevices);
61
62 TYPED_TEST(MemoryDataLayerTest, TestSetup) {
63   typedef typename TypeParam::Dtype Dtype;
64
65   LayerParameter layer_param;
66   MemoryDataParameter* md_param = layer_param.mutable_memory_data_param();
67   md_param->set_batch_size(this->batch_size_);
68   md_param->set_channels(this->channels_);
69   md_param->set_height(this->height_);
70   md_param->set_width(this->width_);
71   shared_ptr<Layer<Dtype> > layer(
72       new MemoryDataLayer<Dtype>(layer_param));
73   layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
74   EXPECT_EQ(this->data_blob_->num(), this->batch_size_);
75   EXPECT_EQ(this->data_blob_->channels(), this->channels_);
76   EXPECT_EQ(this->data_blob_->height(), this->height_);
77   EXPECT_EQ(this->data_blob_->width(), this->width_);
78   EXPECT_EQ(this->label_blob_->num(), this->batch_size_);
79   EXPECT_EQ(this->label_blob_->channels(), 1);
80   EXPECT_EQ(this->label_blob_->height(), 1);
81   EXPECT_EQ(this->label_blob_->width(), 1);
82 }
83
84 // run through a few batches and check that the right data appears
85 TYPED_TEST(MemoryDataLayerTest, TestForward) {
86   typedef typename TypeParam::Dtype Dtype;
87
88   LayerParameter layer_param;
89   MemoryDataParameter* md_param = layer_param.mutable_memory_data_param();
90   md_param->set_batch_size(this->batch_size_);
91   md_param->set_channels(this->channels_);
92   md_param->set_height(this->height_);
93   md_param->set_width(this->width_);
94   shared_ptr<MemoryDataLayer<Dtype> > layer(
95       new MemoryDataLayer<Dtype>(layer_param));
96   layer->DataLayerSetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
97   layer->Reset(this->data_->mutable_cpu_data(),
98       this->labels_->mutable_cpu_data(), this->data_->num());
99   for (int i = 0; i < this->batches_ * 6; ++i) {
100     int batch_num = i % this->batches_;
101     layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
102     for (int j = 0; j < this->data_blob_->count(); ++j) {
103       EXPECT_EQ(this->data_blob_->cpu_data()[j],
104           this->data_->cpu_data()[
105               this->data_->offset(1) * this->batch_size_ * batch_num + j]);
106     }
107     for (int j = 0; j < this->label_blob_->count(); ++j) {
108       EXPECT_EQ(this->label_blob_->cpu_data()[j],
109           this->labels_->cpu_data()[this->batch_size_ * batch_num + j]);
110     }
111   }
112 }
113
114 TYPED_TEST(MemoryDataLayerTest, AddDatumVectorDefaultTransform) {
115   typedef typename TypeParam::Dtype Dtype;
116
117   LayerParameter param;
118   MemoryDataParameter* memory_data_param = param.mutable_memory_data_param();
119   memory_data_param->set_batch_size(this->batch_size_);
120   memory_data_param->set_channels(this->channels_);
121   memory_data_param->set_height(this->height_);
122   memory_data_param->set_width(this->width_);
123   MemoryDataLayer<Dtype> layer(param);
124   layer.SetUp(this->blob_bottom_vec_, &this->blob_top_vec_);
125
126   vector<Datum> datum_vector(this->batch_size_);
127   const size_t count = this->channels_ * this->height_ * this->width_;
128   size_t pixel_index = 0;
129   for (int i = 0; i < this->batch_size_; ++i) {
130     LOG(ERROR) << "i " << i;
131     datum_vector[i].set_channels(this->channels_);
132     datum_vector[i].set_height(this->height_);
133     datum_vector[i].set_width(this->width_);
134     datum_vector[i].set_label(i);
135     vector<char> pixels(count);
136     for (int j = 0; j < count; ++j) {
137       pixels[j] = pixel_index++ % 256;
138     }
139     datum_vector[i].set_data(&(pixels[0]), count);
140   }
141
142   layer.AddDatumVector(datum_vector);
143
144   int data_index;
145   // Go through the data 5 times
146   for (int iter = 0; iter < 5; ++iter) {
147     layer.Forward(this->blob_bottom_vec_, &this->blob_top_vec_);
148     const Dtype* data = this->data_blob_->cpu_data();
149     size_t index = 0;
150     for (int i = 0; i < this->batch_size_; ++i) {
151       const string& data_string = datum_vector[i].data();
152       EXPECT_EQ(i, this->label_blob_->cpu_data()[i]);
153       for (int c = 0; c < this->channels_; ++c) {
154         for (int h = 0; h < this->height_; ++h) {
155           for (int w = 0; w < this->width_; ++w) {
156             data_index = (c * this->height_ + h) * this->width_ + w;
157             EXPECT_EQ(static_cast<Dtype>(
158                 static_cast<uint8_t>(data_string[data_index])),
159                       data[index++]);
160           }
161         }
162       }
163     }
164   }
165 }
166
167 }  // namespace caffe