// Reset should accept const pointers, but can't, because the memory
// will be given to Blob, which is mutable
void Reset(Dtype* data, Dtype* label, int n);
- void ChangeBatchSize(int new_size);
+ void set_batch_size(int new_size);
int batch_size() { return batch_size_; }
int channels() { return channels_; }
Blob<Dtype> added_data_;
Blob<Dtype> added_label_;
bool has_new_data_;
- bool needs_reshape_;
};
/**
added_label_.Reshape(batch_size_, 1, 1, 1);
data_ = NULL;
labels_ = NULL;
- needs_reshape_ = false;
added_data_.cpu_data();
added_label_.cpu_data();
}
}
template <typename Dtype>
-void MemoryDataLayer<Dtype>::ChangeBatchSize(int new_size) {
+void MemoryDataLayer<Dtype>::set_batch_size(int new_size) {
CHECK(!has_new_data_) <<
"Can't change batch_size before all data haven't been consumed"
<< " by the upper layers";
batch_size_ = new_size;
added_data_.Reshape(batch_size_, channels_, height_, width_);
added_label_.Reshape(batch_size_, 1, 1, 1);
- needs_reshape_ = true;
}
template <typename Dtype>
void MemoryDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset";
- if (needs_reshape_) {
- top[0]->Reshape(batch_size_, channels_, height_, width_);
- top[1]->Reshape(batch_size_, 1, 1, 1);
- needs_reshape_ = false;
- }
+ top[0]->Reshape(batch_size_, channels_, height_, width_);
+ top[1]->Reshape(batch_size_, 1, 1, 1);
top[0]->set_cpu_data(data_ + pos_ * size_);
top[1]->set_cpu_data(labels_ + pos_);
pos_ = (pos_ + batch_size_) % n_;
}
}
-TYPED_TEST(MemoryDataLayerTest, TestChangeBatchSize) {
+TYPED_TEST(MemoryDataLayerTest, TestSetBatchSize) {
typedef typename TypeParam::Dtype Dtype;
LayerParameter param;
MemoryDataParameter* memory_data_param = param.mutable_memory_data_param();
}
// and then add new data with different batch_size
int new_batch_size = 16;
- layer.ChangeBatchSize(new_batch_size);
+ layer.set_batch_size(new_batch_size);
mat_vector.clear();
mat_vector.resize(new_batch_size * num_iter);
label_vector.clear();