public:
Blob()
: data_(), diff_(), num_(0), channels_(0), height_(0), width_(0),
- count_(0) {}
+ count_(0), capacity_(0) {}
explicit Blob(const int num, const int channels, const int height,
const int width);
void Reshape(const int num, const int channels, const int height,
int height_;
int width_;
int count_;
+ int capacity_;
DISABLE_COPY_AND_ASSIGN(Blob);
}; // class Blob
height_ = height;
width_ = width;
count_ = num_ * channels_ * height_ * width_;
- if (count_) {
- data_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
- diff_.reset(new SyncedMemory(count_ * sizeof(Dtype)));
- } else {
- data_.reset(reinterpret_cast<SyncedMemory*>(NULL));
- diff_.reset(reinterpret_cast<SyncedMemory*>(NULL));
+ if (count_ > capacity_) {
+ capacity_ = count_;
+ data_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
+ diff_.reset(new SyncedMemory(capacity_ * sizeof(Dtype)));
}
}
template <typename Dtype>
Blob<Dtype>::Blob(const int num, const int channels, const int height,
- const int width) {
+ const int width)
+ // capacity_ must be initialized before calling Reshape
+ : capacity_(0) {
Reshape(num, channels, height, width);
}